当前位置: 首页 > news >正文

模型训练套路(一)

一、训练完整使用网络模型

import torch
import torchvision
from torch import nn
from torch.utils.data import DataLoader
from model1 import* # 此处的引用为此文在实现过程中所解决的问题

train_data = torchvision.datasets.CIFAR10(root = "../data", train=True,
                                          transform=torchvision.transforms.ToTensor(),download=True)
test_data = torchvision.datasets.CIFAR10(root = "../data", train=False,
                                          transform=torchvision.transforms.ToTensor(),download=True)
# 查看数据集的长度
train_data_size = len(train_data)
test_data_size = len(test_data)
# 格式化 # 格式化注意的是,之间是.的连接
print("训练数据集的长度为: {}".format(train_data_size))
print("测试数据集的长度为: {}".format(test_data_size))

# 利用dataloade r加载数据集#加载数据集的参数设置
train_dataloader = DataLoader (train_data, batch_size=64)
test_dataloader = DataLoader(test_data, batch_size=64)

# 创建网络模型
sun = SUN()

# 损失函数 交叉熵函数的使用
loss_fn = nn.CrossEntropyLoss()

# 优化器(SGD随机梯度下降)
learning_rate = 0.01
optimizer = torch.optim.SGD(sun.parameters(), lr = learning_rate)

# 设置网络训练的参数
# 记录训练的次数
total_train_step = 0
# 记录测试的次数
total_test_step = 0
# 训练的轮数
epoch = 10

for i in range(epoch):
    print("-------第{}轮训练开始--------".format(i+1))
    # 训练网络模型,从训练的data中取数据
    # 训练步骤开始
    for data in train_dataloader:
        imgs, targets = data
        outputs = sun(imgs)
        # 将得到的输出与真实的target比较,得到误差
        loss =loss_fn(outputs, targets)

        # 优化器优化模型
        # 进行优化,首先是梯度清零
        optimizer.zero_grad()
        # 得到每个节点的梯度
       loss.backward()
        # 对其中的参数进行优化
       optimizer.step()

        total_test_step = total_test_step + 1
        print("训练次数:{}, Loss: {}".format(total_test_step, loss.item()))

二、调用的神经网络模型

import torch
from torch import nn
from torch.nn import Conv2d, MaxPool2d, Flatten, Linear, Sequentialclass SUN(nn.Module):def __init__(self):super(SUN, self).__init__() self.model = nn.Sequential(nn.Conv2d(3, 32, 5, 1, 2), nn.MaxPool2d(2),  nn.Conv2d(32, 32, 5, 1, 2),nn.MaxPool2d(2), nn.Conv2d(32, 64, 5, 1, 2), nn.MaxPool2d(2),nn.Flatten(),   nn.Linear(1024, 64), nn.Linear(64, 10))def forward(self, x):x =self.model(x)return xif __name__ == '__main__':sun = SUN()input = torch.ones((64, 3, 32,32))output = sun(input)print(output.shape)

三、调用python文件

在调用的python文件时,会出现一些问题:

from model1 import*

使用该语句调用,但是model1会画红色波浪线报错,并且,引用的神经网络也会出现报错,原因就是,未正确引用py文件。

尝试的解决办法:使用.model1,这种办法不可取后;

使用标记目录仍未成功;

最终,神经网络的py文件与训练的该文件在同一目录下,将被引用的Py文件,放在需引用文件的上一级目录下。也就是说,被引用文件在需引用文件的上一级。

# 接套路一代码:

# 如何知道数据训练好了没有
# 利用现有模型进行测试
# 在测试数据集上走一遍,以测试数据集的损失,来判定模型训练好了没有
# 测试过程中不需要在对模型进行调优
# 测试步骤开始
    total_test_loss = 0 with torch.no_grad(): # 将参数梯度调零for data in test_dataloader:imgs, targets = dataoutputs = sun(imgs)loss =loss_fn(outputs, targets)total_test_loss = total_test_loss + loss.item()print("整体测试集上的Loss:{}".format(total_test_loss))writer.add_scalar("test_loss",total_test_loss, total_train_step)total_test_step +=1

                # 对模型的保存

        torch.save(sun, "sun_{}.path".format(i))print("模型已保存")writer.close()

