当前位置: 首页 > news >正文

理解张量拼接(torch.cat)

拼接

维度顺序:对于 3D 张量,通常可以理解为 (深度, 行, 列) 或 (批次, 行, 列)。 选择一个dim进行拼接的时候其他两个维度大小要相等
![[Pasted image 20240808214248.png]]

对于三维张量,理解 torch.catdim 参数确实变得更加抽象,但原理是相同的。让我们通过一个具体的例子来说明这一点。

import torch# 创建两个 3D 张量
a = torch.tensor([[[1, 2], [3, 4]], [[5, 6], [7, 8]]])
b = torch.tensor([[[9, 10], [11, 12]], [[13, 14], [15, 16]]])print("Tensor a shape:", a.shape)
print(a)
print("\nTensor b shape:", b.shape)
print(b)# dim=0 连接
c_dim0 = torch.cat([a, b], dim=0)
print("\nResult of torch.cat([a, b], dim=0):")
print("Shape:", c_dim0.shape)
print(c_dim0)# dim=1 连接
c_dim1 = torch.cat([a, b], dim=1)
print("\nResult of torch.cat([a, b], dim=1):")
print("Shape:", c_dim1.shape)
print(c_dim1)# dim=2 连接
c_dim2 = torch.cat([a, b], dim=2)
print("\nResult of torch.cat([a, b], dim=2):")
print("Shape:", c_dim2.shape)
print(c_dim2)

现在让我们详细解释这个三维张量的例子:

  1. 初始张量:

    • ab 都是形状为 (2, 2, 2) 的 3D 张量。
    • 可以将它们想象成两个 2x2 的矩阵堆叠在一起。
  2. dim=0 连接:

    • 结果形状:(4, 2, 2)
    • 这相当于在第一个维度上堆叠张量。
    • 可以理解为将 b 放在 a 的"下面",增加了第一个维度的大小。
  3. dim=1 连接:

    • 结果形状:(2, 4, 2)
    • 这相当于在第二个维度上堆叠张量。
    • 可以理解为在每个 2x2 矩阵的"行"方向上扩展,将 b 的行添加到 a 的每个对应部分的下方。
  4. dim=2 连接:

    • 结果形状:(2, 2, 4)
    • 这相当于在第三个维度(最内层)上堆叠张量。
    • 可以理解为在每个 2x2 矩阵的"列"方向上扩展,将 b 的列添加到 a 的每个对应部分的右侧。

理解三维张量 torch.cat 的关键点:

  1. 维度顺序:对于 3D 张量,通常可以理解为 (深度, 行, 列) 或 (批次, 行, 列)。

  2. dim=0:增加"深度"或"批次"的数量。

  3. dim=1:增加每个"深度"层或"批次"中的行数。

  4. dim=2:增加每行中的元素数量(列数)。

  5. 保持其他维度:除了被连接的维度,其他维度的大小保持不变。

  6. 形状变化:只有指定的 dim 对应的维度大小会改变(增加),其他维度大小保持不变。

  7. 一致性:要连接的张量在非连接维度上的大小必须相同。

3D Matrix Visualization

Let’s visualize the 3D matrices a and b, and their concatenation results.

Matrix a (2x2x2):
Depth 0:    Depth 1:
+---+---+   +---+---+
| 1 | 2 |   | 5 | 6 |
+---+---+   +---+---+
| 3 | 4 |   | 7 | 8 |
+---+---+   +---+---+
Matrix b (2x2x2):
Depth 0:    Depth 1:
+----+----+ +----+----+
| 9  | 10 | | 13 | 14 |
+----+----+ +----+----+
| 11 | 12 | | 15 | 16 |
+----+----+ +----+----+

Concatenation Results:

dim=0 (4x2x2):
Depth 0:    Depth 1:    Depth 2:    Depth 3:
+---+---+   +---+---+   +----+----+ +----+----+
| 1 | 2 |   | 5 | 6 |   | 9  | 10 | | 13 | 14 |
+---+---+   +---+---+   +----+----+ +----+----+
| 3 | 4 |   | 7 | 8 |   | 11 | 12 | | 15 | 16 |
+---+---+   +---+---+   +----+----+ +----+----+
dim=1 (2x4x2):
Depth 0:        Depth 1:
+---+---+       +---+---+
| 1 | 2 |       | 5 | 6 |
+---+---+       +---+---+
| 3 | 4 |       | 7 | 8 |
+---+---+       +---+---+
| 9 | 10 |      | 13| 14|
+---+---+       +---+---+
| 11| 12 |      | 15| 16|
+---+---+       +---+---+
dim=2 (2x2x4):
Depth 0:        Depth 1:
+---+---+---+---+   +---+---+---+---+
| 1 | 2 | 9 | 10|   | 5 | 6 | 13| 14|
+---+---+---+---+   +---+---+---+---+
| 3 | 4 | 11| 12|   | 7 | 8 | 15| 16|
+---+---+---+---+   +---+---+---+---+

当然可以!让我们通过具体的例子来形象地解释不同维度上的拼接。

定义张量

首先,定义三个张量 x, y, z,它们分别具有如下形状:

  • x 的形状是 [2, 1, 3]
  • y 的形状是 [2, 3, 3]
  • z 的形状是 [2, 2, 3]
import torchx = torch.tensor([[[0, 0, 0]], [[0, 0, 0]]])
y = torch.tensor([[[0, 0, 0], [0, 0, 0], [0, 0, 0]],[[0, 0, 0], [0, 0, 0], [0, 0, 0]]
])
z = torch.tensor([[[0, 0, 0], [0, 0, 0]],[[0, 0, 0], [0, 0, 0]]
])

(1) 在 dim=0 上拼接

dim=0 上拼接,相当于增加“深度”或“批次”的数量。每个张量的“深度”都会堆叠起来。

w_dim0 = torch.cat([x, y, z], dim=0)
print(w_dim0.shape)

形象解释

x:
[[[0, 0, 0]],  # 第一层深度[[0, 0, 0]]   # 第二层深度
]y:
[[[0, 0, 0], [0, 0, 0], [0, 0, 0]],  # 第一层深度[[0, 0, 0], [0, 0, 0], [0, 0, 0]]   # 第二层深度
]z:
[[[0, 0, 0], [0, 0, 0]],  # 第一层深度[[0, 0, 0], [0, 0, 0]]   # 第二层深度
]拼接结果 w_dim0:
[[[0, 0, 0]],  # x 第一层深度[[0, 0, 0]],  # x 第二层深度[[0, 0, 0], [0, 0, 0], [0, 0, 0]],  # y 第一层深度[[0, 0, 0], [0, 0, 0], [0, 0, 0]],  # y 第二层深度[[0, 0, 0], [0, 0, 0]],  # z 第一层深度[[0, 0, 0], [0, 0, 0]]   # z 第二层深度
]

形状:[6, 3, 3]

(2)dim=1 上拼接

dim=1 上拼接,相当于增加每个“深度”层中的行数。每个深度层的行数会拼接起来。

w_dim1 = torch.cat([x, y, z], dim=1)
print(w_dim1.shape)

形象解释

x:
[[[0, 0, 0]],  # 第一层深度的第一行[[0, 0, 0]]   # 第二层深度的第一行
]y:
[[[0, 0, 0], [0, 0, 0], [0, 0, 0]],  # 第一层深度的三行[[0, 0, 0], [0, 0, 0], [0, 0, 0]]   # 第二层深度的三行
]z:
[[[0, 0, 0], [0, 0, 0]],  # 第一层深度的两行[[0, 0, 0], [0, 0, 0]]   # 第二层深度的两行
]拼接结果 w_dim1:
[[[0, 0, 0], [0, 0, 0], [0, 0, 0], [0, 0, 0], [0, 0, 0], [0, 0, 0]],  # 第一层深度的六行[[0, 0, 0], [0, 0, 0], [0, 0, 0], [0, 0, 0], [0, 0, 0], [0, 0, 0]]   # 第二层深度的六行
]

