当前位置: 首页 > news >正文

形象化理解pytorch中的tensor.scatter操作

定义

        scatter_(dim, index, src, *, reduce=None) -> Tensor

pytorch官网说这个函数的作用是从src中把index指定的位置把数据写入到self里面,然后给了一个公式:           

self[index[i][j][k]][j][k] = src[i][j][k]  # if dim == 0self[i][index[i][j][k]][k] = src[i][j][k]  # if dim == 1self[i][j][index[i][j][k]] = src[i][j][k]  # if dim == 2

这个公式我也是一脸懵,但是我们可以把他降维到二维表格上,即:

            self[ index[i][j] ][j] = src[i][j]  # if dim == 0

把src从 i 行 移动到了 index[i][j] 行

            self[i][ index[i][j] ] = src[i][j]  # if dim == 1

把src 从 j 列移动到了 index[i][j] 列

对此,个人认为比较直观的理解:
        dim=0,就是把本行这个data放到本列的哪行(上下移动)
        dim=1,就是把本列这个data放到本行的哪列(左右移动)

所以,index数组其实是一个位置变化的映射表

例子1

给定src是一个顺序数组,我们可以更清楚看到这一变化过程。

>>> src = torch.tensor([ [1,2,3], [4,5,6], [7,8,9] ] )
>>> src
tensor([[1, 2, 3],[4, 5, 6],[7, 8, 9]])

当我们指定 dim=0,就是把每一行data放到上下移动位置,比如我们给一个例子

>>> index = torch.tensor([ [0, 0, 0], [1, 1, 1], [2, 2, 0] ]) 
>>> src.scatter(dim = 0, index=index, src = src)
tensor([[1, 2, 9],[4, 5, 6],[7, 8, 9]])

可以看到,scatter之后只有 src[0][2] 发生了变化,为什么呢?

 前面提到了index数组其实是一个位置变化的映射表,  dim=0 时候是把src从 i 行 移动到了 index[i][j] 行(上下移动), 这里的index表 0行所有的元素都移动到了0行对应位置, 1行所有的的元素都移动到了1行对应位置, 只有2行最后一个元素移动到了0行,造成的结果就是src只有最后一个元素移动到了0行的对应位置(从src[2][2]移动到了src[0][2])

 例子2

下面我们再试试dim = 1 时候 把src 从 j 列移动到了 index[i][j] 列

给定src是一个顺序数组,我们可以更清楚看到这一变化过程。

>>> src = torch.tensor([ [1,2,3], [4,5,6], [7,8,9] ] )
>>> src
tensor([[1, 2, 3],[4, 5, 6],[7, 8, 9]])

index给定如下

>>> index = torch.tensor([ [0, 1, 2], [0, 1, 2], [0, 1, 0] ])
>>> index
tensor([[0, 1, 2],[0, 1, 2],[0, 1, 0]])>>> src.scatter(dim = 1, index=index, src = src) 
tensor([[1, 2, 3],[4, 5, 6],[9, 8, 9]])

可以看到,这里src也只有一个位置发生了变化,为什么呢?

 前面提到了index数组其实是一个位置变化的映射表,  dim=1 时候是把src从 i 列 移动到了 index[i][j] 列 (左右移动), 这里的index表 0行 012 列对应的元素都移动到了0行 012列 对应位置(相当于没动), 1行 012 列对应的元素都移动到了1行 012列 对应位置(相当于没动), 只有2行最后一个元素移动到了0列,造成的结果就是src只有最后一个元素移动到了2行0列的位置(从src[2][2]移动到了src[2][0] )

意义

那么这种映射这么复杂,它的意义在哪里呢? 

答:一般scatter用于生成onehot向量

这里还是举个例子

我们还是拿之前的src数组

>>> src = torch.tensor([ [1,2,3], [4,5,6], [7,8,9] ] )
>>> src
tensor([[1, 2, 3],[4, 5, 6],[7, 8, 9]])

我们要如何理解它呢?我们可以认为它是三只股票在昨天、今天、明天的股票价格,昨天三只股票的价格分别为1,4,7,今天三只股票的价格分别为2,5,8, 明天三只股票的价格分别为3,6,9。

现在我们要训练一个预测后天股票价格的神经网络,我们给模型的输入应该是昨天三只股票的价格、今天三只股票的价格、明天三只股票的价格,即1,4,7,2,5,8,3,6,9。同时,我们要把每个数字转化为一个onehot的向量,这样的结果是我们期望的。

所以,我们要做的事情是把src转换为一个 3*3 的矩阵,矩阵中每个元素是一个能表示0-9的10维one-hot向量。

拿一段常用的onehot生成代码说事。


def one_hot(x, n_class, dtype=torch.float32):# X shape: (batch, 1), output shape: (batch, n_class)x = x.long()res = torch.zeros(x.shape[0], n_class, dtype=dtype, device=x.device)res.scatter_(1, x.view(-1, 1), 1)return res# X shape: batch_size, prices_list
def to_onehot(X, n_class):# 返回结果 shape: prices_list, batch_size, onehot_size 三维return [one_hot(X[:, i], n_class) for i in range(X.shape[1])]

先不谈代码含义,输出结果如下 

>>> to_onehot(src, 10) 
[tensor([[0., 1., 0., 0., 0., 0., 0., 0., 0., 0.],[0., 0., 0., 0., 1., 0., 0., 0., 0., 0.],[0., 0., 0., 0., 0., 0., 0., 1., 0., 0.]]),tensor([[0., 0., 1., 0., 0., 0., 0., 0., 0., 0.],[0., 0., 0., 0., 0., 1., 0., 0., 0., 0.],[0., 0., 0., 0., 0., 0., 0., 0., 1., 0.]]), tensor([[0., 0., 0., 1., 0., 0., 0., 0., 0., 0.],[0., 0., 0., 0., 0., 0., 1., 0., 0., 0.],[0., 0., 0., 0., 0., 0., 0., 0., 0., 1.]])]

这个结果基本符合了我们的期望,那么这个是如何做到的呢? 


# X shape: batch_size, prices_list
def to_onehot(X, n_class):# 返回结果 shape: prices_list, batch_size, onehot_size 三维return [one_hot(X[:, i], n_class) for i in range(X.shape[1])]

首先,src按照昨天,今天,明天的维度,被切分为了三个列向量 [1,4,7]、[2,5,8]、 [3,6,9] 。这三个列向量对应了我们的输出,one_hot给定一个列向量,可以转换为一个one-hot列向量组。

def one_hot(x, n_class, dtype=torch.float32):# X shape: (batch, 1), output shape: (batch, n_class)x = x.long()res = torch.zeros(x.shape[0], n_class, dtype=dtype, device=x.device)res.scatter_(1, x.view(-1, 1), 1)return res

为了简单,我们举一个例子


>>> one_hot(torch.tensor([1,2,3]), 4) 
tensor([[0., 1., 0., 0.],[0., 0., 1., 0.],[0., 0., 0., 1.]])>>> torch.tensor([1,2,3]).view(-1,1)  
tensor([[1],[2],[3]])

 可以看到,res是一个全0矩阵,scatter操作在dim=1时,是一个左右移动的位置映射表,这里的res是一个 3 * 4 的矩阵,src是一个数字,可以认为是跟res同样大小的全1矩阵,但是index是一个 3*1 的矩阵,也就是这个位置映射表可以认为是一个3行1列的映射表,即 全1矩阵的0 行 0 列映射到res的 0 行 1列,全1矩阵的1行0列映射到res的1行2列,全1矩阵的2行0列映射到res的2行3列,其他保持不变(其他都是0),dim=1这种操作就是制造了one-hot向量

相关文章:

  • 北京网站建设多少钱?
  • 辽宁网页制作哪家好_网站建设
  • 高端品牌网站建设_汉中网站制作
  • VsCode 内置 Git 可视化操作【初始化仓库】
  • HarmonyOS NEXT 底部选项卡功能
  • Excel排序错误原因之一
  • Spring cache的使用
  • 锐捷交换机常用命令
  • 【阿里千问最新多模态模型】Qwen2-VL:让世界看得更清晰
  • iText2KG:显著降低LLM构建知识图谱时的幻觉现象
  • React第五章(swc)
  • npm i:【idealTree:vue.web: sill idealTree buildDeps vue3拉取jQuery卡慢失败:[解决]】
  • LabVIEW如何确保采集卡稳定运行
  • Linux 磁盘扩容操作指引
  • 深入CSS 布局——WEB开发系列29
  • 【JavaSE】重写equals()和hashCode()
  • 2409wtl,网浏包装
  • 程序员如何写笔记?
  • 《用数据讲故事》作者Cole N. Knaflic:消除一切无效的图表
  • 【MySQL经典案例分析】 Waiting for table metadata lock
  • 【刷算法】从上往下打印二叉树
  • css选择器
  • Docker 1.12实践:Docker Service、Stack与分布式应用捆绑包
  • Eureka 2.0 开源流产,真的对你影响很大吗?
  • HTTP中的ETag在移动客户端的应用
  • JavaScript 事件——“事件类型”中“HTML5事件”的注意要点
  • LeetCode29.两数相除 JavaScript
  • puppeteer stop redirect 的正确姿势及 net::ERR_FAILED 的解决
  • springMvc学习笔记(2)
  • vue从创建到完整的饿了么(11)组件的使用(svg图标及watch的简单使用)
  • 阿里云应用高可用服务公测发布
  • 聚类分析——Kmeans
  • 聊聊springcloud的EurekaClientAutoConfiguration
  • 前端存储 - localStorage
  • 浅谈web中前端模板引擎的使用
  • 设计模式 开闭原则
  • 收藏好这篇,别再只说“数据劫持”了
  • postgresql行列转换函数
  • ### RabbitMQ五种工作模式:
  • (1)(1.11) SiK Radio v2(一)
  • (1)(1.19) TeraRanger One/EVO测距仪
  • (11)MATLAB PCA+SVM 人脸识别
  • (5)STL算法之复制
  • (6)设计一个TimeMap
  • (9)目标检测_SSD的原理
  • (Java实习生)每日10道面试题打卡——JavaWeb篇
  • (Python) SOAP Web Service (HTTP POST)
  • (第27天)Oracle 数据泵转换分区表
  • (二) 初入MySQL 【数据库管理】
  • (分类)KNN算法- 参数调优
  • (附源码)ssm本科教学合格评估管理系统 毕业设计 180916
  • (六) ES6 新特性 —— 迭代器(iterator)
  • (七)Java对象在Hibernate持久化层的状态
  • (删)Java线程同步实现一:synchronzied和wait()/notify()
  • (十二)springboot实战——SSE服务推送事件案例实现
  • .NET Core 版本不支持的问题
  • .Net Remoting(分离服务程序实现) - Part.3
  • .NET/C# 如何获取当前进程的 CPU 和内存占用?如何获取全局 CPU 和内存占用?