网站如何制作学校的做,快速免费做网站,网站进入,摄影网站的需求分析在使用pytorch模型训练完成之后#xff0c;我们现在使用的比较多的一种方法是将pytorch模型转成onnx格式的模型中间文件#xff0c;然后再根据使用的硬件来生成具体硬件使用的深度学习模型#xff0c;比如TensorRT。 在从pytorch模型转为onnx时#xff0c;我们可能会遇到部…在使用pytorch模型训练完成之后我们现在使用的比较多的一种方法是将pytorch模型转成onnx格式的模型中间文件然后再根据使用的硬件来生成具体硬件使用的深度学习模型比如TensorRT。 在从pytorch模型转为onnx时我们可能会遇到部分算子无法转换的问题本篇注意记录下解决方法。
在导出onnx时如果出现报错的算子可以先在下面的链接中查找onnx算子是否支持 https://github.com/onnx/onnx/blob/main/docs/Operators.md
pytorch中有onnx中也有的算子
导出时使用的onnx op 版本低导致
这个就好解决了把op库的版本提高就行但是有可能提高了版本以后又出现了原来支持的算子现在又不支持了这个再说
pytorch中没有注册某个onnx算子
如果是这种情况就按照下面的方式进行
from torch.onnx import register_custom_op_symbolic
# 创建一个asinh算子的symblic符号函数用来登记
# 符号函数内部调用g.op, 为onnx计算图添加Asinh算子
# g: 就是graph计算图
# 也就是说在计算图中添加onnx算子
# 由于我们已经知道Asinh在onnx是有实现的所以我们只要在g.op调用这个op的名字就好了
# symblic的参数需要与Pytorch的asinh接口函数的参数对齐
# def asinh(input: Tensor, *, out: Optional[Tensor]None) - Tensor: ...
def asinh_symbolic(g, input, *, outNone):return g.op(Asinh, input)# 在这里将asinh_symbolic这个符号函数与PyTorch的asinh算子绑定。也就是所谓的“注册算子”
# asinh是在名为aten的一个c命名空间下进行实现的# aten是a Tensor Library的缩写是一个实现张量运算的C库
register_custom_op_symbolic(aten::asinh, asinh_symbolic, 12)另外一个写法 这个是类似于torch/onnx/symbolic_opset*.py中的写法 通过torch._internal中的registration来注册这个算子让这个算子可以与底层C实现的aten::asinh绑定 一般如果这么写的话其实可以把这个算子直接加入到torch/onnx/symbolic_opset*.py中
import functools
from torch.onnx import register_custom_op_symbolic
from torch.onnx._internal import registration_onnx_symbolic functools.partial(registration.onnx_symbolic, opset9)_onnx_symbolic(aten::asinh)
def asinh_symbolic(g, input, *, outNone):return g.op(Asinh, input)pytorch中有onnx中无的算子
继承torch.autograd.Function实现自定义算子
import torch
import torch.onnx
import onnxruntime
from torch.onnx import register_custom_op_symbolicOperatorExportTypes torch._C._onnx.OperatorExportTypesclass CustomOp(torch.autograd.Function):staticmethod def symbolic(g: torch.Graph, x: torch.Value) - torch.Value:return g.op(custom_domain::customOp2, x)staticmethoddef forward(ctx, x: torch.Tensor) - torch.Tensor:ctx.save_for_backward(x)x x.clamp(min0)return x / (1 torch.exp(-x))customOp CustomOp.apply然后再自己实现custom_domain::customOp2这个算子如果用TensorRT就需要自己实现一个插件。