当前位置: 首页 > news >正文

tensor连接和拆分

文章目录

    • 连接
      • torch.cat()
        • 案例准备
      • torch.stack()
        • 区别
    • 拆分
      • torch.split()

连接

torch.cat()

函数目的: 在给定维度上对输入的张量序列 进行连接操作。

案例准备
a = torch.tensor([[1,2,3],[4,5,6],[7,8,9]], dtype=torch.float)
b = torch.tensor([[10,10,10,],[10,10,10],[10,10,10,]], dtype=torch.float)

在这里插入图片描述

# dim指的是维度,dim = 0就是行,所以下面的代码就是按行拼接
print("按行拼接:\n",torch.cat((a,b),dim=0))
print("按行拼接:\n",torch.cat((a,b),dim=0).shape) #6行3列

在这里插入图片描述

print("按列拼接:\n",torch.cat((a,b),dim=1))
print("按列拼接:\n",torch.cat((a,b),dim=1).shape)#3行6列

在这里插入图片描述

torch.stack()

沿着一个新维度对输入张量序列进行连接。 序列中所有的张量都应该为相同形状。
也就是2维拼成3维,3维拼4维,以此类推。

print("按行拼接:\n",torch.stack((a,b),dim=0))
print("按行拼接:\n",torch.stack((a,b),dim=0).shape) 

在这里插入图片描述

print("按行拼接:\n",torch.stack((a,b),dim=1))
print("按行拼接:\n",torch.stack((a,b),dim=1).shape)

在这里插入图片描述

print("按行拼接:\n",torch.stack((a,b),dim=2))
print("按行拼接:\n",torch.stack((a,b),dim=2).shape)

在这里插入图片描述

区别

stack与cat的区别在于,torch.stack()函数要求输入张量的大小完全相同,得到的张量的维度会比输入的张量的大小多1,并且多出的那个维度就是拼接的维度,那个维度的大小就是输入张量的个数。

c = torch.tensor([[10,20],[30,40],[50,60]], dtype=torch.float)
a = torch.tensor([[1,2,3],[4,5,6],[7,8,9]], dtype=torch.float)
torch.cat((a,c),dim=1)

在这里插入图片描述

#但是以下情况就会出错
torch.cat((a,c),dim=0)

在这里插入图片描述
如图,按行拼接会缺数据,报错吗,应该的。
在这里插入图片描述

torch.stack((a,c),dim=0)
###运行结果
RuntimeError: stack expects each tensor to be equal size, but got [3, 3] at entry 0 and [3, 2] at entry 1再次验证stack需要两个大小一样的张量

拆分

torch.split()

def split(
tensor: Tensor, split_size_or_sections: Union[int, List[int]], dim: int = 0
) -> Tuple[Tensor, …]:

  • 按块大小拆分张量 除不尽的取余数,返回一个元组
a = torch.tensor([[1,2,3],[4,5,6],[7,8,9]], dtype=torch.float)
print(torch.split(a,2,dim=0))	#按行拆,两行拆成一个
print(torch.split(a,1,dim=0))	#按行拆,一行拆成一个
print(torch.split(a,1,dim=1))	#按列拆,一列拆成一个
print(torch.split(a,2,dim=1)) 	#按列拆,两列拆成一个

在这里插入图片描述

  • 按块数拆分张量
torch.chunk(a,2,dim=0)	#按行拆成两块
torch.split(a,2,dim=1)	#按列拆成两块

在这里插入图片描述

相关文章:

  • 北京网站建设多少钱?
  • 辽宁网页制作哪家好_网站建设
  • 高端品牌网站建设_汉中网站制作
  • 搜维尔科技:ART光学空间定位虚拟交互工业级光学跟踪系统
  • sourcetree配置ssh连接gitee
  • 中国企业500强!最新名单揭晓→
  • JavaScript高级进阶(二)
  • IGNAV_NHC分析
  • 【深度学习】训练过程中一个OOM的问题,太难查了
  • 多人开发小程序设置体验版的痛点
  • 视频推拉流/直播点播EasyDSS平台安装失败并报错“install mediaserver error”是什么原因?
  • Centos7.9部署Gitlab-ce-16.9
  • 【人工智能学习笔记】3_2 机器学习基础之机器学习经典算法介绍
  • 程序员如何写笔记并整理资料?
  • react js 路由 Router
  • 跑步戴的耳机哪个品牌的好?全新测评推荐五大爆款骨传导运动耳机
  • 工业一体机帮助MES系统打通工厂数据采集及目视化
  • Python 如何类与对象
  • echarts的各种常用效果展示
  • FastReport在线报表设计器工作原理
  • gitlab-ci配置详解(一)
  • Java基本数据类型之Number
  • Linux CTF 逆向入门
  • Netty 4.1 源代码学习:线程模型
  • PAT A1092
  • ubuntu 下nginx安装 并支持https协议
  • vue2.0一起在懵逼的海洋里越陷越深(四)
  • vue脚手架vue-cli
  • 阿里云Kubernetes容器服务上体验Knative
  • 基于MaxCompute打造轻盈的人人车移动端数据平台
  • 精益 React 学习指南 (Lean React)- 1.5 React 与 DOM
  • 理解 C# 泛型接口中的协变与逆变(抗变)
  • 深度解析利用ES6进行Promise封装总结
  • 数据结构java版之冒泡排序及优化
  • 吐槽Javascript系列二:数组中的splice和slice方法
  • 问:在指定的JSON数据中(最外层是数组)根据指定条件拿到匹配到的结果
  • 应用生命周期终极 DevOps 工具包
  • 7行Python代码的人脸识别
  • raise 与 raise ... from 的区别
  • 机器人开始自主学习,是人类福祉,还是定时炸弹? ...
  • 通过调用文摘列表API获取文摘
  • ​configparser --- 配置文件解析器​
  • ​数据链路层——流量控制可靠传输机制 ​
  • ​无人机石油管道巡检方案新亮点:灵活准确又高效
  • # 利刃出鞘_Tomcat 核心原理解析(二)
  • #我与Java虚拟机的故事#连载01:人在JVM,身不由己
  • #知识分享#笔记#学习方法
  • ( )的作用是将计算机中的信息传送给用户,计算机应用基础 吉大15春学期《计算机应用基础》在线作业二及答案...
  • (2020)Java后端开发----(面试题和笔试题)
  • (3)nginx 配置(nginx.conf)
  • (35)远程识别(又称无人机识别)(二)
  • (4)STL算法之比较
  • (C语言)strcpy与strcpy详解,与模拟实现
  • (pojstep1.3.1)1017(构造法模拟)
  • (zhuan) 一些RL的文献(及笔记)
  • (二)测试工具
  • (含react-draggable库以及相关BUG如何解决)固定在左上方某盒子内(如按钮)添加可拖动功能,使用react hook语法实现
  • (七)Appdesigner-初步入门及常用组件的使用方法说明