当前位置: 首页 > news >正文

【深度学习】(7)--神经网络之保存最优模型

文章目录

  • 保存最优模型
    • 一、两种保存方法
      • 1. 保存模型参数
      • 2. 保存完整模型
    • 二、迭代模型
  • 总结

保存最优模型

我们在迭代模型训练时,随着次数初始的增多,模型的准确率会逐渐的上升,但是同时也随着迭代次数越来越多,由于模型会开始学习到训练数据中的噪声或非共性特征,发生过拟合现象,使得模型的准确率会上下震荡甚至于下降。

本篇就是介绍我们如何在进行那么多次迭代之中,找到训练最好效果时,模型的参数或完整模型。也方便以后使用模型时直接使用。

一、两种保存方法

我们知道,一个模型到底好不好,主要体现在对测试集数据结果上的表现,所以我们的方法主要从测试集入手,计算每次迭代测试集数据的准确率,取到准确率最大时对应的模型和参数

那么,我们该如何保存模型和参数呢?介绍一个小东西:

  • 文件拓展名pt\pth,t7,使用pt\pth或t7作为模型文件扩展名,保存模型的整个状态(包括模型架构和参数)或仅保存模型的参数(即状态字典,state_dict)。

1. 保存模型参数

方法

torch.save(model.state_dict(),path)
# model.state_dict()是一个从参数名称映射到参数张量的字典对象,它包含了模型的所有权重和偏置项
# path为创建的保存模型的文件

通过比较每一次迭代准确率的大小,取准确率最大时模型的参数

best_acc = 0
"""-----测试集-----"""
def test(dataloader,model,loss_fn):global best_accsize = len(dataloader.dataset) # 总数据大小num_batches = len(dataloader) # 划分的小批次数量model.eval()test_loss,correct = 0,0with torch.no_grad():for x,y in dataloader:x,y = x.to(device),y.to(device)pred = model.forward(x)test_loss += loss_fn(pred,y).item()correct += (pred.argmax(1) == y).type(torch.float).sum().item() # 预测正确的个数test_loss /= num_batchescorrect /= sizecorrect = round(correct, 4)print(f"Test result: \n Accuracy:{(100*correct)}%,Avg loss:{test_loss}")# 保存最优模型的方法(文件扩展名一般:pt\pth,t7)if correct > best_acc:best_acc = correct# 1. 保存模型参数方法:torch.save(model.state_dict(),path)  (w,b)print(model.state_dict().keys()) # 输出模型参数名称cnntorch.save(model.state_dict(),"best.pth") 

2. 保存完整模型

方法

torch.save(model,path)
# 直接得到整个模型

依旧是通过比较每一次迭代准确率的大小,但是取准确率最大时的整个模型

def test(dataloader,model,loss_fn):global best_accsize = len(dataloader.dataset)num_batches = len(dataloader)model.eval()test_loss,correct = 0,0with torch.no_grad():for x,y in dataloader:x,y = x.to(device),y.to(device)pred = model.forward(x)test_loss += loss_fn(pred,y).item()correct += (pred.argmax(1) == y).type(torch.float).sum().item()a = (pred.argmax(1) == y)b = (pred.argmax(1) == y).type(torch.float)test_loss /= num_batchescorrect /= sizecorrect = round(correct, 4)print(f"Test result: \n Accuracy:{(100*correct)}%,Avg loss:{test_loss}")# 保存最优模型的方法(文件扩展名一般:pt\pth,t7)if correct > best_acc:best_acc = correct# 2. 保存完整模型(w,b,模型cnn)torch.save(model,"best1.pt")

二、迭代模型

接下来就要迭代模型,得到最优的模型:

loss_fn = nn.CrossEntropyLoss()
optimizer = torch.optim.Adam(model.parameters(),lr=0.001,weight_decay=0.0001)epochs = 150
# training_data、test_data:数据预处理好的数据
train_dataloader = DataLoader(training_data,batch_size=64,shuffle=True)
test_dataloader = DataLoader(test_data,batch_size=64,shuffle=True)
for t in range(epochs):print(f"Epoch {t+1} \n-------------------------")train(train_dataloader,model,loss_fn,optimizer)test(test_dataloader,model,loss_fn)
print("Done!")

