当前位置: 首页 > news >正文

神经网络---网络模型的保存、加载

方式1:结构+参数

保存

import torch
import torchvision
from torch import nn
from torchvision.models import vgg16, VGG16_Weightsvgg16 = torchvision.models.vgg16(weights=VGG16_Weights.DEFAULT)# 保存方式1, 模型解构+模型参数
torch.save(vgg16, 'vgg16_1.pth')

加载

from p26_model_svae import *# 方式1 -》 保存方式1,加载模型
model = torch.load('vgg16.pth')
print(model)

方式1的陷阱
自定义网络结构如下:

import torch
import torchvision
from torch import nnclass Tudui(nn.Module):def __init__(self):super(Tudui, self).__init__()self.conv1 = nn.Conv2d(3, 64, kernel_size=3, padding=1)def forward(self, x):x = self.conv1(x)return x
torch.save(Tudui(), 'tudui_1.pth')

在另一个文件加载该模型,会报错
正确的调用格式需要复制原模型的类定义

class Tudui(nn.Module):def __init__(self):super(Tudui, self).__init__()self.conv1 = nn.Conv2d(3, 64, kernel_size=3, padding=1)def forward(self, x):x = self.conv1(x)return x
model = torch.load('tudui_1.pth')
print(model)

或者用import

from p26_model_svae import *model = torch.load('tudui_1.pth')
print(model)

方式2 模型参数(官方推荐)

import torch
import torchvision
from torch import nnvgg16 = torchvision.models.vgg16(pretrained=False)# 保存方式2 模型参数(官方推荐)
torch.save(vgg16.state_dict(), 'vgg16_2.pth')

模型加载(在另一个文件加载)

# 方式2 ,加载模型
vgg16 = torchvision.models.vgg16(weights=None)
vgg16.load_state_dict(torch.load('vgg16_2.pth'))
# model2 = torch.load('vgg16_2.pth')  #字典形式
print(vgg16)

相关文章:

  • 分治算法例子
  • OceanBase v4.2 解读:tenant=all 语义优化,提升易用性
  • Java Web学习笔记4——HTML、CSS
  • PyTorch 的 torch.nn 模块学习
  • 正则表达式----IP地址合法性判断
  • 啵啵啵啵啵啵啵啵啵啵啵啵啵啵啵
  • Java面试——中间件
  • 嵌入式Linux系统编程 — 2.1 标准I/O库简介
  • cs与msf权限传递
  • 最大矩形问题
  • 如何给 MySQL 表和列授予权限?(官方版)
  • HBuilderX编写APP一、获取token
  • Polar Web【简单】upload1
  • 【Meetup】探索Apache SeaTunnel的二次开发与实战案例
  • 数据结构初阶 堆(一)
  • 《Java编程思想》读书笔记-对象导论
  • ➹使用webpack配置多页面应用(MPA)
  • Android 架构优化~MVP 架构改造
  • Angular6错误 Service: No provider for Renderer2
  • C学习-枚举(九)
  • Java比较器对数组,集合排序
  • java第三方包学习之lombok
  • Markdown 语法简单说明
  • node入门
  • springMvc学习笔记(2)
  • webgl (原生)基础入门指南【一】
  • 程序员最讨厌的9句话,你可有补充?
  • 分享几个不错的工具
  • 使用 Xcode 的 Target 区分开发和生产环境
  • 视频flv转mp4最快的几种方法(就是不用格式工厂)
  • 数据可视化之 Sankey 桑基图的实现
  • 微信开放平台全网发布【失败】的几点排查方法
  • 我建了一个叫Hello World的项目
  • 移动端唤起键盘时取消position:fixed定位
  • 在electron中实现跨域请求,无需更改服务器端设置
  • 正则表达式
  • Hibernate主键生成策略及选择
  • 阿里云ACE认证学习知识点梳理
  • 完善智慧办公建设,小熊U租获京东数千万元A+轮融资 ...
  • 小白应该如何快速入门阿里云服务器,新手使用ECS的方法 ...
  • #QT(一种朴素的计算器实现方法)
  • #我与Java虚拟机的故事#连载09:面试大厂逃不过的JVM
  • (16)UiBot:智能化软件机器人(以头歌抓取课程数据为例)
  • (C语言)输入一个序列,判断是否为奇偶交叉数
  • (Mirage系列之二)VMware Horizon Mirage的经典用户用例及真实案例分析
  • (草履虫都可以看懂的)PyQt子窗口向主窗口传递参数,主窗口接收子窗口信号、参数。
  • (九)c52学习之旅-定时器
  • (三分钟)速览传统边缘检测算子
  • .NET/C# 异常处理:写一个空的 try 块代码,而把重要代码写到 finally 中(Constrained Execution Regions)
  • ?
  • @31省区市高考时间表来了,祝考试成功
  • @Transient注解
  • [ 隧道技术 ] cpolar 工具详解之将内网端口映射到公网
  • []利用定点式具实现:文件读取,完成不同进制之间的
  • [20180129]bash显示path环境变量.txt