当前位置: 首页 > news >正文

PyTorch 基础学习(13)- 混合精度训练

系列文章:
《PyTorch 基础学习》文章索引

基本概念

混合精度训练是深度学习中一种优化技术,旨在通过结合高精度(torch.float32)和低精度(如 torch.float16torch.bfloat16)数据类型的优势,提高计算效率和内存利用率。

  • 高精度(torch.float32:适合需要大动态范围的操作,如损失计算、缩减操作(如求和、平均)等。这些操作对数值稳定性要求较高,使用高精度能确保计算结果的准确性。

  • 低精度(torch.float16torch.bfloat16:适合计算密集型操作,如卷积和矩阵乘法。这些操作在低精度下可以显著提升计算速度,同时减少显存占用。

混合精度训练的核心思想是在模型中自动选择合适的数据类型,以在加速计算的同时,尽可能保持结果的准确性。PyTorch 提供了 torch.amp 模块,该模块封装了一些便捷的工具,使得混合精度的实现更加直观和高效。

重要方法及其作用

torch.autocast

torch.autocast 是混合精度训练中的核心工具。它是一个上下文管理器或装饰器,用于在代码的特定部分启用混合精度。在这些被启用的区域内,autocast 将根据操作的特性自动选择合适的数据类型。例如,卷积操作可以自动转换为 float16,而损失计算则保持为 float32

主要参数:

  • device_type:指定设备类型,如 cudacpuxpu
  • dtype:指定在 autocast 区域内使用的低精度数据类型。对于 CUDA 设备,默认是 torch.float16;对于 CPU 设备,默认是 torch.bfloat16
  • enabled:是否启用混合精度。默认为 True
  • cache_enabled:是否启用权重缓存。默认是 True,可以在某些场景下提高性能。

torch.amp.GradScaler

在低精度(如 float16)下,梯度值较小的操作可能会出现下溢现象,导致梯度值变为零,从而影响模型的训练。为了避免这种情况,PyTorch 提供了 GradScaler,它通过在反向传播之前动态缩放损失值,从而放大梯度值,使其在低精度下也能被有效表示。之后,优化器会在更新参数之前对梯度进行反缩放,以确保不会影响学习率。

主要参数:

  • init_scale:初始的缩放因子,默认是 65536.0
  • growth_factor:在没有发生下溢的情况下,缩放因子增长的倍数,默认是 2.0
  • backoff_factor:发生下溢时,缩放因子减少的倍数,默认是 0.5
  • growth_interval:在多少个步骤之后,如果没有下溢,缩放因子会增长,默认是 2000
  • enabled:是否启用梯度缩放,默认为 True

适用的场景

GPU 训练
在使用 CUDA 设备进行深度学习模型训练时,启用混合精度可以显著提升模型的训练速度。尤其是在使用大规模数据和复杂模型(如卷积神经网络、Transformer 模型)时,torch.autocast(device_type="cuda") 能够有效地减少 GPU 的计算负载,并提高吞吐量。

CPU 训练与推理
虽然 GPU 在深度学习中更常用,但在一些特定场景下(如低资源环境或需要在 CPU 上进行部署),混合精度在 CPU 上同样具有优势。使用 torch.autocast(device_type="cpu", dtype=torch.bfloat16) 可以在推理过程中降低计算复杂度,同时保持较高的精度。

3.3 自定义操作
在某些高级用例中,用户可能需要为自定义的自动微分函数实现混合精度支持。通过 torch.amp.custom_fwdtorch.amp.custom_bwd,用户可以定义在特定设备(如 cuda)上执行的前向和反向操作,并确保这些操作在混合精度模式下正常运行。

应用实例

以下是一个在 CUDA 设备上使用混合精度进行训练的完整示例,展示了如何在实践中应用 torch.autocasttorch.amp.GradScaler

import torch
import torch.nn as nn
import torch.optim as optim
from torch.cuda.amp import autocast, GradScaler# 定义简单的神经网络模型
class SimpleModel(nn.Module):def __init__(self):super(SimpleModel, self).__init__()self.fc1 = nn.Linear(100, 50)self.fc2 = nn.Linear(50, 10)def forward(self, x):x = torch.relu(self.fc1(x))x = self.fc2(x)return x# 创建模型和优化器,使用默认精度(float32)
model = SimpleModel().cuda()
optimizer = optim.SGD(model.parameters(), lr=0.01)# 定义损失函数
loss_fn = nn.CrossEntropyLoss()# 创建GradScaler
scaler = GradScaler()# 训练循环
for epoch in range(10):  # 假设有10个epochfor input, target in data_loader:  # 假设有一个data_loaderinput, target = input.cuda(), target.cuda()optimizer.zero_grad()# 在前向传播过程中启用自动混合精度with autocast(device_type="cuda"):output = model(input)loss = loss_fn(output, target)# 使用GradScaler进行反向传播scaler.scale(loss).backward()scaler.step(optimizer)scaler.update()print(f"Epoch {epoch+1} completed.")

代码说明

  • 首先,我们定义了一个简单的神经网络模型,并将其放置在 CUDA 设备上。
  • 在每次训练循环中,我们使用 torch.autocast(device_type="cuda") 上下文管理器包裹前向传播过程,使得模型的计算自动使用混合精度。
  • 使用 GradScaler 对损失进行缩放,并在缩放后的损失上调用 backward() 进行反向传播。这一步骤有助于防止梯度下溢。
  • scaler.step(optimizer) 用于更新模型参数,scaler.update() 则是调整缩放因子。

这种方法既能提高训练速度,又能在较低精度下保持数值稳定性,是在实际项目中应用混合精度训练的有效方案。

注意事项

  • 弃用警告:从 PyTorch 1.10 开始,原有的 torch.cuda.amp.autocasttorch.cpu.amp.autocast 方法被弃用,推荐使用通用的 torch.autocast 代替。这不仅简化了接口,也为未来的设备扩展提供了灵活性。

  • 数据类型匹配:在使用 autocast 时,确保输入数据类型的一致性非常重要。如果在混合精度区域内生成的张量在退出后与其他不同精度的张量混合使用,可能会导致类型不匹配错误。因此,在必要时,需要手动将张量转换为 float32 或其他合适的精度。

  • GradScaler 的适用性:虽然 GradScaler 对大多数模型都有效,但在某些情况下(例如使用 bf16 预训练模型),可能会出现梯度溢出的情况。因此,在使用混合精度训练时,需要根据具体模型的特性进行调整。

通过对这些概念、方法、使用场景和实例的深入理解,您可以在实际项目中更好地应用混合精度训练,从而提升深度学习模型的训练效率和性能。

相关文章:

  • 北京网站建设多少钱?
  • 辽宁网页制作哪家好_网站建设
  • 高端品牌网站建设_汉中网站制作
  • C++ 设计模式——建造者模式
  • Redis—持久化机制
  • StarRocks 存算分离数据回收原理
  • jpg怎么转换成pdf?6个简单方法,实现jpg转换成pdf
  • 设计模式(一):单例模式
  • 数字IC/FPGA中有符号数的处理探究
  • Python|OpenCV-基于OpenCV进行图像的复制与克隆(19)
  • 第五章 设置和其他常见活动 - 创建 IRIS 凭证集
  • 【hot100篇-python刷题记录】【买卖股票的最佳时机】
  • django之自定义序列化器用法
  • 【Java学习】反射和枚举详解
  • 微服务网关
  • 基于python的自适应svm电影评价倾向性分析设计与实现
  • 全光谱日照模拟系统汽车整车光老化测试 太阳光照射模拟器
  • 【10.2 python中的类的定义和使用】
  • Angular6错误 Service: No provider for Renderer2
  • co.js - 让异步代码同步化
  • Docker: 容器互访的三种方式
  • Git同步原始仓库到Fork仓库中
  • Java的Interrupt与线程中断
  • Linux CTF 逆向入门
  • MySQL Access denied for user 'root'@'localhost' 解决方法
  • nginx(二):进阶配置介绍--rewrite用法,压缩,https虚拟主机等
  • spring-boot List转Page
  • 持续集成与持续部署宝典Part 2:创建持续集成流水线
  • 初探 Vue 生命周期和钩子函数
  • 基于web的全景—— Pannellum小试
  • 京东美团研发面经
  • 前端面试之闭包
  • 如何学习JavaEE,项目又该如何做?
  • 腾讯视频格式如何转换成mp4 将下载的qlv文件转换成mp4的方法
  • 最简单的无缝轮播
  • Hibernate主键生成策略及选择
  • 阿里云IoT边缘计算助力企业零改造实现远程运维 ...
  • ​ssh-keyscan命令--Linux命令应用大词典729个命令解读
  • ‌JavaScript 数据类型转换
  • # 利刃出鞘_Tomcat 核心原理解析(七)
  • #《AI中文版》V3 第 1 章 概述
  • #android不同版本废弃api,新api。
  • $ is not function   和JQUERY 命名 冲突的解说 Jquer问题 (
  • $.ajax()参数及用法
  • (01)ORB-SLAM2源码无死角解析-(56) 闭环线程→计算Sim3:理论推导(1)求解s,t
  • (办公)springboot配置aop处理请求.
  • (二)Linux——Linux常用指令
  • (二十六)Java 数据结构
  • (一)硬件制作--从零开始自制linux掌上电脑(F1C200S) <嵌入式项目>
  • (译)2019年前端性能优化清单 — 下篇
  • (转)Linq学习笔记
  • .gitignore文件—git忽略文件
  • .net core 微服务_.NET Core 3.0中用 Code-First 方式创建 gRPC 服务与客户端
  • .net core使用EPPlus设置Excel的页眉和页脚
  • .net 设置默认首页
  • .NET简谈设计模式之(单件模式)
  • .NET使用HttpClient以multipart/form-data形式post上传文件及其相关参数
  • /boot 内存空间不够