当前位置: 首页 > news >正文

pytorch保存和加载模型权重以及CUDA在pytorch中的使用

1.保存加载自定义模型

1.1 保存加载整个模型

'''模型保存'''
torch.save(model, 'model.pkl')
'''模型加载'''
model = torch.load('model.pkl', map_location='cuda:0')

这种方式直接保存加载整个网络结构,比较死板,不能调整网络结构。

1.2 保存加载模型参数

'''模型参数保存'''
torch.save(model.state_dict(), 'model_param.pkl')
'''模型参数加载'''
device = torch.device('cpu') # device = torch.device('cuda')
# 定义网络
model = RNN() # 举例,一个RNN类,定义了RNN模型的结构
# 加载模型参数到模型结构
model.load_state_dict('model_param.pkl', map_location=device)

这种方式需要自己先定义网络模型的结构才能加载模型的参数,并且定义的网络模型的参数名称和结构要与加载的模型一致(可以是部分网络,比如只使用神经网络的前几层),相对灵活,便于对网络进行修改。

2.加载预训练模型

由于加载保存整个网络模型比较死板,所以一般都只保存或者加载预训练模型的参数.

2.1 预训练模型网络结构 == 自定义模型网络结构

结构相同就不需要修改,可以直接套用

path="预训练模型地址"
model = CJK_MODEL()
checkpoint = torch.load(path, map_location="cuda:0")
# 加载参数
model = model.load_state_dict(torch.load(path))

2.2 预训练模型网络结构与自定义模型网络结构不一致

  1. 首先打印出两个网络模型的各层网络名称
	'''输出自定义模型的各层网络结构名称'''
	model_dict = model.state_dict()
	print(model_dict.keys())
	'''输出自定义模型的各层网络结构名称'''
	checkpoint = torch.load('./model_param.pkl')
	for k, v in checkpoint.items():
		print("keys:",k)
  1. 对比两者网络结构参数,如果差距太大就没有借用的必要了
  • 如果许多参数层的名称完全一致:
	model.load_state_dict(checkpoint, strict=True)
	'''
	load_state_dict 函数添加参数 strict=True,
	它直接忽略那些没有的dict,有相同的就复制,没有就直接放弃赋值. 
	他要求预训练模型的关键字必须确切地严格地和
	自定义网络的state_dict()函数返回的关键字相匹配才能赋值。
	'''
  • 如果许多参数层的名称大部分一致:
    比如自定义网络模型中参数层名称为backbone.stage0.rbr_dense.conv.weight,
    预训练模型中参数层名称为stage0.rbr_dense.conv.weight,可以看到二者大部分是一致的.这种情况下,可以把 预训练模型的stage0.rbr_dense.conv.weight读入网络的backbone.stage0.rbr_dense.conv.weight 中。

3.cuda

把在cpu上运算的张量转变为在GPU上运行的cuda张量可以显著提升训练速度

  • 在加载模型时转变为cuda张量
# 选择设备
device = ['cuda' if torch.cuda.is_available() else 'cpu']  
# 法一:加载整个模型
model = torch.load('model.pkl', map_location='cuda:0')
# 法二:只加载模型参数
model = model.load_state_dict(torch.load(path), map_location='cuda:0')
  • 单独转变为cuda张量
# 选择设备
device = ['cuda' if torch.cuda.is_available() else 'cpu']  
# 法一:加载整个模型
model = torch.load('model.pkl')
# 法二:只加载模型参数
model = model.load_state_dict(torch.load(path))
'''单独转变为cuda张量'''
model.to(device)

注意, 必须把所有的张量都转为cuda型, 否则会报错

相关文章:

  • UDF提权(mysql)
  • linux内核漏洞(CVE-2022-0847)
  • kubekey 离线部署 kubesphere v3.3.0
  • Git史上最详细教程(详细图解)
  • Python科学计算库练习题
  • 高性能MySQL实战第10讲:搭建稳固的MySQL运维体系
  • java毕业设计茶叶企业管理系统Mybatis+系统+数据库+调试部署
  • JAVA安装教程 (windows)
  • 6.hadoop文件数据库系列讲解
  • Day11OSI与TCP/IP协议簇以及物理层
  • Javaweb学生信息管理系统(Mysql+JSP+MVC+CSS)
  • ubuntu-hadoop伪分布
  • springboot 多环境配置(pom配置Profiles变量来,控制打包环境)
  • 计算机毕业设计ssm蓟县农家院网站2zl2w系统+程序+源码+lw+远程部署
  • 刷题记录(NC16645 [NOIP2007]矩阵取数游戏,NC207781 迁徙过程中的河流,NC235953 最大m个子段和)
  • 5、React组件事件详解
  • ABAP的include关键字,Java的import, C的include和C4C ABSL 的import比较
  • GitUp, 你不可错过的秀外慧中的git工具
  • Laravel 菜鸟晋级之路
  • Objective-C 中关联引用的概念
  • PhantomJS 安装
  • Vim 折腾记
  • vue从入门到进阶:计算属性computed与侦听器watch(三)
  • Webpack入门之遇到的那些坑,系列示例Demo
  • 给Prometheus造假数据的方法
  • 三栏布局总结
  • 【干货分享】dos命令大全
  • C# - 为值类型重定义相等性
  • Hibernate主键生成策略及选择
  • shell使用lftp连接ftp和sftp,并可以指定私钥
  • # 安徽锐锋科技IDMS系统简介
  • #常见电池型号介绍 常见电池尺寸是多少【详解】
  • #基础#使用Jupyter进行Notebook的转换 .ipynb文件导出为.md文件
  • %check_box% in rails :coditions={:has_many , :through}
  • (免费领源码)python#django#mysql公交线路查询系统85021- 计算机毕业设计项目选题推荐
  • (牛客腾讯思维编程题)编码编码分组打印下标题目分析
  • (万字长文)Spring的核心知识尽揽其中
  • (一) storm的集群安装与配置
  • .NET Conf 2023 回顾 – 庆祝社区、创新和 .NET 8 的发布
  • .NET CORE 第一节 创建基本的 asp.net core
  • .NET Windows:删除文件夹后立即判断,有可能依然存在
  • .Net 应用中使用dot trace进行性能诊断
  • .Net 中Partitioner static与dynamic的性能对比
  • .Net 中的反射(动态创建类型实例) - Part.4(转自http://www.tracefact.net/CLR-and-Framework/Reflection-Part4.aspx)...
  • .NET简谈设计模式之(单件模式)
  • .Net开发笔记(二十)创建一个需要授权的第三方组件
  • .NET性能优化(文摘)
  • [ 第一章] JavaScript 简史
  • [ 蓝桥杯Web真题 ]-Markdown 文档解析
  • [android] 练习PopupWindow实现对话框
  • [AutoSar]BSW_Memory_Stack_003 NVM与APP的显式和隐式同步
  • [ccc3.0][数字钥匙] UWB配置和使用(二)
  • [EFI]Dell Inspiron 15 5567 电脑 Hackintosh 黑苹果efi引导文件
  • [Flutter]WindowsPlatform上运行遇到的问题总结
  • [IE技巧] IE8中HTTP连接数目的变化