当前位置: 首页 > news >正文

pytorch导出rot90算子至onnx

如何导出rot90算子至onnx

    • 1 背景描述
    • 2 等价替换
      • 2.1 rot90替换(NCHW)
      • 2.2 rot180替换(NCHW)
      • 2.3 rot270替换(NCHW)
    • 3 rot导出ONNX

1 背景描述

在部署模型时,如果某些模型中或者前后处理中含有rot90算子,但又希望一起和模型导出onnx时,可能会遇到如下错误(当前使用环境pytorch2.0.1opset_version为17):

import torch
import torch.nn as nnclass RotModel(nn.Module):def forward(self, x: torch.Tensor):x = torch.rot90(x, k=1, dims=(2, 3))return xdef main():print("pytorch version:", torch.__version__)model = RotModel()with torch.inference_mode():x = torch.randn(size=(1, 3, 224, 224))torch.onnx.export(model,args=(x,),f="rot90_counterclockwise.onnx",opset_version=17)if __name__ == '__main__':main()

torch.onnx.errors.UnsupportedOperatorError: Exporting the operator ‘aten::rot90’ to ONNX opset version 17 is not supported. Please feel free to request support or submit a pull request on PyTorch GitHub: https://github.com/pytorch/pytorch/issues.

简单的说就是不支持导出该算子,包括在onnx支持的算子文档中也找不到rot90算子,onnx官方github链接:
https://github.com/onnx/onnx


2 等价替换

导不出咋办,那就想想旋转矩阵的原理,以及如何使用现有支持的算子替换。

2.1 rot90替换(NCHW)

废话不多说,rot90度(以逆时针为例)可以使用翻转和转置实现。具体代码如下,使用torch自带的rot90与自己实现的对比,通过torch.equal()来对比两个Tensor是否一致,结果一致,不信自己试试。

import torchdef self_rot90_counterclockwise(x: torch.Tensor):x = x.flip(dims=[3]).permute([0, 1, 3, 2])return xdef main():print("pytorch version:", torch.__version__)with torch.inference_mode():x = torch.randn(size=(1, 3, 224, 224))y0 = torch.rot90(x, k=1, dims=[2, 3])y1 = self_rot90_counterclockwise(x)print(torch.equal(y0, y1))if __name__ == '__main__':main()

2.2 rot180替换(NCHW)

rot180度(以逆时针为例)可以使用翻转实现。具体代码如下:

import torchdef self_rot180_counterclockwise(x: torch.Tensor):x = x.flip(dims=[2, 3])return xdef main():print("pytorch version:", torch.__version__)with torch.inference_mode():x = torch.randn(size=(1, 3, 224, 224))y0 = torch.rot90(x, k=2, dims=[2, 3])y1 = self_rot180_counterclockwise(x)print(torch.equal(y0, y1))if __name__ == '__main__':main()

2.3 rot270替换(NCHW)

rot270度(以逆时针为例)可以使用翻转和转置实现。具体代码如下:

import torchdef self_rot270_counterclockwise(x: torch.Tensor):x = x.flip(dims=[2]).permute([0, 1, 3, 2])return xdef main():print("pytorch version:", torch.__version__)with torch.inference_mode():x = torch.randn(size=(1, 3, 224, 224))y0 = torch.rot90(x, k=3, dims=[2, 3])y1 = self_rot270_counterclockwise(x)print(torch.equal(y0, y1))if __name__ == '__main__':main()

3 rot导出ONNX

这里以rot90度(以逆时针为例)结合刚刚的等价实现来导出ONNX:

import torch
import torch.nn as nnclass RotModel(nn.Module):def forward(self, x: torch.Tensor):# x = torch.rot90(x, k=1, dims=(2, 3))x = x.flip(dims=[3]).permute([0, 1, 3, 2])return xdef main():print("pytorch version:", torch.__version__)model = RotModel()with torch.inference_mode():x = torch.randn(size=(1, 3, 224, 224))torch.onnx.export(model,args=(x,),f="rot90_counterclockwise.onnx",opset_version=17)if __name__ == '__main__':main()

使用netron打开生成的rot90_counterclockwise.onnx文件,如下所示:

在这里插入图片描述

相关文章:

  • 24. 深度学习进阶 - 矩阵运算的维度和激活函数
  • 嵌入式硬件电路·电平
  • Linux中vi常用命令-批量替换
  • 智能医疗越发周到!新的机器人系统评估中风后的活动能力
  • 从零开始学习管道:管道程序的优化和文件描述符继承问题
  • gitee推荐-1Panel
  • 搜索百度可以直接生成代码拉
  • 【广州华锐互动】节约用水VR互动教育:身临其境体验水资源的珍贵!
  • ubuntu/vscode下的c/c++开发之-CMake语法与练习
  • Git多库多账号本地SSH连接配置方法
  • gitea仓库镜像同步至gitlab
  • 阿里云跨账号建立局域网
  • 深入理解RC4加密算法
  • 2023亚太杯数学建模A题思路分析 - 采果机器人的图像识别技术
  • 线程基本方法
  • 【vuex入门系列02】mutation接收单个参数和多个参数
  • Android系统模拟器绘制实现概述
  • AzureCon上微软宣布了哪些容器相关的重磅消息
  • express.js的介绍及使用
  • HTML-表单
  • javascript从右向左截取指定位数字符的3种方法
  • leetcode-27. Remove Element
  • magento 货币换算
  • Map集合、散列表、红黑树介绍
  • Redis字符串类型内部编码剖析
  • sublime配置文件
  • Vultr 教程目录
  • 从0到1:PostCSS 插件开发最佳实践
  • 大型网站性能监测、分析与优化常见问题QA
  • 多线程 start 和 run 方法到底有什么区别?
  • 缓存与缓冲
  • 前端之Sass/Scss实战笔记
  • 如何用vue打造一个移动端音乐播放器
  • 如何正确配置 Ubuntu 14.04 服务器?
  • 视频flv转mp4最快的几种方法(就是不用格式工厂)
  • 通过获取异步加载JS文件进度实现一个canvas环形loading图
  • 智能合约Solidity教程-事件和日志(一)
  • 测评:对于写作的人来说,Markdown是你最好的朋友 ...
  • #我与Java虚拟机的故事#连载08:书读百遍其义自见
  • $ is not function   和JQUERY 命名 冲突的解说 Jquer问题 (
  • (C语言)二分查找 超详细
  • (pojstep1.3.1)1017(构造法模拟)
  • (超简单)构建高可用网络应用:使用Nginx进行负载均衡与健康检查
  • (附源码)ssm教师工作量核算统计系统 毕业设计 162307
  • (深度全面解析)ChatGPT的重大更新给创业者带来了哪些红利机会
  • (太强大了) - Linux 性能监控、测试、优化工具
  • .360、.halo勒索病毒的最新威胁:如何恢复您的数据?
  • .desktop 桌面快捷_Linux桌面环境那么多,这几款优秀的任你选
  • .Net Core缓存组件(MemoryCache)源码解析
  • .NET DevOps 接入指南 | 1. GitLab 安装
  • .netcore 6.0/7.0项目迁移至.netcore 8.0 注意事项
  • /dev下添加设备节点的方法步骤(通过device_create)
  • @ 代码随想录算法训练营第8周(C语言)|Day53(动态规划)
  • @for /l %i in (1,1,10) do md %i 批处理自动建立目录
  • @requestBody写与不写的情况