当前位置: 首页 > news >正文

pytorch DistributedDataParallel 分布式训练踩坑记录

目录

    • 一、几个比较常见的概念:
    • 二、踩坑记录
      • 2.1 dist.init_process_group初始化
      • 2.2 spawn启动(rank怎么来的)
      • 2.3 loss backward
      • 2.4 model cuda设置
      • 2.5 数据加载

一、几个比较常见的概念:

  • rank: 多机多卡时代表某一台机器,单机多卡时代表某一块GPU
  • world_size: 多机多卡时代表有几台机器,单机多卡时代表有几块GPU
    world_size = torch.cuda.device_count()
    
  • local_rank: 多机多卡时代表某一块GPU, 单机多卡时代表某一块GPU
    单机多卡的情况要比多机多卡的情况常见的多。
  • DP:适用于单机多卡(=多进程)训练。算是旧版本的DDP
  • DDP:适用于单机多卡训练、多机多卡。

二、踩坑记录

2.1 dist.init_process_group初始化

这一步就是设定一个组,这个组里面设定你有几个进程(world_size),现在是卡几(rank)。让pycharm知道你要跑几个进程,包装在组内,进行通讯这样模型参数会自己同步,不需要额外操作了。

import os
import torch.distributed as distdef ddp_setup(rank,world_size):os.environ['MASTER_ADDR'] = 'localhost' #rank0 对应的地址os.environ['MASTER_PORT'] = '6666' #任何空闲的端口dist.init_process_group(backend='nccl',  #nccl Gloo #nvidia显卡的选择ncclworld_size=world_size, init_method='env://',rank=rank) #初始化默认的分布进程组dist.barrier() #等到每块GPU运行到这再继续往下走

2.2 spawn启动(rank怎么来的)

rank是自动分配的。怎么分配呢?这里用的是spawn也就一行代码。

import torch.multiprocessing as mp
def main (rank:int,world_size:int,args):pass#训练代码 主函数mp.spawn(main,args=(args.world_size,args), nprocs=args.world_size)

注意,调用spawn的时候,没有输入main的其中一个参数rank,rank由代码自动分配。将代码复制两份在两张卡上同时跑,你可以print(rank),会发现输出 0 1。两份代码并行跑。

另外,nprocs=args.world_size。如果不这么写,代码会卡死在这,既不报错,也不停止。

2.3 loss backward

one of the variables needed for gradient computation has been modified by an inplace operationRuntimeError: one of the variables needed for gradient computation has been modified by an inplace operation: [torch.cuda.FloatTensor [2048]] is at version 4; expected version 3 instead. Hint: enable anomaly detection to find the operation that failed to compute its gradient, with torch.autograd.set_detect_anomaly(True).

经过调试发现,当使用nn.DataParallel并行训练或者单卡训练均可正常运行;另外如果将两次模型调用集成到model中,即通过out1, out2 = model(input0, input1) 的方式在分布式训练下也不会报错。

在分布式训练中,如果对同一模型进行多次调用则会触发以上报错,即nn.parallel.DistributedDataParallel方法封装的模型,forword()函数和backward()函数必须交替执行,如果执行多个(次)forward()然后执行一次backward()则会报错。

解决此问题可以聚焦到nn.parallel.DistributedDataParallel接口上,通过查询PyTorch官方文档发现此接口下的两个参数:

  • find_unused_parameters: 如果模型的输出有不需要进行反向传播的,此参数需要设置为True;若你的代码运行后卡住不动,基本上就是该参数的问题。
  • broadcast_buffers: 该参数默认为True,设置为True时,在模型执行forward之前,gpu0会把buffer中的参数值全部覆盖到别的gpu上。
model = nn.parallel.DistributedDataParallel(model, device_ids=[args.local_rank], broadcast_buffers=False, find_unused_parameters=True)

2.4 model cuda设置

RuntimeError: NCCL error in: ../torch/lib/c10d/ProcessGroupNCCL.cpp:859, invalid usage, NCCL version 21.1.1
ncclInvalidUsage: This usually reflects invalid usage of NCCL library (such as too many async ops, too many collectives at once, mixing streams in a group, etc).

*这是因为model和local_rank所指定device不一致引起的错误。

model.cuda(args.local_rank)
model = torch.nn.parallel.DistributedDataParallel(model, device_ids=[args.local_rank],broadcast_buffers=False,find_unused_parameters=True)

2.5 数据加载

使用distributed加载数据集,需要使用DistributedSampler自动为每个gpu分配数据,但需要注意的是sampler和shuffle=True不能并存。

train_sampler = DistributedSampler(trainset)
train_loader = torch.utils.data.DataLoader(trainset,batch_size=args.train_batch_size,num_workers=args.train_workers,sampler=train_sampler)

相关文章:

  • 【问题记录】docker pull 镜像的时候 devel 版本和无 devel 版本的差别
  • 使用 eBPF检测 mmap泄露
  • 【电路笔记】-节点电压分析和网状电流分析
  • EDA实验----四选一多路选择器设计(QuartusII)
  • Java中单例模式
  • Echarts柱状体实现滚动条动态滚动
  • Spring源码系列-框架中的设计模式
  • [工业自动化-11]:西门子S7-15xxx编程 - PLC从站 - 分布式IO从站/从机
  • 【C++笔记】优先级队列priority_queue的模拟实现
  • 原型模式(创建型)
  • 解析html生成Word文档
  • 总结:利用原生JDK封装工具类,解析properties配置文件以及MF清单文件
  • 七个优秀微服务跟踪工具
  • 微服务-开篇-个人对微服务的理解
  • 【Springboot】基于注解式开发Springboot-Vue3整合Mybatis-plus实现分页查询
  • 【剑指offer】让抽象问题具体化
  • 【跃迁之路】【669天】程序员高效学习方法论探索系列(实验阶段426-2018.12.13)...
  • 4个实用的微服务测试策略
  • CentOS从零开始部署Nodejs项目
  • CSS 提示工具(Tooltip)
  • Date型的使用
  • Just for fun——迅速写完快速排序
  • NLPIR语义挖掘平台推动行业大数据应用服务
  • React Transition Group -- Transition 组件
  • Sass 快速入门教程
  • swift基础之_对象 实例方法 对象方法。
  • 多线程 start 和 run 方法到底有什么区别?
  • 分享一个自己写的基于canvas的原生js图片爆炸插件
  • 如何解决微信端直接跳WAP端
  • 思维导图—你不知道的JavaScript中卷
  • 阿里云ACE认证之理解CDN技术
  • 数据可视化之下发图实践
  • ​DB-Engines 11月数据库排名:PostgreSQL坐稳同期涨幅榜冠军宝座
  • ​Linux Ubuntu环境下使用docker构建spark运行环境(超级详细)
  • #Z2294. 打印树的直径
  • (C语言)字符分类函数
  • (Matalb分类预测)GA-BP遗传算法优化BP神经网络的多维分类预测
  • (pytorch进阶之路)CLIP模型 实现图像多模态检索任务
  • (八十八)VFL语言初步 - 实现布局
  • (附源码)apringboot计算机专业大学生就业指南 毕业设计061355
  • (附源码)计算机毕业设计SSM在线影视购票系统
  • (接口封装)
  • (一)Thymeleaf用法——Thymeleaf简介
  • (译) 函数式 JS #1:简介
  • (转)chrome浏览器收藏夹(书签)的导出与导入
  • (转)mysql使用Navicat 导出和导入数据库
  • (转)shell中括号的特殊用法 linux if多条件判断
  • .libPaths()设置包加载目录
  • .NET 4.0网络开发入门之旅-- 我在“网” 中央(下)
  • .NET Compact Framework 多线程环境下的UI异步刷新
  • .Net Core 中间件验签
  • .NET国产化改造探索(三)、银河麒麟安装.NET 8环境
  • .NET企业级应用架构设计系列之结尾篇
  • @RequestBody详解:用于获取请求体中的Json格式参数
  • [2009][note]构成理想导体超材料的有源THz欺骗表面等离子激元开关——