当前位置: 首页 > news >正文

nnunetv2系列:torch转onnx

nnunetv2系列:torch转onnx

首先感谢https://blog.csdn.net/chen_niansan/article/details/141940273作者提供的代码,这里做了一些简单的调整和优化,原代码存在覆盖文件问题。

代码示例


import json
from os.path import isdir, join
from pathlib import Path
from typing import Tuple, Unionimport numpy as np
# import onnx
from onnx import (load,checker
)
# import onnxruntime
from onnxruntime import InferenceSession
# import torch
from torch import (rand,
)
from torch.onnx import export as torch_onnx_export
from batchgenerators.utilities.file_and_folder_operations import load_jsonfrom nnunetv2.inference.predict_from_raw_data import nnUNetPredictor
from nnunetv2.utilities.dataset_name_id_conversion import (maybe_convert_to_dataset_name,
)
from nnunetv2.utilities.file_path_utilities import get_output_folderdef export_dataset_json(json_path,c, fold, batch_size, use_dynamic_axes, config, dataset_name, dataset_json, foreground_intensity_properties,
):with open(json_path, "w", encoding="utf-8") as f:config_dict = {"configuration": c,"fold": fold,"model_parameters": {"batch_size": batch_sizeif not use_dynamic_axeselse "dynamic","patch_size": config.patch_size,"spacing": config.spacing,"normalization_schemes": config.normalization_schemes,# These are mostly interesting for certification# uses, but they are also useful for debugging.# "UNet_class_name": config.UNet_class_name,# "UNet_base_num_features": config.UNet_base_num_features,# "unet_max_num_features": config.unet_max_num_features,# "conv_kernel_sizes": config.conv_kernel_sizes,# "pool_op_kernel_sizes": config.pool_op_kernel_sizes,# "num_pool_per_axis": config.num_pool_per_axis,},"dataset_parameters": {"dataset_name": dataset_name,"num_channels": len(dataset_json["channel_names"].keys()),"channels": {k: {"name": v,# For when normalization is not Z-Score"foreground_properties": foreground_intensity_properties[k],}for k, v in dataset_json["channel_names"].items()},"num_classes": len(dataset_json["labels"].keys()),"class_names": {v: k for k, v in dataset_json["labels"].items()},},}json.dump(config_dict,f,indent=4,)def export_onnx_model(dataset_name_or_id: Union[int, str],output_dir: Path,configurations: Tuple[str] = ("2d",),batch_size: int = 0,trainer: str = "nnUNetTrainer",plans_identifier: str = "nnUNetPlans",# folds: Tuple[Union[int, str], ...] = (0, 1, 2, 3, 4),folds: Tuple[Union[int, str], ...] = (0,),strict: bool = True,save_checkpoints: Tuple[str, ...] = ("checkpoint_final.pth",),output_names: tuple[str, ...] = None,verbose: bool = False,
) -> None:if not output_names:output_names = (f"{checkpoint[:-4]}_fold_{fold_index}.onnx"for fold_index, checkpoint in zip(folds, save_checkpoints))if batch_size < 0:raise ValueError("batch_size must be non-negative")# use_dynamic_axes = batch_size == 0dataset_name = maybe_convert_to_dataset_name(dataset_name_or_id)for c in configurations:# print(f"Configuration {c}")trainer_output_dir = get_output_folder(dataset_name, trainer, plans_identifier, c)dataset_json = load_json(join(trainer_output_dir, "dataset.json"))# While we load in this file indirectly, we need the plans file to# determine the foreground intensity properties.plans = load_json(join(trainer_output_dir, "plans.json"))foreground_intensity_properties = plans["foreground_intensity_properties_per_channel"]if not isdir(trainer_output_dir):if strict:raise RuntimeError(f"{dataset_name} is missing the trained model of configuration {c}")else:print(f"Skipping configuration {c}, does not exist")continuepredictor = nnUNetPredictor(# perform_everything_on_gpu=False,# device=torch.device("cpu"),)for checkpoint_name, output_name in zip(save_checkpoints, output_names):predictor.initialize_from_trained_model_folder(model_training_output_dir=trainer_output_dir,use_folds=folds,checkpoint_name=checkpoint_name,# disable_compilation=True,)list_of_parameters = predictor.list_of_parametersnetwork = predictor.networkconfig = predictor.configuration_managerfor fold, params in zip(folds, list_of_parameters):network.load_state_dict(params)network.eval()# curr_output_dir = output_dir / c / f"fold_{fold}"curr_output_dir = Path(output_dir)curr_output_dir.mkdir(parents=True, exist_ok=True)if use_dynamic_axes:rand_input = rand((1, 1, *config.patch_size))torch_output = network(rand_input)torch_onnx_export(network,rand_input,curr_output_dir / output_name,export_params=True,verbose=verbose,input_names=["input"],output_names=["output"],dynamic_axes={"input": {0: "batch_size"},"output": {0: "batch_size"},},)else:rand_input = rand((batch_size, 3, *config.patch_size))torch_output = network(rand_input)torch_onnx_export(network,rand_input,curr_output_dir / output_name,export_params=True,verbose=verbose,input_names=["input"],output_names=["output"],)onnx_model = load(curr_output_dir / output_name)checker.check_model(onnx_model)ort_session = InferenceSession(curr_output_dir / output_name,providers=["CPUExecutionProvider"],)ort_inputs = {ort_session.get_inputs()[0].name: rand_input.numpy()}ort_outs = ort_session.run(None, ort_inputs)try:np.testing.assert_allclose(torch_output.detach().cpu().numpy(),ort_outs[0],rtol=1e-03,atol=1e-05,verbose=True,)except AssertionError as e:print("WARN: Differences found between torch and onnx:\n")print(e)print("\nExport will continue, but please verify that your pipeline matches the original.")print(f"Exported {curr_output_dir / output_name}")export_dataset_json(f"{curr_output_dir}/config_{fold}.json",c, fold, batch_size, use_dynamic_axes, config, dataset_name, dataset_json, foreground_intensity_properties,)if __name__ == "__main__":export_onnx_model(dataset_name_or_id="500",output_dir="onnx_model",configurations={"2d"},batch_size=4,folds=(1,),save_checkpoints=("checkpoint_final.pth", ), output_names=None, # onnx文件名组成的tuple)

相关文章:

  • 北京网站建设多少钱?
  • 辽宁网页制作哪家好_网站建设
  • 高端品牌网站建设_汉中网站制作
  • AI学习指南深度学习篇-带动量的随机梯度下降法Python实践
  • 技术美术一百问(01)
  • 基于CNN-BiGUR的恶意域名检测方法
  • IDC基础学习笔记
  • Pycharm Remote Development 报错解决
  • HTTP 协议和 APACHE 服务
  • TikTok运营需要的独立IP如何获取?
  • Redis 入门 - 五大基础类型及其指令学习
  • 代码随想录冲冲冲 Day41 动态规划Part9
  • Mysql | 知识 | 事务隔离级别
  • Kylin Server V10 下 MySQL 8 binlog 管理
  • Spark底层逻辑
  • 高教社杯数模竞赛特辑论文篇-2013年B题:碎纸复原模型与算法
  • 分享一个基于微信小程序的医院挂号就诊一体化平台uniapp医院辅助挂号应用小程序设计(源码、调试、LW、开题、PPT)
  • WORD批量转换器MultiDoc Converter
  • 【162天】黑马程序员27天视频学习笔记【Day02-上】
  • 5、React组件事件详解
  • CentOS 7 修改主机名
  • Java方法详解
  • Just for fun——迅速写完快速排序
  • MaxCompute访问TableStore(OTS) 数据
  • python学习笔记-类对象的信息
  • Spring框架之我见(三)——IOC、AOP
  • vue 个人积累(使用工具,组件)
  • 程序员该如何有效的找工作?
  • 工作手记之html2canvas使用概述
  • 聚类分析——Kmeans
  • 浏览器缓存机制分析
  • 前端面试总结(at, md)
  • 通过获取异步加载JS文件进度实现一个canvas环形loading图
  • 深度学习之轻量级神经网络在TWS蓝牙音频处理器上的部署
  • 教程:使用iPhone相机和openCV来完成3D重建(第一部分) ...
  • ​Spring Boot 分片上传文件
  • ​经​纬​恒​润​二​面​​三​七​互​娱​一​面​​元​象​二​面​
  • ​什么是bug?bug的源头在哪里?
  • ​中南建设2022年半年报“韧”字当头,经营性现金流持续为正​
  • ​字​节​一​面​
  • ![CDATA[ ]] 是什么东东
  • #设计模式#4.6 Flyweight(享元) 对象结构型模式
  • $Django python中使用redis, django中使用(封装了),redis开启事务(管道)
  • (01)ORB-SLAM2源码无死角解析-(66) BA优化(g2o)→闭环线程:Optimizer::GlobalBundleAdjustemnt→全局优化
  • (3)(3.2) MAVLink2数据包签名(安全)
  • (6) 深入探索Python-Pandas库的核心数据结构:DataFrame全面解析
  • (c语言版)滑动窗口 给定一个字符串,只包含字母和数字,按要求找出字符串中的最长(连续)子串的长度
  • (Git) gitignore基础使用
  • (二)linux使用docker容器运行mysql
  • (非本人原创)史记·柴静列传(r4笔记第65天)
  • (十)Flink Table API 和 SQL 基本概念
  • (五)c52学习之旅-静态数码管
  • (五)大数据实战——使用模板虚拟机实现hadoop集群虚拟机克隆及网络相关配置
  • (转) Android中ViewStub组件使用
  • (转)Android学习笔记 --- android任务栈和启动模式
  • .NET BackgroundWorker
  • .NET C# 使用 SetWindowsHookEx 监听鼠标或键盘消息以及此方法的坑
  • .NET MVC第三章、三种传值方式