当前位置: 首页 > news >正文

损失函数与反向传播

一、损失函数

例子:

outputtarget
选择(10)选择(30)

填空(10)

填空(20)
解答(10)解答(50)
loss = (30-10)+(20-10)+(50-10)

loss的值越小越好,根据loss提高输出,神经网络根据loss的值不断的训练。计算实际输出与目标之间的差距,2、为更新输出提供一定的依据(反向传播)

二、损失函数代码

在代码运行中,出现如下问题got long:

RuntimeError: mean(): could not infer output dtype. Input dtype must be either a floating point or complex dtype. Got: Long

需要加入的语句是加入dtype = torch.float32语句

input = torch.tensor([1, 2, 3], dtype=torch.float32)
targets = torch.tensor([1, 2, 5], dtype=torch.float32)

 完整代码如下:

import torch
from torch.nn import L1Loss, MSELoss
# 防止自己写错的办法
from torch import nninput = torch.tensor([1, 2, 3], dtype=torch.float32)
targets = torch.tensor([1, 2, 5], dtype=torch.float32)# 将输入input转化为,形状为1 batch_size,1 chanel,1 行 3列
inputs = torch.reshape(input, (1, 1, 1, 3))
targets = torch.reshape(targets, (1, 1, 1, 3))# 损失函数是L1Loss
# 默认是做平均为0.667,也可以设置参数做和为2.0
loss = L1Loss(reduction="sum")
result = loss(inputs, targets)
print(result)# 损失函数是MSELoss,平方差
loss_MSE = nn.MSELoss()
result_mse = loss_MSE(inputs, targets)
print(result_mse)# 交叉熵 batch_size等于1,class等于3,表示3个类
x = torch.tensor([0.1, 0.2, 0.3])
y = torch.tensor([1])
# 将x变成满足条件的(N, C)结构,输入为x
x = torch.reshape(x, [1, 3])
# 交叉熵
loss_cross = nn.CrossEntropyLoss()
result = loss_cross(x, y)
print(result)

三、神经网络预测

再例如,放入到之前的神经网络中处理,对图片进行预测:

import torch
import torchvision
from torch import nn
from torch.nn import Conv2d, MaxPool2d, Flatten, Linear, Sequential
from torch.utils.data import DataLoader
from torch.utils.tensorboard import SummaryWriterdataset = torchvision.datasets.CIFAR10('../datas', train=False, download=True, transform=torchvision.transforms.ToTensor())
dataloader = DataLoader(dataset, batch_size=1)class SUN(nn.Module):def __init__(self):super(SUN, self).__init__()self.model1 = Sequential(Conv2d(3, 32, 5, padding=2),MaxPool2d(2),Conv2d(32, 32, 2, padding=2),MaxPool2d(2),Conv2d(32, 64, 5, padding=2),MaxPool2d(2),Flatten(),Linear(1024, 64),Linear(64, 10))def forward(self, x): x =self.model1(x)return xsun = SUN()
for data in dataloader:imgs, targets = dataoutputs = sun(imgs)print(outputs)print(targets)

输出的结果是:

0d9a704e599a436a9af1e2af686162a9.png

第一个tensor数据是图片的概率,其中概率最大的是第四张图像。预测出图像是第四张图片的类别。

四、与损失函数的误差

使用损失函数在最后一部分:

loss = nn.CrossEntropyLoss()sun = SUN()
for data in dataloader:imgs, targets = dataoutputs = sun(imgs)result = loss(outputs, targets)print(result)

输出结果,表示神经网络的输出与真实网络的输出误差:

789fda7511ec491c86a60071cc36bc18.png

五、反向传播

对于神经网络来说,每一个卷积和中的每一个参数就是我们送需要调节的,给每一个卷积核的参数都设置了一个grad。每一个节点,每一个参数都会提供一个grad。在优化过程中,就会根据grad来进行优化。实现对整个loss降低的目的。

