(pt可视化)利用torch的make_grid进行张量可视化
在使用pytorch时,有时候需要对张量进行可视化,比如在经过一堆数据预处理后,我们从dataloader拿到了一个张量:[8,3,224,224],显然这是一个bs=8且为RGB的张量,一般来说经过ToTensor和Normalize后值范围在[-1,1],如果想看看这些张量是什么样子,一堆代码还是挺麻烦的,所以利用torch提供的make_grid和plt就能够轻松可视化张量。
make_graid():
def make_grid(
tensor: Union[torch.Tensor, List[torch.Tensor]],
nrow: int = 8,
padding: int = 2,
normalize: bool = False,
value_range: Optional[Tuple[int, int]] = None,
scale_each: bool = False,
pad_value: int = 0,
**kwargs
) -> torch.Tensor:
- tensor:要可视化的张量,比如为[8,3,224,224]
- nrow:列数,行数=bs/列数
- padding:不同图像之间的间隙大小
- normalize:是否归一化,若是则按图像最大最小值归一化到[0,1]
- value_range:指定normalize使用的最大最小值,默认使用图像本身的最大最小值
- scale_each:是否单独为图像进行normalize。默认所有的图像都进行normalize
- pad_value:间隙的填充值。范围在0(间隙为黑色)~1(间隙为白色)之间
- return:返回(C,H,W)数据(多张图拼凑成了一张图)
plt:
上面我们得到了makr_graid生成的图像,我们使用plt来可视化:
npimg = vis.numpy() # plt输入需要时ndarray
plt.imshow(np.transpose(npimg, (1, 2, 0)), interpolation='nearest') # 需要将通道转到最后一维
plt.show()
关于plt显示图像,详见:7、显示图片
效果如下: