当前位置: 首页 > news >正文

Pytorch图像模型转ONNX后出现色偏问题

本篇记录一次从Pytorch图像处理模型转换成ONNX模型之后,在推理过程中出现了明显色偏问题的解决过程。

问题描述:原始pytorch模型推理正常,通过torch.onnx.export()函数转换成onnx之后,推理时出现了比较明显的颜色偏差。

原始模型转换程序如下:

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')pth_path = 'model/my_model.pth'
onnx_path = 'model/my_model.onnx'# 模型定义
model = MyModelStruct()
# 加载模型到设备
model.to(device)# 加载checkpoint
checkpoint = torch.load(pth_path, map_location=device)
# 将checkpoint加载到模型
model.load_state_dict(checkpoint) # 将模型设置为推理模式
model.eval()# 定义模型输入输出
input_names = ['input', 'a', 'b']
output_names = ['output']# 定义输入数据格式,随机数初始化
input = torch.rand(1,3, 512, 512)
a = torch.rand(1)
b = torch.rand(1)# 将数据加载到设备
input = input.to(device)
a = a.to(device)
b = b.to(devici)# 开始转换
torch.onnx.export(model, (input, a, b), onnx_path, input_names=input_names, output_names=output_names, verbose=True)print('Done.')

问题解决过程:

1. 由于不是自己的模型,因此,这个pytorch模型拿到手后,先自己写了推理程序,在自己的PC上跑了一下。我自己的PC机只有CPU,在CPU上运行的结果跟onnx上一致,也存在色偏。但将同样的程序放到GPU服务器上运行,结果确实正常的。因此得出结论,应该是模型中间的某些算子,在CPU和GPU上的处理存在误差。

2. 将pytorch模型使用上面的转换程序转换成onnx之后,通过onnxruntime进行推理(这部分本篇先略过,后面再专门开一篇写onnxruntime推理),同样偏色,复现了最初的问题。

3. 通过研究torch.onnx.export()函数,发现跟算子处理关系最大的参数是opset_version,我的版本中,默认使用的opset_version为14,尝试换到16,转换出来的模型推理结果竟然正常了!

相应转换代码替换为:

torch.onnx.export(model, (input, a, b), onnx_path, input_names=input_names, output_names=output_names, verbose=True, opset_version=16)

torch.onnx.export()接收的部分关键参数解释如下:

_export(model,args,f,export_params,verbose,training,input_names,output_names,operator_export_type=operator_export_type,opset_version=opset_version,do_constant_folding=do_constant_folding,dynamic_axes=dynamic_axes,keep_initializers_as_inputs=keep_initializers_as_inputs,custom_opsets=custom_opsets,export_modules_as_functions=export_modules_as_functions,)

其中,

  • model:需要被转换成onnx的模型。
  • args:模型输入参数,一般我们在这里指定模型输入数据的尺寸,如果模型有多个参数,该参数也可以是一个元组,如本例中模型输入三个参数。
  • f:要导出的onnx模型的路径,包括onnx文件名。
  • export_params:(bool, default True),为True时所有的模型参数都会被导出;为False时,则会导出一个未被训练的模型。
  • training:(enum, default TrainingMode.EVAL),有三种模式,分别为:TrainingMode.EVAL,TrainingMOde.PRESERVE,TrainingMode.TRAINING,一般使用EVAL模式即可。
  • input_names:模型图中输入节点的名称,字符串。
  • output_names:模型图中输出节点的名称,字符串。
  • operator_export_type:算子导出类型,包括:OperatorExportTypes.ONNX,OperatorExportTypes.ONNX_FALLTHROUGH,OperatorExportTypes.ONNX_ATEN,OperatorExportTypes.ONNX_ATEN_FALLBACK,我们常用的模型算子一般都有onnx支持,因此默认选第一种。
  • opset_version:这是我们本篇中问题解决的核心参数,默认opset版本为14,可用范围是7~16,通过手动设置,将其设为最高版本,问题初步解决。
  • dynamic_axes:允许在导出的onnx模型中创建变化的维度,是一个字典形式,默认为空。

