当前位置: 首页 > news >正文

Pytorch dataloader中的num_workers (选择最合适的num_workers值)

      num_workers是Dataloader的概念,默认值是0. 是告诉DataLoader实例要使用多少个子进程进行数据加载(和CPU有关,和GPU无关)

     如果num_worker设为0,意味着每一轮迭代时,dataloader不再有自主加载数据到RAM这一步骤(因为没有worker了),而是在RAM中找batch,找不到时再加载相应的batch。缺点当然是速度慢。

    当num_worker不为0时,每轮到dataloader加载数据时,dataloader一次性创建num_worker个worker,并用batch_sampler将指定batch分配给指定worker,worker将它负责的batch加载进RAM。

  num_worker设置得大,好处是寻batch速度快,因为下一轮迭代的batch很可能在上一轮/上上一轮...迭代时已经加载好了。坏处是内存开销大,也加重了CPU负担(worker加载数据到RAM的进程是CPU复制的嘛)。num_workers的经验设置值是自己电脑/服务器的CPU核心数,如果CPU很强、RAM也很充足,就可以设置得更大些

num_worker小了的情况,主进程采集完最后一个worker的batch。此时需要回去采集第一个worker产生的第二个batch。如果该worker此时没有采集完,主线程会卡在这里等。(这种情况出现在,num_works数量少或者batchsize 比较小,显卡很快就计算完了,CPU对GPU供不应求。)

即,num_workers的值和模型训练快慢有关,和训练出的模型的performance无关

Detectron2的num_workers默认是4

选择最合适的num_workers值

最合适的num_works值与数据集有关

最好是跑代码之前先用这段script跑一下,选择最合适的num_workers值

from time import time
import multiprocessing as mp
import torch
import torchvision
from torchvision import transforms


transform = transforms.Compose([
    torchvision.transforms.ToTensor(),
    torchvision.transforms.Normalize((0.1307,), (0.3081,))
])
 
trainset = torchvision.datasets.MNIST(
    root='dataset/',
    train=True,  #如果为True,从 training.pt 创建数据,否则从 test.pt 创建数据。
    download=True, #如果为true,则从 Internet 下载数据集并将其放在根目录中。 如果已下载数据集,则不会再次下载。
    transform=transform
)

print(f"num of CPU: {mp.cpu_count()}")
for num_workers in range(2, mp.cpu_count(), 2):  
    train_loader = torch.utils.data.DataLoader(trainset, shuffle=True, num_workers=num_workers, batch_size=64, pin_memory=True)
    start = time()
    for epoch in range(1, 3):
        for i, data in enumerate(train_loader, 0):
            pass
    end = time()
    print("Finish with:{} second, num_workers={}".format(end - start, num_workers))

可以看到,这个服务器24个CPU, 最合适的num_workers值是14

相关文章:

  • rsync数据传输
  • Pytorch torch.utils.data.DataLoader(二) —— pin_memory锁页内存 drop_last num_works
  • Pytorch分布式训练/多卡训练(三) —— Model Parallel 并行
  • Pytorch backend 通信后端
  • PyTorch多卡/多GPU/分布式DPP的基本概念(noderanklocal_ranknnodesnode_ranknproc_per_nodeworld_size)
  • Pytorch函数keepdim=True
  • Python opencv putText()中文乱码问题
  • Python类的__call__方法
  • Python处理XML(ElementTree)
  • Pytorch为不同层设置不同的学习率(全局微调)
  • PIL Image和opencv读入图片相互转化
  • Python 星号表达式*(starred expression / unpack / 解包)
  • Pytorch/Python计算交并比IOU(IU)代码(批量算IOU)
  • Pytorch apply() 函数
  • Python namedtuple数据结构(命名元组)(collections)
  • 自己简单写的 事件订阅机制
  • android百种动画侧滑库、步骤视图、TextView效果、社交、搜房、K线图等源码
  • - C#编程大幅提高OUTLOOK的邮件搜索能力!
  • CSS实用技巧干货
  • es6--symbol
  • in typeof instanceof ===这些运算符有什么作用
  • Map集合、散列表、红黑树介绍
  • maven工程打包jar以及java jar命令的classpath使用
  • Netty 框架总结「ChannelHandler 及 EventLoop」
  • Nodejs和JavaWeb协助开发
  • PHP的类修饰符与访问修饰符
  • Python 基础起步 (十) 什么叫函数?
  • python学习笔记 - ThreadLocal
  • Redash本地开发环境搭建
  • vue自定义指令实现v-tap插件
  • 蓝海存储开关机注意事项总结
  • 力扣(LeetCode)21
  • 力扣(LeetCode)965
  • 猫头鹰的深夜翻译:JDK9 NotNullOrElse方法
  • 融云开发漫谈:你是否了解Go语言并发编程的第一要义?
  • 写给高年级小学生看的《Bash 指南》
  • 鱼骨图 - 如何绘制?
  • 运行时添加log4j2的appender
  • 在Unity中实现一个简单的消息管理器
  • Semaphore
  • 好程序员web前端教程分享CSS不同元素margin的计算 ...
  • ​DB-Engines 12月数据库排名: PostgreSQL有望获得「2020年度数据库」荣誉?
  • ​草莓熊python turtle绘图代码(玫瑰花版)附源代码
  • #1015 : KMP算法
  • #HarmonyOS:Web组件的使用
  • (1) caustics\
  • (Arcgis)Python编程批量将HDF5文件转换为TIFF格式并应用地理转换和投影信息
  • (BFS)hdoj2377-Bus Pass
  • (论文阅读23/100)Hierarchical Convolutional Features for Visual Tracking
  • (七)微服务分布式云架构spring cloud - common-service 项目构建过程
  • (算法)前K大的和
  • (中等) HDU 4370 0 or 1,建模+Dijkstra。
  • ****Linux下Mysql的安装和配置
  • .net Application的目录
  • .Net core 6.0 升8.0