形状:[2, 6, 3]

当然可以!为了展示如何在 dim=2(第三个维度)上拼接张量,我们需要确保这些张量在前两个维度上的大小是相同的,而在第三个维度上的大小可以不同。

假设我们定义三个张量 a, b, c,它们分别具有如下形状:

  • a 的形状是 [2, 2, 2]
  • b 的形状是 [2, 2, 3]
  • c 的形状是 [2, 2, 1]
import torcha = torch.tensor([[[1, 2], [3, 4]],[[5, 6], [7, 8]]
])b = torch.tensor([[[9, 10, 11], [12, 13, 14]],[[15, 16, 17], [18, 19, 20]]
])c = torch.tensor([[[21], [22]],[[23], [24]]
])

(3)在 dim=2 上拼接

dim=2 上拼接,相当于增加每行中的元素数量(列数)。每个深度层中的列数会拼接起来:

w_dim2 = torch.cat([a, b, c], dim=2)
print(w_dim2)
print(w_dim2.shape)

形象解释

a:
[[[1, 2],      [3, 4]],     # 第一层深度的两行两列[[5, 6],      [7, 8]]      # 第二层深度的两行两列
]b:
[[[9, 10, 11], [12, 13, 14]], # 第一层深度的两行三列[[15, 16, 17], [18, 19, 20]] # 第二层深度的两行三列
]c:
[[[21],        [22]],       # 第一层深度的两行一列[[23],        [24]]        # 第二层深度的两行一列
]拼接结果 w_dim2:
[[[1, 2, 9, 10, 11, 21], [3, 4, 12, 13, 14, 22]],       # 第一层深度的两行六列[[5, 6, 15, 16, 17, 23], [7, 8, 18, 19, 20, 24]]       # 第二层深度的两行六列
]w_dim2 的形状为:[2, 2, 6]

通过在 dim=2 上拼接,结果张量 w_dim2 的第三个维度是各个张量第三个维度的和:2 + 3 + 1 = 6

# 代码输出:
# tensor([[[ 1,  2,  9, 10, 11, 21],
#          [ 3,  4, 12, 13, 14, 22]],
# 
#         [[ 5,  6, 15, 16, 17, 23],
#          [ 7,  8, 18, 19, 20, 24]]])
# 
# 形状: torch.Size([2, 2, 6])

希望这个例子能帮助你更好地理解如何在 dim=2 上拼接张量。
非常好的问题!让我们用书架的比喻来解释这个例子,这将有助于更直观地理解张量的维度。

在这个比喻中:

  • dim=0(第一个维度)代表书架的数量
  • dim=1(第二个维度)代表每个书架的层板数
  • dim=2(第三个维度)代表每个层板可以放置的书本数量(即层板的宽度)

让我们用这个比喻来解释 a, b, 和 c 这三个张量:

  1. 张量 a [2, 2, 2]:

    • 2个书架
    • 每个书架有2层层板
    • 每个层板可以放2本书
  2. 张量 b [2, 2, 3]:

    • 2个书架
    • 每个书架有2层层板
    • 每个层板可以放3本书
  3. 张量 c [2, 2, 1]:

    • 2个书架
    • 每个书架有2层层板
    • 每个层板可以放1本书

当我们在 dim=2 上拼接这些张量时,相当于我们在不改变书架数量和层板数量的情况下,将每个层板变宽,使其可以容纳更多的书。

拼接后的结果 w_dim2 [2, 2, 6]:

  • 仍然是2个书架(dim=0 没变)
  • 每个书架仍然有2层层板(dim=1 没变)
  • 但是现在每个层板可以放6本书了(dim=2 变成了 2+3+1=6)

形象地说:

原来的书架 a:    原来的书架 b:    原来的书架 c:
[□□]            [□□□]           [□]
[□□]            [□□□]           [□][□□]            [□□□]           [□]
[□□]            [□□□]           [□]拼接后的新书架 w_dim2:
[□□□□□□]  (2+3+1 = 6本书)
[□□□□□□][□□□□□□]
[□□□□□□]

每个 □ 代表一本书(或者说张量中的一个元素)。

这个比喻展示了我们如何在不增加书架数量(dim=0)或层板数量(dim=1)的情况下,通过拼接来增加每个层板可以放置的书本数量(dim=2)。这就是在 dim=2 上进行张量拼接的直观理解。

相关文章:

  • 北京网站建设多少钱?
  • 辽宁网页制作哪家好_网站建设
  • 高端品牌网站建设_汉中网站制作
  • cmseasy的两个注入漏洞
  • GiantPandaCV | 大模型训练:Megatron-Kwai中的内存优化
  • Lesson 57 An unusual day
  • git:安装 / 设置环境变量 / 使用
  • 009集——调用方法与递归算法 ——C#学习笔记
  • 网络安全面试题
  • 飞桨Paddle API index_add 详解
  • 8月8号前端日报:web在线进行eps32固件升级
  • 阿里云部署open-webui实现openai代理服务(持续更新)
  • Flink Checkpoint expired before completing解决方法
  • R 语言学习教程,从入门到精通,R 数据框(14)
  • 使用html+css+js实现完整的登录注册页面
  • Python酷库之旅-第三方库Pandas(082)
  • 数据集的简单制作和使用
  • TS中什么是泛型
  • (三)从jvm层面了解线程的启动和停止
  • [iOS]Core Data浅析一 -- 启用Core Data
  • [微信小程序] 使用ES6特性Class后出现编译异常
  • 【跃迁之路】【585天】程序员高效学习方法论探索系列(实验阶段342-2018.09.13)...
  • CODING 缺陷管理功能正式开始公测
  • happypack两次报错的问题
  • javascript面向对象之创建对象
  • Java到底能干嘛?
  • js
  • js递归,无限分级树形折叠菜单
  • miaov-React 最佳入门
  • MySQL-事务管理(基础)
  • Quartz初级教程
  • React16时代,该用什么姿势写 React ?
  • 安卓应用性能调试和优化经验分享
  • -- 查询加强-- 使用如何where子句进行筛选,% _ like的使用
  • 从零开始学习部署
  • 大整数乘法-表格法
  • 坑!为什么View.startAnimation不起作用?
  • 前端代码风格自动化系列(二)之Commitlint
  • 数据结构java版之冒泡排序及优化
  • 微信小程序上拉加载:onReachBottom详解+设置触发距离
  • 微信小程序实战练习(仿五洲到家微信版)
  • 终端用户监控:真实用户监控还是模拟监控?
  • ​configparser --- 配置文件解析器​
  • ​io --- 处理流的核心工具​
  • ​插件化DPI在商用WIFI中的价值
  • ‌内网穿透技术‌总结
  • ‌前端列表展示1000条大量数据时,后端通常需要进行一定的处理。‌
  • # Kafka_深入探秘者(2):kafka 生产者
  • (done) 两个矩阵 “相似” 是什么意思?
  • (Git) gitignore基础使用
  • (Matalb回归预测)PSO-BP粒子群算法优化BP神经网络的多维回归预测
  • (亲测有效)解决windows11无法使用1500000波特率的问题
  • (三分钟)速览传统边缘检测算子
  • (十)c52学习之旅-定时器实验
  • (详细文档!)javaswing图书管理系统+mysql数据库
  • (一)Kafka 安全之使用 SASL 进行身份验证 —— JAAS 配置、SASL 配置
  • (转)linux自定义开机启动服务和chkconfig使用方法
  • (转)总结使用Unity 3D优化游戏运行性能的经验