在每轮数据迭代后,project工程栏中的best1.ptbest.pth文件中模型会随着迭代及时更新,迭代结束后,文件中保存的就是最优模型以及最优的模型参数。

在这里插入图片描述

总结

本篇介绍了:

  1. 为什么随着迭代次数越来越多,模型的准确率会上下震荡甚至于下降。—> 过拟合
  2. pt\pth,t7三个扩展名,用于保存完整模型或者模型参数。
  3. 模型的好坏,通过体现在测试集的结果上。
  4. 保存最优模型的两种方法:保存模型参数和保存完整模型。

相关文章:

  • 每日一题|1845. 座位预约管理系统|最小堆操作、优先队列
  • Linux系统中命令wc
  • 用css实现改变图片滤镜
  • Ubuntu20.04安装ros2
  • 2024必读NLP书籍!《自然语言处理:基于预训练模型的方法》附PDF!
  • 网站服务器在不同操作系统上监听端口情况的方法
  • 卫华集团再次惠购宏山激光30kW大幅面激光切割机,全力构建新质生产力
  • scrapy 爬取微博(五)【最新超详细解析】: 爬取微博文章
  • oracle direct path read处理过程
  • uniapp js怎么根据map需要显示的点位,计算自适应的缩放scale
  • 【Unity踩坑】Textmesh Pro是否需要加入Version Control?
  • 经典sql题(十四)炸裂函数的恢复
  • 资金晋阶司库|基于数字化标准建立的操作类应用
  • 生物医学光学第三章作业:归纳和总结生物发光的主要类型和特点
  • Linux 网络配置 (深入理解)
  • 【跃迁之路】【463天】刻意练习系列222(2018.05.14)
  • classpath对获取配置文件的影响
  • CSS居中完全指南——构建CSS居中决策树
  • extjs4学习之配置
  • iOS编译提示和导航提示
  • IP路由与转发
  • Js实现点击查看全文(类似今日头条、知乎日报效果)
  • Linux编程学习笔记 | Linux IO学习[1] - 文件IO
  • Three.js 再探 - 写一个跳一跳极简版游戏
  • vue中实现单选
  • 给自己的博客网站加上酷炫的初音未来音乐游戏?
  • 爬虫进阶 -- 神级程序员:让你的爬虫就像人类的用户行为!
  • 用jQuery怎么做到前后端分离
  • SAP CRM里Lead通过工作流自动创建Opportunity的原理讲解 ...
  • 阿里云IoT边缘计算助力企业零改造实现远程运维 ...
  • ​浅谈 Linux 中的 core dump 分析方法
  • # SpringBoot 如何让指定的Bean先加载
  • #我与Java虚拟机的故事#连载11: JVM学习之路
  • (17)Hive ——MR任务的map与reduce个数由什么决定?
  • (6)添加vue-cookie
  • (7)STL算法之交换赋值
  • (函数)颠倒字符串顺序(C语言)
  • (七)Appdesigner-初步入门及常用组件的使用方法说明
  • (未解决)jmeter报错之“请在微信客户端打开链接”
  • (学习日记)2024.04.04:UCOSIII第三十二节:计数信号量实验
  • (译) 理解 Elixir 中的宏 Macro, 第四部分:深入化
  • .htaccess配置重写url引擎
  • .NET Core 成都线下面基会拉开序幕
  • .Net CoreRabbitMQ消息存储可靠机制
  • .NET Framework 的 bug?try-catch-when 中如果 when 语句抛出异常,程序将彻底崩溃
  • .NET 直连SAP HANA数据库
  • .NET/C# 获取一个正在运行的进程的命令行参数
  • .net操作Excel出错解决
  • .NET分布式缓存Memcached从入门到实战
  • .NET企业级应用架构设计系列之结尾篇
  • /etc/motd and /etc/issue
  • @data注解_SpringBoot 使用WebSocket打造在线聊天室(基于注解)
  • @requestBody写与不写的情况
  • @RequestBody与@ModelAttribute
  • [2015][note]基于薄向列液晶层的可调谐THz fishnet超材料快速开关——