当前位置: 首页 > news >正文

深度学习损失计算

文章目录

  • 深度学习损失计算
    • 1.如何计算当前epoch的损失?
    • 2.为什么要计算样本平均损失,而不是计算批次平均损失?

深度学习损失计算

1.如何计算当前epoch的损失?

深度学习中的损失计算,通常为数据集的平均损失,即每个样本的平均损失值。计算步骤如下:

  • 计算单个批次的损失。每次迭代中,用当前模型预测值和真实值计算损失。假设 _loss 是这次迭代中计算得到的损失。
  • 转换为标量。利用item()方法将其转换为标量值。_loss.item()
  • 乘以批次大小。乘以批次大小的原因是,希望总损失是所有数据点的损失总和,而不是批次平均损失。
  • 累加损失loss += _loss.item() * batch_size 将当前批次的总损失累加到变量 loss 中。这样所有批次遍历结束后,就得到一个epoch的总损失。
  • 计算当前epoch的样本平均损失。通过总损失除以总的数据样本数,来得到平均损失。average_loss = loss/len(dataloader.dataset)【注意:除的是总的数据样本数(len(dataloader.dataset))!不是总的批次数(len(dataloader))!】

示例代码如下:

for epoch in total_epoch:  # epoch迭代total_loss = 0.0  # 初始化总损失for inputs, targets in dataloader:  # batch迭代outputs = model(inputs)  # 获取预测值_loss = criterion(outputs, targets)  # 计算当前批次损失,为批次平均损失batch_size = inputs.size(0)  # 获取批次大小total_loss += _loss.item() * batch_size  # 计算当前批次的总损失# 计算当前epoch的平均损失average_loss = total_loss / len(dataloader.dataset)  

2.为什么要计算样本平均损失,而不是计算批次平均损失?

由于每个批次的大小可能不一样,特别是在数据集的大小不是批次大小的整数倍时,所以使用 len(dataloader) 会导致错误的平均损失计算。

下面用一个简单的例子,解释这两种计算方式的不同:

假设数据集有 105 个样本,每个批次大小为 10,这样会有 11 个批次,其中最后一个批次只有 5 个样本。结合上面的伪代码,假设损失值 _loss.item() 是 1,对于 10 个批次的损失是 10,最后一个批次的损失是 5。那么:

  • t o t a l _ l o s s = ( 1 ∗ 10 ) ∗ 10 + ( 1 ∗ 5 ) ∗ 1 = 105 total\_loss = (1 * 10) * 10 + (1 * 5) * 1 = 105 total_loss=(110)10+(15)1=105
  • l e n ( d a t a l o a d e r . d a t a s e t ) = 105 len(dataloader.dataset) = 105 len(dataloader.dataset)=105
  • l e n ( d a t a l o a d e r ) = 11 len(dataloader) = 11 len(dataloader)=11

计算结果:

  • 样本平均损失计算:average_loss = total_loss / len(dataloader.dataset) 105 / 105 = 1 105/105 = 1 105/105=1
  • 批次平均损失计算:average_loss = total_loss / len(dataloader) 105 / 11 ≈ 9.545 105/11 \approx 9.545 105/119.545

显然,第一种方式是正确的,反映了每个样本的真实平均损失。

😃😃😃

相关文章:

  • 北京网站建设多少钱?
  • 辽宁网页制作哪家好_网站建设
  • 高端品牌网站建设_汉中网站制作
  • SpringBoot使用开发环境的application.properties
  • go语言 fmt的几个打印区别以及打印格式
  • Linux内核启用 bridge 模块
  • UPFC统一潮流控制器的simulink建模与仿真
  • React、Vue的password输入框组件,如何关闭自动填充?
  • Go 语言中的互斥锁 Mutex
  • ARFoundation系列讲解 - 91 Immersal 简介
  • 独角数卡(自动发卡系统)开源自动化售货最新2.0.6
  • 【UE5.1】NPC人工智能——02 NPC移动到指定位置
  • 高性能存储 SIG 月度动态:优化 xfs dax reflink 时延,独立选型并维护 mdadm 和 ledmon
  • gradle学习及问题
  • 【Unity学习笔记】第十九 · 物理引擎约束求解解惑(LCP,最优,拉格朗日乘数法,SI,PGS,基于冲量法)
  • docker-cli nerdctl ctr crictl容器命令比较
  • 基于jeecgboot-vue3的Flowable流程支持bpmn流程设计器与仿钉钉流程设计器-编辑多版本处理
  • NLP入门——RNN、LSTM模型的搭建、训练与预测
  • [ 一起学React系列 -- 8 ] React中的文件上传
  • [译]前端离线指南(上)
  • 【技术性】Search知识
  • css系列之关于字体的事
  • Java Agent 学习笔记
  • java 多线程基础, 我觉得还是有必要看看的
  • Js基础知识(一) - 变量
  • WePY 在小程序性能调优上做出的探究
  • 分布式熔断降级平台aegis
  • 前言-如何学习区块链
  • 什么软件可以剪辑音乐?
  • 微信如何实现自动跳转到用其他浏览器打开指定页面下载APP
  • 【云吞铺子】性能抖动剖析(二)
  • 回归生活:清理微信公众号
  • ​html.parser --- 简单的 HTML 和 XHTML 解析器​
  • ​HTTP与HTTPS:网络通信的安全卫士
  • #pragma once
  • #数据结构 笔记一
  • (160)时序收敛--->(10)时序收敛十
  • (创新)基于VMD-CNN-BiLSTM的电力负荷预测—代码+数据
  • (待修改)PyG安装步骤
  • (二)Kafka离线安装 - Zookeeper下载及安装
  • (附源码)springboot 智能停车场系统 毕业设计065415
  • (附源码)ssm学生管理系统 毕业设计 141543
  • (含react-draggable库以及相关BUG如何解决)固定在左上方某盒子内(如按钮)添加可拖动功能,使用react hook语法实现
  • (四)activit5.23.0修复跟踪高亮显示BUG
  • (一一四)第九章编程练习
  • .net core 6 redis操作类
  • .NET Core 中插件式开发实现
  • .NET 中让 Task 支持带超时的异步等待
  • .NET高级面试指南专题十一【 设计模式介绍,为什么要用设计模式】
  • // an array of int
  • @AliasFor注解
  • [ Socket学习 ] 第一章:网络基础知识
  • [20170713] 无法访问SQL Server
  • [3D游戏开发实践] Cocos Cyberpunk 源码解读-高中低端机性能适配策略
  • [AIGC] 广度优先搜索(Breadth-First Search,BFS)详解
  • [AutoSAR 存储] 汽车智能座舱的存储需求
  • [C#]实现GRPC通讯的服务端和客户端实例
  • [docker] Docker的数据卷、数据卷容器,容器互联