当前位置: 首页 > news >正文

PyTorch Lightning入门教程(二)

文章目录

  • PyTorch Lightning入门教程(二)
    • 前言
    • 单机多卡
    • 多机多卡
    • 半精度训练

PyTorch Lightning入门教程(二)

前言

pytorch lightning提供了比较方便的多GPU训练方式,同时包括多种策略和拓展库,比如ddp,fairscale等,下面将从单机多卡和多机多卡两个角度介绍。

单机多卡

pytorch lightning的官网提供了比较详细的使用方法,可以参考https://pytorch-lightning.readthedocs.io/en/latest/accelerators/gpu.html

一般来说,只要在trainer的参数中指定了参数gpus,就可以使用多GPU运行了,例如:

Trainer(gpus=4)  # 使用4块显卡进行计算
Trainer(gpus=[0, 2]) # 使用0和2号显卡进行计算

当然这里支持多种写法来加载GPU,这里有说明https://pytorch-lightning.readthedocs.io/en/latest/accelerators/gpu_basic.html,可以参考。

需要注意的是这有一个strategy参数,可以参考https://pytorch-lightning.readthedocs.io/en/latest/accelerators/gpu_intermediate.html,pytorch lightning支持多种训练框架,包括了dp,dpp,horovod,bagua等。所以这里别忘记设定好所使用的多GPU框架。

多机多卡

这部分和单机多卡的区别不大,只是增加了一个参数num_nodes,来设定所使用的机器数量,例如:

Trainer(gpus=4, num_nodes=3) # 表示使用3台机器,共12张显卡 

之后需要注意的是启动进程,可以参考https://pytorch-lightning.readthedocs.io/en/latest/accelerators/gpu_intermediate.html

针对上面那个例子,如果想要启动训练,需要在三台机器上分别输入下面命令:

master_ip=192.168.1.10
机器1:
CUDA_VISIBLE_DEVICES="0,1,2,3" MASTER_ADDR=$master_ip MASTER_PORT=45321 WORLD_SIZE=3 NODE_RANK=0 LOCAL_RANK=0 python main.py
机器2:
CUDA_VISIBLE_DEVICES="0,1,2,3" MASTER_ADDR=$master_ip MASTER_PORT=45321 WORLD_SIZE=3 NODE_RANK=1 LOCAL_RANK=0 python main.py
机器3:
CUDA_VISIBLE_DEVICES="0,1,2,3" MASTER_ADDR=$master_ip MASTER_PORT=45321 WORLD_SIZE=3 NODE_RANK=2 LOCAL_RANK=0 python main.py

这样就可以在3台机器上进行训练了

如果训练的时候报错或者卡死,无法运行,可以试试在上面的命令中加上NCCL_IB_DISABLE=1,将NCCL_IB_DISABLE设置为1来禁止使用IB/RoCE传输方式,转而使用IP传输,对于不支持RDMA技术的服务器,这个值设置为1可以解决部分训练卡死的问题。如果网络接口不能被自动发现,则手工设置NCCL_SOCKET_IFNAME=eth0,如果还有问题,就设置NCCL的debug模式NCCL_DEBUG=INFO

注:需要注意的是,由于目前pytorch lightning还在开发发展中,很多新的功能只有新版本的才有,所以需要注意自己的pytorch lightning的版本,比如本文中提到的strategy策略,笔者在使用的时候,发现对于1.5.10的pytorch lightning版本,使用fairscale策略,无法在多机多卡的情况下使用,后来升级到1.7.4版本之后,才可以正常使用。

半精度训练

这里设置比较简单,只需要新加一个参数precision即可:

Trainer(gpus=4, num_nodes=3, precision=16)

但是需要注意的是并不是所有的情况下都支持半精度训练,比如DataParallel就不支持半精度,而DistributedDataParallel是可以支持的,所以平时多机多卡训练就不要用DataParallel。

相关文章:

  • 【滤波跟踪】基于变分贝叶斯卡尔曼滤波器实现目标跟踪附matlab代码
  • C++ mutex 与 condition_variable
  • 基础 | Spring - [单例创建过程]
  • K8S集群Pod资源自动扩缩容方案
  • SPPNet
  • java多线程-多线程技能
  • 网课查题接口 该怎么搭建
  • Elasticsearch学习-- 聚合查询
  • 网课搜题公众号接口
  • ubuntu18.04.1LTS 编译安装ffmpeg详解
  • 接口幂等问题:redis分布式锁解决方案
  • 算法与数据结构(第一周)——线性查找法
  • 修改docker 修改容器配置
  • ARM汇编语言
  • 【通信】非正交多址接入(NOMA)和正交频分多址接入(OFDMA)的性能对比附matlab代码
  • 【391天】每日项目总结系列128(2018.03.03)
  • 【腾讯Bugly干货分享】从0到1打造直播 App
  • 2018一半小结一波
  • 3.7、@ResponseBody 和 @RestController
  • Angular 响应式表单 基础例子
  • CSS选择器——伪元素选择器之处理父元素高度及外边距溢出
  • docker容器内的网络抓包
  • gulp 教程
  • JavaScript 无符号位移运算符 三个大于号 的使用方法
  • JavaScript工作原理(五):深入了解WebSockets,HTTP/2和SSE,以及如何选择
  • java多线程
  • Java反射-动态类加载和重新加载
  • JS进阶 - JS 、JS-Web-API与DOM、BOM
  • laravel 用artisan创建自己的模板
  • markdown编辑器简评
  • MYSQL如何对数据进行自动化升级--以如果某数据表存在并且某字段不存在时则执行更新操作为例...
  • Nodejs和JavaWeb协助开发
  • PHP那些事儿
  • puppeteer stop redirect 的正确姿势及 net::ERR_FAILED 的解决
  • vue2.0开发聊天程序(四) 完整体验一次Vue开发(下)
  • 阿里云Kubernetes容器服务上体验Knative
  • 百度地图API标注+时间轴组件
  • 从零开始在ubuntu上搭建node开发环境
  • 回顾 Swift 多平台移植进度 #2
  • 京东美团研发面经
  • 实现菜单下拉伸展折叠效果demo
  • 提升用户体验的利器——使用Vue-Occupy实现占位效果
  • 网络应用优化——时延与带宽
  • 异步
  • 用Python写一份独特的元宵节祝福
  • 直播平台建设千万不要忘记流媒体服务器的存在 ...
  • # 学号 2017-2018-20172309 《程序设计与数据结构》实验三报告
  • (1/2) 为了理解 UWP 的启动流程,我从零开始创建了一个 UWP 程序
  • (9)STL算法之逆转旋转
  • (AngularJS)Angular 控制器之间通信初探
  • (Matlab)遗传算法优化的BP神经网络实现回归预测
  • (第8天)保姆级 PL/SQL Developer 安装与配置
  • (附源码)springboot 基于HTML5的个人网页的网站设计与实现 毕业设计 031623
  • (九十四)函数和二维数组
  • (免费领源码)python#django#mysql校园校园宿舍管理系统84831-计算机毕业设计项目选题推荐