将上述的代码进行debug查看grad:

d254b47c6d1a4724a27f8719ee5b3f8c.png

将断点打在backward处。

如何使用debug:

点击小虫子,运行结束后,出现各个变量;

选择自己搭建的神经网络;

选择Module;

选择私密属性;

选择modules;

点进变量去;

可以查看到grad=None

3101c664f7714534991f2c4c7d42eebb.png

运行下一句,按上述按钮,出现,grad的参数值。

相关文章:

  • 北京网站建设多少钱?
  • 辽宁网页制作哪家好_网站建设
  • 高端品牌网站建设_汉中网站制作
  • 网络编程(TCP+网络模型)
  • Intel 在人工智能领域
  • ClickHouse 的安装与基本配置
  • C++深入理解哈希表的设计与实现:处理冲突的多种方法
  • Python股票接口实现量化交易的优势是什么
  • Ubuntu环境的MySql下载安装
  • Flutter自动打包ios ipa并且上传
  • 【BIO、NIO、AIO适用场景分析】
  • 大数据-119 - Flink Window总览 窗口机制-滚动时间窗口-基于时间驱动基于事件驱动
  • Word封面对齐技巧
  • 数据库中的逐行数据处理
  • FPGA随记——OSERDESE2和IERDESE2
  • (纯JS)图片裁剪
  • PyTorch 创建数据集
  • 《论系统安全架构设计及其应用》写作框架,软考高级系统架构设计师
  • 【翻译】Mashape是如何管理15000个API和微服务的(三)
  • Android开源项目规范总结
  • CAP理论的例子讲解
  • Create React App 使用
  • java架构面试锦集:开源框架+并发+数据结构+大企必备面试题
  • leetcode讲解--894. All Possible Full Binary Trees
  • Vue 动态创建 component
  • 海量大数据大屏分析展示一步到位:DataWorks数据服务+MaxCompute Lightning对接DataV最佳实践...
  • 免费小说阅读小程序
  • 深入浏览器事件循环的本质
  • 数据结构java版之冒泡排序及优化
  • 怎样选择前端框架
  • 主流的CSS水平和垂直居中技术大全
  • 带你开发类似Pokemon Go的AR游戏
  • 正则表达式-基础知识Review
  • ​Kaggle X光肺炎检测比赛第二名方案解析 | CVPR 2020 Workshop
  • ​Redis 实现计数器和限速器的
  • (bean配置类的注解开发)学习Spring的第十三天
  • (el-Date-Picker)操作(不使用 ts):Element-plus 中 DatePicker 组件的使用及输出想要日期格式需求的解决过程
  • (javaweb)Http协议
  • (ros//EnvironmentVariables)ros环境变量
  • (创新)基于VMD-CNN-BiLSTM的电力负荷预测—代码+数据
  • (附源码)ssm高校运动会管理系统 毕业设计 020419
  • (十七)Flink 容错机制
  • (循环依赖问题)学习spring的第九天
  • (转)EOS中账户、钱包和密钥的关系
  • (转)Sql Server 保留几位小数的两种做法
  • (转)总结使用Unity 3D优化游戏运行性能的经验
  • .bat批处理(八):各种形式的变量%0、%i、%%i、var、%var%、!var!的含义和区别
  • .equals()到底是什么意思?
  • .Net CF下精确的计时器
  • .Net CoreRabbitMQ消息存储可靠机制
  • .net MySql
  • .NET/C# 利用 Walterlv.WeakEvents 高性能地定义和使用弱事件
  • .NET/C# 在 64 位进程中读取 32 位进程重定向后的注册表
  • .net中应用SQL缓存(实例使用)
  • @configuration注解_2w字长文给你讲透了配置类为什么要添加 @Configuration注解
  • @DependsOn:解析 Spring 中的依赖关系之艺术
  • @NoArgsConstructor和@AllArgsConstructor,@Builder
  • @test注解_Spring 自定义注解你了解过吗?