在tensorboard上显示:

经过10轮的训练,测试集与训练集的损失值变化。

输出(outputs)与最终的预测(predicts)之间的转变,使用函数Argmax,就能够求出横向的最大值所在的位置。

相关文章:

  • 北京网站建设多少钱?
  • 辽宁网页制作哪家好_网站建设
  • 高端品牌网站建设_汉中网站制作
  • [数据集][目标检测]街道乱堆垃圾检测数据集VOC+YOLO格式94张1类别
  • 【数学建模】2024数学建模国赛B题(word论文+matlab):生产过程中的决策问题
  • C++ STL-List容器概念及应用方法详解
  • 2024年高教社杯数学建模国赛C题超详细解题思路分析
  • Linux 一个简单的中断信号实现
  • 力扣100题——子串
  • 经验笔记:SSL证书
  • Stream插件相关的用法
  • 操作系统概述及特征
  • 回溯——7.子集II
  • 【蓝桥杯嵌入式(一)程序框架和调度器】
  • 《机器学习》 基于SVD的矩阵分解 推导、案例实现
  • AI基础 L1 Introduction to Artificial Intelligence
  • k8s技术架构
  • 多维时序 | Matlab基于SSA-SVR麻雀算法优化支持向量机的数据多变量时间序列预测
  • 《Javascript数据结构和算法》笔记-「字典和散列表」
  • Java 11 发布计划来了,已确定 3个 新特性!!
  • learning koa2.x
  • MYSQL 的 IF 函数
  • Python 反序列化安全问题(二)
  • Spring Boot MyBatis配置多种数据库
  • 从零到一:用Phaser.js写意地开发小游戏(Chapter 3 - 加载游戏资源)
  • 动手做个聊天室,前端工程师百无聊赖的人生
  • 分布式事物理论与实践
  • 前端每日实战:70# 视频演示如何用纯 CSS 创作一只徘徊的果冻怪兽
  • 手写双向链表LinkedList的几个常用功能
  • 提升用户体验的利器——使用Vue-Occupy实现占位效果
  • 微信小程序:实现悬浮返回和分享按钮
  • 我的业余项目总结
  • 详解NodeJs流之一
  • 一起来学SpringBoot | 第十篇:使用Spring Cache集成Redis
  • 硬币翻转问题,区间操作
  • 关于Android全面屏虚拟导航栏的适配总结
  • ​ssh免密码登录设置及问题总结
  • ​马来语翻译中文去哪比较好?
  • # 详解 JS 中的事件循环、宏/微任务、Primise对象、定时器函数,以及其在工作中的应用和注意事项
  • (~_~)
  • (react踩过的坑)antd 如何同时获取一个select 的value和 label值
  • (附源码)springboot炼糖厂地磅全自动控制系统 毕业设计 341357
  • (接口封装)
  • (三)mysql_MYSQL(三)
  • (四)搭建容器云管理平台笔记—安装ETCD(不使用证书)
  • (自用)gtest单元测试
  • ./mysql.server: 没有那个文件或目录_Linux下安装MySQL出现“ls: /var/lib/mysql/*.pid: 没有那个文件或目录”...
  • .net core 3.0 linux,.NET Core 3.0 的新增功能
  • .net 重复调用webservice_Java RMI 远程调用详解,优劣势说明
  • .NET/C# 使用 ConditionalWeakTable 附加字段(CLR 版本的附加属性,也可用用来当作弱引用字典 WeakDictionary)
  • .Net+SQL Server企业应用性能优化笔记4——精确查找瓶颈
  • .Net6 Api Swagger配置
  • .NET6 开发一个检查某些状态持续多长时间的类
  • .net利用SQLBulkCopy进行数据库之间的大批量数据传递
  • ??Nginx实现会话保持_Nginx会话保持与Redis的结合_Nginx实现四层负载均衡
  • [ 手记 ] 关于tomcat开机启动设置问题
  • [1525]字符统计2 (哈希)SDUT
  • [2023年]-hadoop面试真题(一)