一文搞懂 | Pytorch维度转换操作:view,reshape,permute,flatten函数详解
《博主简介》
小伙伴们好,我是阿旭。专注于人工智能、AIGC、python、计算机视觉相关分享研究。
👍感谢小伙伴们点赞、关注!
《------往期经典推荐------》
一、AI应用软件开发实战专栏【链接】
项目名称 | 项目名称 |
---|---|
1.【人脸识别与管理系统开发】 | 2.【车牌识别与自动收费管理系统开发】 |
3.【手势识别系统开发】 | 4.【人脸面部活体检测系统开发】 |
5.【图片风格快速迁移软件开发】 | 6.【人脸表表情识别系统】 |
7.【YOLOv8多目标识别与自动标注软件开发】 | 8.【基于YOLOv8深度学习的行人跌倒检测系统】 |
9.【基于YOLOv8深度学习的PCB板缺陷检测系统】 | 10.【基于YOLOv8深度学习的生活垃圾分类目标检测系统】 |
11.【基于YOLOv8深度学习的安全帽目标检测系统】 | 12.【基于YOLOv8深度学习的120种犬类检测与识别系统】 |
13.【基于YOLOv8深度学习的路面坑洞检测系统】 | 14.【基于YOLOv8深度学习的火焰烟雾检测系统】 |
15.【基于YOLOv8深度学习的钢材表面缺陷检测系统】 | 16.【基于YOLOv8深度学习的舰船目标分类检测系统】 |
17.【基于YOLOv8深度学习的西红柿成熟度检测系统】 | 18.【基于YOLOv8深度学习的血细胞检测与计数系统】 |
19.【基于YOLOv8深度学习的吸烟/抽烟行为检测系统】 | 20.【基于YOLOv8深度学习的水稻害虫检测与识别系统】 |
21.【基于YOLOv8深度学习的高精度车辆行人检测与计数系统】 | 22.【基于YOLOv8深度学习的路面标志线检测与识别系统】 |
23.【基于YOLOv8深度学习的智能小麦害虫检测识别系统】 | 24.【基于YOLOv8深度学习的智能玉米害虫检测识别系统】 |
25.【基于YOLOv8深度学习的200种鸟类智能检测与识别系统】 | 26.【基于YOLOv8深度学习的45种交通标志智能检测与识别系统】 |
27.【基于YOLOv8深度学习的人脸面部表情识别系统】 | 28.【基于YOLOv8深度学习的苹果叶片病害智能诊断系统】 |
29.【基于YOLOv8深度学习的智能肺炎诊断系统】 | 30.【基于YOLOv8深度学习的葡萄簇目标检测系统】 |
31.【基于YOLOv8深度学习的100种中草药智能识别系统】 | 32.【基于YOLOv8深度学习的102种花卉智能识别系统】 |
33.【基于YOLOv8深度学习的100种蝴蝶智能识别系统】 | 34.【基于YOLOv8深度学习的水稻叶片病害智能诊断系统】 |
35.【基于YOLOv8与ByteTrack的车辆行人多目标检测与追踪系统】 | 36.【基于YOLOv8深度学习的智能草莓病害检测与分割系统】 |
37.【基于YOLOv8深度学习的复杂场景下船舶目标检测系统】 | 38.【基于YOLOv8深度学习的农作物幼苗与杂草检测系统】 |
39.【基于YOLOv8深度学习的智能道路裂缝检测与分析系统】 | 40.【基于YOLOv8深度学习的葡萄病害智能诊断与防治系统】 |
41.【基于YOLOv8深度学习的遥感地理空间物体检测系统】 | 42.【基于YOLOv8深度学习的无人机视角地面物体检测系统】 |
43.【基于YOLOv8深度学习的木薯病害智能诊断与防治系统】 | 44.【基于YOLOv8深度学习的野外火焰烟雾检测系统】 |
45.【基于YOLOv8深度学习的脑肿瘤智能检测系统】 | 46.【基于YOLOv8深度学习的玉米叶片病害智能诊断与防治系统】 |
47.【基于YOLOv8深度学习的橙子病害智能诊断与防治系统】 | 48.【车辆检测追踪与流量计数系统】 |
49.【行人检测追踪与双向流量计数系统】 | 50.【基于YOLOv8深度学习的反光衣检测与预警系统】 |
51.【危险区域人员闯入检测与报警系统】 | 52.【高密度人脸智能检测与统计系统】 |
53.【CT扫描图像肾结石智能检测系统】 | 54.【水果智能检测系统】 |
55.【水果质量好坏智能检测系统】 | 56.【蔬菜目标检测与识别系统】 |
57.【非机动车驾驶员头盔检测系统】 | 58.【太阳能电池板检测与分析系统】 |
59.【工业螺栓螺母检测】 | 60.【金属焊缝缺陷检测系统】 |
61.【链条缺陷检测与识别系统】 | 62.【交通信号灯检测识别】 |
二、机器学习实战专栏【链接】,已更新31期,欢迎关注,持续更新中~~
三、深度学习【Pytorch】专栏【链接】
四、【Stable Diffusion绘画系列】专栏【链接】
五、YOLOv8改进专栏【链接】,持续更新中~~
六、YOLO性能对比专栏【链接】,持续更新中~
《------正文------》
目录
- 引言
- view函数
- 代码示例
- permute函数
- 代码示例
- Reshape函数
- flatten函数
- 代码示例
引言
在深度学习网络构建与计算过程中,我们经常会使用到张量维度之间的各种转换,用于不同操作。Pytorch中常见的维度转换函数有view,reshape,permute,flatten
。本文将详细介绍这几个函数的作用与使用方式,希望能够帮助大家。
常见的维度有四维:比如(batch, channel, height, width);三维:比如(b,n,c);二维:比如(h,w)。下面介绍如何使用上述函数进行维度之间的转换。
view函数
作用
tensor.view() 可以用来调整张量的形状
,这对于在网络层之间传递数据或者在处理图像数据时非常有用。需要注意的是,新的形状必须与原始张量的元素数量一致。
参数
size (tuple of ints) – 新的大小应该与原张量元素数量相匹配
。可以指定一个尺寸为 -1 的维度来自动计算合适的大小。
代码示例
将计算机视觉中的常见四维张量(Batch, Channel, Height, Width)
转为三维(Batch,N,Channel)
形式。
import torch
# view使用示例
x = torch.randn(16,3,64,64) # B, C, H, W
print(x.shape) #torch.Size([16,3,64,64])
B, C, H, W = x.size()# 转为BNC
x = x.view(B, -1, C)
# 或者 x = x.view(B, H*W, C)
print(x.shape) #torch.Size([16, 4096, 3])
torch.randn() 是 PyTorch 中的一个函数,用于生成一个填充了从标准正态分布(均值为 0,方差为 1)中随机抽取的数字的张量。
permute函数
作用
permute() 函数用于改变张量的维度顺序。它接受一个新的维度顺序作为参数,并返回一个新的张量,其维度顺序按照给定的顺序排列。
参数说明
参数:一个元组,表示新的维度顺序。
例如,对于一个形状为 (10, 3, 32, 32) 的张量,permute(0, 2, 3, 1) 表示新的维度顺序为 (10, 32, 32, 3)。其中0,1,2,3分别表示4个维度(10, 3, 32, 32)的索引。
代码示例
依然将计算机视觉中的常见四维张量(Batch, Channel, Height, Width)转为三维(Batch,N,Channel)形式。
import torch
# permute使用示例:permute转换唯独顺序
x = torch.randn(16,3,64,64) # B, C, H, W
print(x.shape) #torch.Size([16,3,64,64])# 16,3,64,64的维度索引分别为0,1,2,3
dim_change = x.permute(0,2,3,1) # 转为 B,H,W,C
# 然后将中间两个通道索引为[1,2]展平
out = dim_change.flatten(start_dim=1,end_dim=2)
print(out.shape) #torch.Size([16, 4096, 3])
flatten() 方法用于展平张量的一个或多个维度。它可以接受两个可选参数:
start_dim:从哪个维度开始展平,默认为 0。
end_dim:到哪个维度结束展平,默认为 -1,表示直到最后一个维度。
此处的作用是将第二个和第三个维度进行展平。
start_dim=1 表示从第二个维度(即 64)开始展平。
end_dim=2 表示到第三个维度(即 64)结束展平。
展平后的结果为 (16, 4096, 3),其中 4096= 64 * 64。
通过这些步骤,你可以将原始张量从 (16,3,64,64) 转换为 (16, 4096, 3)。
Reshape函数
torch.reshape() 可以改变张量的形状,而不改变张量中的数据
。与view函数的作用类似。
注意事项:新旧形状的元素总数必须相同。
import torch# 创建一个简单的张量
x = torch.randn(4, 3)
print("Original tensor:")
print(x)# 使用 torch.reshape() 来改变张量的形状
# 将 (4, 3) 的张量转换成 (2, 6) 的张量
reshaped_x = torch.reshape(x, (2, 6))
print("\nReshaped tensor:")
print(reshaped_x)# 如果不确定某个维度的大小,可以使用 -1 让 PyTorch 自动计算
# 这里将 (4, 3) 转换为 (12,) 的一维张量
flat_x = torch.reshape(x, (-1))
print("\nFlattened tensor:")
print(flat_x)# 更复杂的形状变换
# 将 (4, 3) 转换为 (3, 4) 的张量
complex_reshaped_x = torch.reshape(x, (3, 4))
print("\nComplex reshaped tensor:")
print(complex_reshaped_x)
flatten函数
torch.flatten 是 PyTorch 库中的一个函数,用于将一个多维张量转换为一维张量或降低其维度。
torch.flatten参数说明
input: 这是要被展平的张量。这是必需的参数。
start_dim (可选): 指定从哪个维度开始展平。默认值为 0,这意味着展平将从第一个维度(通常是批量大小)开始。如果你希望保留前几个维度并只展平后续的维度,你可以设置这个参数。
end_dim (可选): 指定展平到哪个维度结束。默认值为 -1,这表示展平将一直持续到最后一个维度。如果只想展平中间的一部分维度,可以设置这个参数来指定结束维度。
当 start_dim 和 end_dim 都没有被显式地指定时,torch.flatten 将会展平除了第一个维度之外的所有维度,通常第一个维度是批量大小,会被保留以便于批次处理。
代码示例
举个例子,假设你有一个形状为 [batch_size, channels, height, width] 的四维张量,如果你想将其展平为 [batch_size, channels * height * width] 的二维张量,你可以直接调用 torch.flatten 而不需要额外的参数。但是,如果你想保留通道维度,并展平高度和宽度维度,你可以设置 start_dim=1 和 end_dim=2。
import torch# 创建一个形状为 [8, 3, 64, 64] 的随机张量
x = torch.randn(8, 3, 64, 64)# 展平除了第一个维度外的所有维度
y = torch.flatten(x)
print(y.shape) # 输出: torch.Size([8, 12288])# 只展平第二和第三个维度[也就是最后两个维度],0,1,2,3
z = torch.flatten(x, 1, 2)
print(z.shape) # 输出: torch.Size([8, 3, 4096])
关注文末名片G-Z-H:【阿旭算法与机器学习】,发送【开源】可获取更多学习资源
好了,这篇文章就介绍到这里,喜欢的小伙伴感谢给点个赞和关注,更多精彩内容持续更新~~
关于本篇文章大家有任何建议或意见,欢迎在评论区留言交流!