当前位置: 首页 > news >正文

8-pytorch-损失函数与反向传播

b站小土堆pytorch教程学习笔记

根据loss更新模型参数
1.计算实际输出与目标之间的差距
2.为我们更新输出提供一定的依据(反向传播)

在这里插入图片描述

1 MSEloss

import torch
from torch.nn import L1Loss
from torch import nninputs=torch.tensor([1,2,3],dtype=torch.float32)
targets=torch.tensor([1,2,5],dtype=torch.float32)inputs=torch.reshape(inputs,(-1,1,1,3))
targets=torch.reshape(targets,(-1,1,1,3))loss=L1Loss()
result=loss(inputs,targets)loss_mse=nn.MSELoss()
result_mse=loss_mse(inputs,targets)print(result)
print(result_mse)

tensor(0.6667)
tensor(1.3333)

2 Cross EntropyLoss

在这里插入图片描述

x=torch.tensor([0.1,0.2,0.3])#需要reshape为要求的(batch_size,class)
y=torch.tensor([1])#target已经为要求的batch_size无需reshape
x=torch.reshape(x,(-1,3))
loss_cross=nn.CrossEntropyLoss()
result_cross=loss_cross(x,y)
print(result_cross)

tensor(1.1019)

3 在具体的神经网络中使用loss

import torch
import torchvision.datasets
from torch import nn
from torch.nn import Conv2d, MaxPool2d, Flatten, Linear, Sequential
from torch.utils.data import DataLoader
from torch.utils.tensorboard import SummaryWriterdataset=torchvision.datasets.CIFAR10('dataset',train=False,transform=torchvision.transforms.ToTensor(),download=True)
dataloader=DataLoader(dataset,batch_size=1)class Han(nn.Module):def __init__(self):super(Han, self).__init__()self.model1=Sequential(Conv2d(3,32,5,padding=2),MaxPool2d(2),Conv2d(32,32,5,padding=2),MaxPool2d(2),Conv2d(32,64,5,padding=2),MaxPool2d(2),Flatten(),Linear(1024,64),Linear(64,10))def forward(self,x):x=self.model1(x)return xloss=nn.CrossEntropyLoss()
han=Han()
for data in dataloader:imgs,target=dataoutput=han(imgs)# print(target)# print(output)result_loss=loss(output,target)print(result_loss)

*tensor([7])
tensor([[ 0.0057, -0.0201, -0.0796, 0.0556, -0.0625, 0.0125, -0.0413, -0.0056,
0.0624, -0.1072]], grad_fn=)…

tensor(2.2664, grad_fn=)…

4 反向传播 优化器

  1. 定义优化器
  2. 将待更新的每个参数梯度清零
  3. 调用损失函数的反向传播函数求出每个节点的梯度
  4. 使用step函数对模型的每个参数调优
import torch
import torchvision.datasets
from torch import nn
from torch.nn import Conv2d, MaxPool2d, Flatten, Linear, Sequential
from torch.utils.data import DataLoader
from torch.utils.tensorboard import SummaryWriterdataset=torchvision.datasets.CIFAR10('dataset',train=False,transform=torchvision.transforms.ToTensor(),download=True)
dataloader=DataLoader(dataset,batch_size=64)class Han(nn.Module):def __init__(self):super(Han, self).__init__()self.model1=Sequential(Conv2d(3,32,5,padding=2),MaxPool2d(2),Conv2d(32,32,5,padding=2),MaxPool2d(2),Conv2d(32,64,5,padding=2),MaxPool2d(2),Flatten(),Linear(1024,64),Linear(64,10))def forward(self,x):x=self.model1(x)return xloss=nn.CrossEntropyLoss()
han=Han()
optim=torch.optim.SGD(han.parameters(),lr=0.01)for epoch in range(5):running_loss=0.0#一个epoch结束的loss和for data in dataloader:imgs,target=dataoutput=han(imgs)result_loss=loss(output,target)#每次迭代的lossoptim.zero_grad()#将网络中每个可调节参数对应的梯度调为0result_loss.backward()#优化器需要每个参数的梯度,使用反向传播获得optim.step()#对每个参数调优running_loss=running_loss+result_lossprint(running_loss)

