FairScale
FairScale 是 PyTorch 扩展库,用于在一台或多台机器/节点上进行高性能和大规模训练
【示例】
在2个GPU上运行4层模型。前两层在cuda:0上运行,后两层在cuda:1上运行。
from torch import nn import fairscale model = nn.Sequential( nn.Conv2d(in_channels=3, out_channels=6, kernel_size=(5,5), stride=1, padding=0), nn.MaxPool2d(kernel_size=(2,2), stride=2, padding=0), nn.Conv2d(in_channels=6, out_channels=16, kernel_size=(5,5), stride=1, padding=0), nn.MaxPool2d(kernel_size=(2, 2), stride=2, padding=0), ) model = fairscale.nn.Pipe(model, balance=[2, 2], devices=[0, 1], chunks=8)