当前位置: 首页 > news >正文

优化器与现有网络模型的修改

文章目录

    • 一、优化器是什么
    • 二、优化器的使用
    • 三、分类模型VGG16
    • 四、现有网络模型的修改

一、优化器是什么

优化器(Optimizer)是一个算法,用于在训练过程中调整模型的参数,以便最小化损失函数(Loss Function)。损失函数衡量的是模型预测值与真实值之间的差异,而优化器则负责通过更新模型的权重(Weights)和偏置(Biases)来减少这种差异。

利用得到的梯度,用优化器对梯度进行修正,从而得到整体误差降低的目的。

优化器Optimizer 所需要从参数:

在这里插入图片描述

参数解析:

  • model.parameters()是训练的模型
  • lr(LearningRate)是学习率,这是最核心的参数之一,它决定了在每次迭代中参数更新的步长。如果学习率太高,可能会导致训练过程中的梯度爆炸,使模型无法收敛,训练很不稳定;如果学习率太低,训练过程可能会变得非常缓慢。
    推荐一开始用大的lr值进行运算,到后面用小的lr再进行运算。
  • 动量(Momentum)往往是特定参数,是用于加速梯度下降方法,特别是在处理凸优化问题时。它通过在连续的迭代中累积梯度信息来帮助优化器克服局部最小值,并加快收敛速度。

二、优化器的使用

本文使用我的上一章内容神经网络内容进行续写,神经网络具体可跳转损失函数和反向传播

使用一下代码来进行梯度优化:

    optim.zero_grad()# 向后传播result_loss.backward()#这一步对数值进行调优optim.step()

整体代码如下:

import torch
import torchvision
from torch import nn
from torch.nn import Conv2d, MaxPool2d, Flatten, Linear, Sequential
from torch.utils.data import DataLoader
from torch.utils.tensorboard import SummaryWriterdataset = torchvision.datasets.CIFAR10("../data", train=False, transform=torchvision.transforms.ToTensor(), download= True)dataloader = DataLoader (dataset, batch_size = 1)
class Sen(nn.Module):def __init__(self):super(Sen,self).__init__()self.model1 = Sequential(Conv2d(3, 32, 5, 1, 2),MaxPool2d(2),Conv2d(32, 32, 5, padding=2),MaxPool2d(2),Conv2d(32, 64, 5, padding=2),MaxPool2d(2),Flatten(),Linear(1024, 64),Linear(64, 10))def forward(self, x):x = self.model1(x)return xloss = nn.CrossEntropyLoss()
sen = Sen()#随机梯度下降
optim = torch.optim.SGD(sen.parameters(), lr=0.01)for data in dataloader:imgs, tatgets = dataoutputs = sen(imgs)result_loss = loss(outputs, tatgets)#对参数进行梯度清零optim.zero_grad()# 向后传播result_loss.backward()#这一步对数值进行调优optim.step()

在未运行时的梯度没有值:
在这里插入图片描述
当运行一下:
在这里插入图片描述
可以看到每个参数节点的值被计算出来了。

当for循环第二次运行的时候,可以看到grad梯度已经被优化了:

在这里插入图片描述

通过反复循环,上图中的data数据,也就是loss就会越来越被优化。

上面的for循环其实是为数据的一次小循环,我们可以加上epoch 外嵌套 进行数据的一轮轮循环深度优化:

for epoch in range(20):running_loss = 0.0#这里只是进行了一次的学习for data in dataloader:imgs, tatgets = dataoutputs = sen(imgs)result_loss = loss(outputs, tatgets)#对参数进行梯度清零optim.zero_grad()# 向后传播result_loss.backward()#这一步对数值进行调优aoptim.step()#这一步就相当于所有误差的一个整体求和running_loss = running_loss + result_loss

整体代码:

import torch
import torchvision
from torch import nn
from torch.nn import Conv2d, MaxPool2d, Flatten, Linear, Sequential
from torch.utils.data import DataLoader
from torch.utils.tensorboard import SummaryWriterdataset = torchvision.datasets.CIFAR10("../data", train=False, transform=torchvision.transforms.ToTensor(), download= True)dataloader = DataLoader (dataset, batch_size = 1)
class Sen(nn.Module):def __init__(self):super(Sen,self).__init__()self.model1 = Sequential(Conv2d(3, 32, 5, 1, 2),MaxPool2d(2),Conv2d(32, 32, 5, padding=2),MaxPool2d(2),Conv2d(32, 64, 5, padding=2),MaxPool2d(2),Flatten(),Linear(1024, 64),Linear(64, 10))def forward(self, x):x = self.model1(x)return xloss = nn.CrossEntropyLoss()
sen = Sen()#随机梯度下降
optim = torch.optim.SGD(sen.parameters(), lr=0.01)#这里是进行一轮一轮的学习
for epoch in range(20):running_loss = 0.0#这里只是进行了一次的学习for data in dataloader:imgs, tatgets = dataoutputs = sen(imgs)result_loss = loss(outputs, tatgets)#对参数进行梯度清零optim.zero_grad()# 向后传播result_loss.backward()#这一步对数值进行调优aoptim.step()#这一步就相当于所有误差的一个整体求和running_loss = running_loss + result_lossprint(running_loss)

运行结果如下,可以看到,整个神经网络在所有的数据当中,它的误差之和如下:

在这里插入图片描述

在第一轮优化的时候,整个神经网络的误差之和是18779
在第二轮优化的时候,整个神经网络的误差之和是16205
在第三轮优化的时候,整个神经网络的误差之和是15448

可以看到,通过优化器的一轮轮优化,整体的loss值会一直降低,从而达到数据优化的效果。

三、分类模型VGG16

pytorch为我们提供了很多网络模型,其中包括分类模型VGG16

分类模型VGG16是基于ImageNet数据集进行训练的,所以我们需要下载ImageNet数据集

由于ImageNet数据集的内存为143g,会发生以下报错,需要我们自己去下载ImageNet数据集再放在根目录当中。
在这里插入图片描述

既然ImageNet数据集太大,那么就换一条思路,用一下方法加载vgg16

import torchvision.datasets
vgg16_false = torchvision.models.vgg16(pretrained=False)
vgg16_True = torchvision.models.vgg16(pretrained=True)
print('ok')

如果pretrained = True,说明这个数据集已经是训练好的了。
如果pretrained = False,说明这些参数是一个初始参数,没有在任何参数集上面进行训练。
如果progress = True,显示下载进度条
如果progress = Flase,则不显示下载进度条

vgg16_false = torchvision.models.vgg16(pretrained=False),这代码表示只是加载网络模型(也就是像之前的网络模型那样,只是加载模型,含有卷积,池化等,其中的参数都是默认的),所以它不需要下载。
vgg16_True = torchvision.models.vgg16(pretrained=True),这代码表示需要把网络模型参数进行一个下载,还要加载对应的参数。故它需要进行下载。
简单理解就是False不需要进行下载,而True需要进行下载。
VGG16将数据集分成1000个类。

print(vgg16_true)
输出结果:
在这里插入图片描述
在这里插入图片描述
看它把各种卷积层,最大池化都自动按参数下载好了。

常用的CIFAR10会把数据集分成10个类。
vgg16会把数据集分成1000个类,如上图的out_features=1000

四、现有网络模型的修改

方法:像上面得到的是out_features=1000,我们可以进行一个新的处理,通过Linear将输入是1000,而输出为10,从而达到降类的效果。

vgg16_true.add_module("add_linear", nn.Linear(1000, 10))

运行得到:
在这里插入图片描述
可以看到,在add_linear这里的out_features=10

如果要想类的改变在classifier当中,那么代码只需要添加上classifier

vgg16_true.classifier.add_module("add_linear", nn.Linear(1000, 10))

运行结果:
在这里插入图片描述
整体代码如下:

import torchvision.datasets
from torch import nnvgg16_false = torchvision.models.vgg16(pretrained=False)
vgg16_true = torchvision.models.vgg16(pretrained=True)print(vgg16_true)train_data = torchvision.datasets.CIFAR10("./data",train=True, transform=torchvision.transforms.ToTensor(),download=True)vgg16_true.classifier.add_module("add_linear", nn.Linear(1000, 10))

如果想直接在上面 (6)Linear 里面修改out_features,而不是新命名一个(add_linear)进行修改也是可以的

用vgg16_flase进行示范:

在没进行修改前print(vgg16_false)

运行结果:
在这里插入图片描述
直接在(6)Linear中修改out_features为10

代码:

vgg16_false.classifier[6] = nn.Linear(4096, 10)

运行结果:
在这里插入图片描述
可以看到out_features=10,从而成功修改现有的网络模型。

相关文章:

  • 北京网站建设多少钱?
  • 辽宁网页制作哪家好_网站建设
  • 高端品牌网站建设_汉中网站制作
  • 软件编程随想
  • 内存dump文件分析
  • STM32--基于PWM的呼吸灯实验
  • 服务器断电重启后报XFS文件系统错误 XFS (dm-0)_ Metadata I_O error
  • 多线程之CompletableFuture
  • nodejs 011: nodejs事件驱动编程 EventEmitter 与 IPC
  • SLA 概念和计算方法
  • 智慧课堂学生行为数据集
  • AI预测福彩3D采取888=3策略+和值012路或胆码测试9月19日新模型预测第92弹
  • 基于深度学习的零售柜商品识别系统实战思路
  • Vue2篇
  • 【60天备战2024年11月软考高级系统架构设计师——第21天:系统架构设计原则——高内聚低耦合】
  • C++实现的小游戏
  • watch和computed的使用及区别
  • Unity3D 小案例 像素贪吃蛇 02 蛇的觅食
  • C++类的相互关联
  • CSS盒模型深入
  • Fastjson的基本使用方法大全
  • js数组之filter
  • mac修复ab及siege安装
  • magento2项目上线注意事项
  • mysql外键的使用
  • nodejs:开发并发布一个nodejs包
  • Python语法速览与机器学习开发环境搭建
  • Swift 中的尾递归和蹦床
  • Three.js 再探 - 写一个跳一跳极简版游戏
  • Vue小说阅读器(仿追书神器)
  • 基于webpack 的 vue 多页架构
  • 基于游标的分页接口实现
  • 面试题:给你个id,去拿到name,多叉树遍历
  • 想使用 MongoDB ,你应该了解这8个方面!
  • 一个JAVA程序员成长之路分享
  • 2017年360最后一道编程题
  • 如何在 Intellij IDEA 更高效地将应用部署到容器服务 Kubernetes ...
  • ​力扣解法汇总946-验证栈序列
  • ​马来语翻译中文去哪比较好?
  • $(document).ready(function(){}), $().ready(function(){})和$(function(){})三者区别
  • (2)MFC+openGL单文档框架glFrame
  • (3)选择元素——(17)练习(Exercises)
  • (C++17) std算法之执行策略 execution
  • (C++20) consteval立即函数
  • (DenseNet)Densely Connected Convolutional Networks--Gao Huang
  • (LNMP) How To Install Linux, nginx, MySQL, PHP
  • (考研湖科大教书匠计算机网络)第一章概述-第五节1:计算机网络体系结构之分层思想和举例
  • (实战篇)如何缓存数据
  • ***汇编语言 实验16 编写包含多个功能子程序的中断例程
  • .“空心村”成因分析及解决对策122344
  • .Net7 环境安装配置
  • .NET设计模式(11):组合模式(Composite Pattern)
  • .NET使用HttpClient以multipart/form-data形式post上传文件及其相关参数
  • .ui文件相关
  • .xml 下拉列表_RecyclerView嵌套recyclerview实现二级下拉列表,包含自定义IOS对话框...
  • /proc/stat文件详解(翻译)
  • @angular/cli项目构建--Dynamic.Form
  • [ 隧道技术 ] cpolar 工具详解之将内网端口映射到公网