Files already downloaded and verified
tensor(361.0316, grad_fn=)
tensor(357.6938, grad_fn=)
tensor(343.0560, grad_fn=)
tensor(321.8132, grad_fn=)
tensor(313.3173, grad_fn=)

相关文章:

  • Qt读写局域网共享文件夹
  • nodejs:nvm(nodejs版本管理切换工具)
  • [SpringDataMongodb开发游戏服务器实战]
  • Camunda7.18流程引擎启动出现Table ‘camunda_platform_docker.ACT_GE_PROPERTY‘的解决方案
  • week04day02(爬虫02)
  • OSCP靶场--Slort
  • 【python基础学习2】python里和可迭代对象iterator相关的函数:zip(), map(), join() 函数和strip()方法等
  • Guitar Pro8.2吉他软件2024中文版功能特点介绍
  • 【课程作业】提取图中苹果的面积、周长和最小外接矩形的python、matlab和c++代码
  • 【Mongo】mongodump/mongoexport/mongoimport 操作
  • Python | OS模块操作
  • 设计模式学习笔记 - 面向对象 - 7.为什么要多用组合少用继承?如何决定该用组合还是继承?
  • Linux的时间操作
  • Java:获取PDF文件的总页数
  • 第2.6章 StarRocks表设计——数据压缩
  • const let
  • IDEA 插件开发入门教程
  • Intervention/image 图片处理扩展包的安装和使用
  • java取消线程实例
  • Selenium实战教程系列(二)---元素定位
  • sessionStorage和localStorage
  • Stream流与Lambda表达式(三) 静态工厂类Collectors
  • vue中实现单选
  • 创建一种深思熟虑的文化
  • 基于Vue2全家桶的移动端AppDEMO实现
  • 计算机常识 - 收藏集 - 掘金
  • 浅谈web中前端模板引擎的使用
  • 如何实现 font-size 的响应式
  • 入职第二天:使用koa搭建node server是种怎样的体验
  • 使用Tinker来调试Laravel应用程序的数据以及使用Tinker一些总结
  • 事件委托的小应用
  • 数据仓库的几种建模方法
  • 吴恩达Deep Learning课程练习题参考答案——R语言版
  • 一个6年java程序员的工作感悟,写给还在迷茫的你
  • kubernetes资源对象--ingress
  • puppet连载22:define用法
  • 阿里云重庆大学大数据训练营落地分享
  • ​决定德拉瓦州地区版图的关键历史事件
  • (附源码)小程序儿童艺术培训机构教育管理小程序 毕业设计 201740
  • (论文阅读26/100)Weakly-supervised learning with convolutional neural networks
  • (亲测)设​置​m​y​e​c​l​i​p​s​e​打​开​默​认​工​作​空​间...
  • (全注解开发)学习Spring-MVC的第三天
  • (一)基于IDEA的JAVA基础10
  • (原創) 博客園正式支援VHDL語法著色功能 (SOC) (VHDL)
  • (转)shell调试方法
  • (转)VC++中ondraw在什么时候调用的
  • (转)如何上传第三方jar包至Maven私服让maven项目可以使用第三方jar包
  • (转载)OpenStack Hacker养成指南
  • .NET BackgroundWorker
  • .net core webapi 大文件上传到wwwroot文件夹
  • .Net 中的反射(动态创建类型实例) - Part.4(转自http://www.tracefact.net/CLR-and-Framework/Reflection-Part4.aspx)...
  • .NET6使用MiniExcel根据数据源横向导出头部标题及数据
  • .net开发时的诡异问题,button的onclick事件无效
  • /etc/shadow字段详解
  • @SentinelResource详解