当前位置: 首页 > news >正文

基于PyTorch分布式训练的实现与优化(以CIFAR-10为例)

基于PyTorch DDP分布式训练的实现与优化

引言

在深度学习的实际应用中,训练大型神经网络常常需要大量的计算资源和时间。为了提高训练效率,利用多GPU进行分布式训练成为了一种常见的解决方案。本文将介绍如何使用PyTorch框架进行分布式训练,特别是通过DistributedDataParallel(DDP)模块实现更高效的并行化处理。

环境设置

在开始编写分布式训练代码之前,首先需要设置多GPU的运行环境。示例代码中首先通过setup函数配置每个进程的运行环境,包括指定主节点地址和端口,并初始化进程组:

def setup(rank, world_size):os.environ['MASTER_ADDR'] = 'localhost'os.environ['MASTER_PORT'] = '12357'torch.distributed.init_process_group("nccl", rank=rank, world_size=world_size)torch.cuda.set_device(rank)
数据处理

为了保证数据在不同GPU间均衡,使用DistributedSampler确保每个进程处理数据的唯一性和均匀性。同时,利用PyTorch的DataLoader来并行加载数据,提高数据处理效率。

train_sampler = DistributedSampler(train_set, num_replicas=world_size, rank=rank)
train_loader = DataLoader(train_set, batch_size=16, sampler=train_sampler)
模型构建与训练

本示例使用ResNet-18作为训练模型,并通过DistributedDataParallel封装以支持多GPU训练。DDP能够自动分配模型的参数到各个GPU,并同步参数的更新,极大地提高了训练的速度和效率。

model = resnet18(pretrained=False, num_classes=10).cuda(rank)
model = DDP(model, device_ids=[rank])

训练过程中,通过循环遍历训练数据,执行前向传播、损失计算、反向传播和参数更新。利用tqdm库显示训练进度,增加交互性。

for epoch in range(10):model.train()pbar = tqdm(train_loader, desc="Training")for data in pbar:inputs, labels = data[0].cuda(rank), data[1].cuda(rank)optimizer.zero_grad()outputs = model(inputs)loss = criterion(outputs, labels)loss.backward()optimizer.step()pbar.set_postfix(Loss=loss.item(),Epoch=epoch,Rank=rank)
性能评估

在每个训练周期后,通过评估函数计算模型在测试集上的准确率,从而监控模型的学习效果。

if rank == 0:accuracy = evaluate(model, rank, test_loader)print(f"Rank {rank}, Test Accuracy: {accuracy}%")
总结

本文展示了如何在PyTorch框架下使用DistributedDataParallel模块进行分布式训练。通过此技术,可以有效地利用多个GPU资源,加速大规模数据集上的模型训练。希望本文能帮助读者更好地理解和应用PyTorch的分布式训练技术。

参考文献
  1. PyTorch官方文档:DistributedDataParallel
  2. CIFAR-10数据集介绍

通过本文的介绍,读者可以在自己的项目中实现类似的多GPU训练任务,从而提升训练效率和模型性能。

附录完整代码:

import os
import torch
import torch.nn as nn
import torch.optim as optim
import torchvision
import torchvision.transforms as transforms
from torch.nn.parallel import DistributedDataParallel as DDP
from torch.utils.data import DataLoader, DistributedSampler, Subset
from torchvision.models import resnet18
import wandb
import numpy as np
from tqdm import tqdm
def setup(rank, world_size):os.environ['MASTER_ADDR'] = 'localhost'os.environ['MASTER_PORT'] = '12357'torch.distributed.init_process_group("nccl", rank=rank, world_size=world_size)torch.cuda.set_device(rank)def cleanup():torch.distributed.destroy_process_group()
def evaluate(model, device, test_loader):model.eval()correct = 0total = 0with torch.no_grad():for data in test_loader:images, labels = data[0].to(device), data[1].to(device)outputs = model(images)_, predicted = torch.max(outputs.data, 1)total += labels.size(0)correct += (predicted == labels).sum().item()accuracy = 100 * correct / totalreturn accuracy
def main(rank, world_size):setup(rank, world_size)# if rank==0:#     wandb.init(#         project="project-DDP-Cifar10",##         # track hyperparameters and run metadata#         config={#             "learning_rate": 1e-3,#         }#     )# 数据变换transform = transforms.Compose([transforms.ToTensor(),transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))])# 加载数据集train_set = torchvision.datasets.CIFAR10(root='/mnt/sda/zjb/data/cifar10', train=True, download=True, transform=transform)Downsampled_Sample = False #是否下采样数据if Downsampled_Sample:# 计算要抽取的样本数量(四分之一)# 创建所有索引的列表indices = list(range(len(train_set)))split = int(np.floor(0.25 * len(train_set)))# 取四分之一的随机索引subset_indices = indices[:split]subset=Subset(train_set,subset_indices)train_sampler = DistributedSampler(subset, num_replicas=world_size, rank=rank)else:train_sampler = DistributedSampler(train_set, num_replicas=world_size, rank=rank)train_loader = DataLoader(train_set, batch_size=16, sampler=train_sampler)test_set = torchvision.datasets.CIFAR10(root='/mnt/sda/zjb/data/cifar10', train=False, download=True, transform=transform)test_sampler = DistributedSampler(test_set, num_replicas=world_size, rank=rank)test_loader = DataLoader(test_set, batch_size=64, sampler=test_sampler)# 模型定义model = resnet18(pretrained=False, num_classes=10).cuda(rank)model = DDP(model, device_ids=[rank])# 损失函数和优化器criterion = nn.CrossEntropyLoss().cuda(rank)optimizer = optim.Adam(model.parameters(), lr=0.001)# 训练过程for epoch in range(10):model.train()pbar = tqdm(train_loader, desc="Training")for data in pbar:inputs, labels = data[0].cuda(rank), data[1].cuda(rank)optimizer.zero_grad()outputs = model(inputs)loss = criterion(outputs, labels)loss.backward()optimizer.step()pbar.set_postfix(Loss=loss.item(),Epoch=epoch,Rank=rank)print(f"Rank {rank}, Epoch {epoch}, Loss: {loss.item()}")# 评估模型if rank == 0 :accuracy = evaluate(model, rank, test_loader)print(f"Rank {rank}, Test Accuracy: {accuracy}%")cleanup()if __name__ == "__main__":world_size = 4 #GPU数量torch.multiprocessing.spawn(main, args=(world_size,), nprocs=world_size)

相关文章:

  • 北京网站建设多少钱?
  • 辽宁网页制作哪家好_网站建设
  • 高端品牌网站建设_汉中网站制作
  • 如何搭建RGBD GS-ICP SLAM环境以及如何与自己编的pcl并存
  • 如何在JSON对象中查询特定的值?C语言实现
  • Linux 命令行/bash脚本 批量创建文件
  • Python基础语法(17多线程线程锁单例模式)
  • Android13默认开启电池百分比数字显示Framework
  • 山东大学机试试题合集
  • 服务器数据恢复—OneFS文件系统下数据被删除的数据恢复案例
  • UE驻网失败问题(三)
  • C++知识点总结
  • Pr 入门系列之二:导入与管理素材(上)
  • OSI七层网络协议
  • 【论文阅读】一种针对多核神经网络处理器的窃取攻击(2020)
  • 7:python第三章:更多的数据类型2(字典)
  • SAPUI5基础知识25 - 聚合绑定(Aggregation Binding)
  • CentOS7单机环境安装k8s集群
  • Akka系列(七):Actor持久化之Akka persistence
  • dva中组件的懒加载
  • MYSQL 的 IF 函数
  • Promise初体验
  • React-flux杂记
  • SegmentFault 技术周刊 Vol.27 - Git 学习宝典:程序员走江湖必备
  • vue-router的history模式发布配置
  • webpack+react项目初体验——记录我的webpack环境配置
  • 编写符合Python风格的对象
  • 关于Android中设置闹钟的相对比较完善的解决方案
  • 开源中国专访:Chameleon原理首发,其它跨多端统一框架都是假的?
  • 用jquery写贪吃蛇
  • postgresql行列转换函数
  • ​力扣解法汇总1802. 有界数组中指定下标处的最大值
  • ​如何使用QGIS制作三维建筑
  • ​用户画像从0到100的构建思路
  • # linux 中使用 visudo 命令,怎么保存退出?
  • #window11设置系统变量#
  • (1)虚拟机的安装与使用,linux系统安装
  • (9)STL算法之逆转旋转
  • (Redis使用系列) Springboot 使用redis的List数据结构实现简单的排队功能场景 九
  • (板子)A* astar算法,AcWing第k短路+八数码 带注释
  • (纯JS)图片裁剪
  • (第61天)多租户架构(CDB/PDB)
  • (每日一问)操作系统:常见的 Linux 指令详解
  • (四)linux文件内容查看
  • (四)opengl函数加载和错误处理
  • (四)软件性能测试
  • (未解决)jmeter报错之“请在微信客户端打开链接”
  • (五)IO流之ByteArrayInput/OutputStream
  • (学习日记)2024.02.29:UCOSIII第二节
  • (转)Unity3DUnity3D在android下调试
  • (转)重识new
  • .cn根服务器被攻击之后
  • .NET Core、DNX、DNU、DNVM、MVC6学习资料
  • .NET Core使用NPOI导出复杂,美观的Excel详解
  • .NET 将混合了多个不同平台(Windows Mac Linux)的文件 目录的路径格式化成同一个平台下的路径
  • .net/c# memcached 获取所有缓存键(keys)
  • .NET/C# 使用反射注册事件
  • .NET连接MongoDB数据库实例教程