4. 到第三步其实还没完,因为,虽然模型输出没有偏色了,但是拿到实际场景中去运行,发现它的运行速度变慢了很多。于是又尝试了使用opset_version 11:

torch.onnx.export(model, (input, a, b), onnx_path, input_names=input_names, output_names=output_names, verbose=True, opset_version=11)

这次,也没有偏色,而且推理速度回归正常。

目前暂时使用opset_version 11,至于为什么版本16会使运行速度变慢,还需要更深入地去了解不同版本地差异。留待后续吧。

相关文章:

  • Visual Studio 2010 软件安装教程(附下载链接)——计算机二级专用编程软件
  • 虽然许多人表示对Windows 11的透明任务栏不满,但有时效果还是挺好的
  • uni-app小程序使用vant
  • 虚拟机VirtualBox添加磁盘
  • React——简便获取经纬度信息
  • 家庭私人影院 - Windows搭建Emby媒体库服务器并远程访问 「无公网IP」
  • 不必安装,快速设计数据库表结构
  • 【gpt redis】原理篇
  • 蓝桥杯官网填空题(含2天数)
  • Java程序设计2023-第三次上机练习
  • pytorch复现_conv2d
  • 读程序员的制胜技笔记04_有用的反模式(下)
  • 【重磅】Cookies、headers、Session规律总结,搞定卡点
  • 2023.11.4 Idea 配置国内 Maven 源
  • 高级深入--day46
  • android高仿小视频、应用锁、3种存储库、QQ小红点动画、仿支付宝图表等源码...
  • Computed property XXX was assigned to but it has no setter
  • CSS盒模型深入
  • Docker: 容器互访的三种方式
  • ECMAScript 6 学习之路 ( 四 ) String 字符串扩展
  • Hibernate【inverse和cascade属性】知识要点
  • HTTP--网络协议分层,http历史(二)
  • JavaScript对象详解
  • js正则,这点儿就够用了
  • Netty 框架总结「ChannelHandler 及 EventLoop」
  • Phpstorm怎样批量删除空行?
  • Redux系列x:源码分析
  • vue--为什么data属性必须是一个函数
  • 工作中总结前端开发流程--vue项目
  • 前端
  • 如何设计一个微型分布式架构?
  • 深入浅出Node.js
  • 使用API自动生成工具优化前端工作流
  • 学习笔记DL002:AI、机器学习、表示学习、深度学习,第一次大衰退
  • 用Node EJS写一个爬虫脚本每天定时给心爱的她发一封暖心邮件
  • 在weex里面使用chart图表
  • Android开发者必备:推荐一款助力开发的开源APP
  • 智能情侣枕Pillow Talk,倾听彼此的心跳
  • ​linux启动进程的方式
  • #Linux(make工具和makefile文件以及makefile语法)
  • #QT项目实战(天气预报)
  • $.extend({},旧的,新的);合并对象,后面的覆盖前面的
  • ()、[]、{}、(())、[[]]命令替换
  • (10)ATF MMU转换表
  • (10)STL算法之搜索(二) 二分查找
  • (12)Hive调优——count distinct去重优化
  • (C语言)求出1,2,5三个数不同个数组合为100的组合个数
  • (libusb) usb口自动刷新
  • (ZT) 理解系统底层的概念是多么重要(by趋势科技邹飞)
  • (待修改)PyG安装步骤
  • (附源码)ssm经济信息门户网站 毕业设计 141634
  • (数位dp) 算法竞赛入门到进阶 书本题集
  • (一)u-boot-nand.bin的下载
  • (一)UDP基本编程步骤
  • (原創) 是否该学PetShop将Model和BLL分开? (.NET) (N-Tier) (PetShop) (OO)