当前位置: 首页 > news >正文

飞桨Paddle API index_add 详解

index_add¶

paddle.index_add(xindexaxisvaluename=None)[源代码]¶

沿着指定轴 axis 将 index 中指定位置的 x 与 value 相加,并写入到结果 Tensor 中的对应位置。这里 index 是一个 1-D Tensor。除 axis 轴外,返回的 Tensor 其余维度大小和输入 x 相等, axis 维度的大小等于 index 的大小。

官方文档:index_add-API文档-PaddlePaddle深度学习平台

我们还是通过一个代码示例来学习:

x = paddle.ones([5, 3])
value = paddle.to_tensor([[1, 2, 3], [4, 5, 6], [7, 8, 9]], dtype=paddle.float32)
index = paddle.to_tensor([0, 4, 2])
print(x)x = paddle.index_add(x, index, 0, value)
print(x)

 输出

Tensor(shape=[5, 3], dtype=float32, place=Place(cpu), stop_gradient=True,[[1., 1., 1.],[1., 1., 1.],[1., 1., 1.],[1., 1., 1.],[1., 1., 1.]])
Tensor(shape=[5, 3], dtype=float32, place=Place(cpu), stop_gradient=True,[[2. , 3. , 4. ],[1. , 1. , 1. ],[8. , 9. , 10.],[1. , 1. , 1. ],[5. , 6. , 7. ]])

API 解析:index_add

查看前面的例子输出,可以看到,index_add就是把value的各个值,按照index里的值为索引,加入到源x里面去,比如

value = paddle.to_tensor([[1, 2, 3], [4, 5, 6], [7, 8, 9]], dtype=paddle.float32)
index = paddle.to_tensor([0, 4, 2])

首先取出value[0] ,发现index[0]是 0,那么就把value[0] 跟x[0]相加

取出value[1] ,发现index[1] 是4,那么就把value[1] 跟x[4]相加

取出value[2] ,发现index[2] 是2,那么就把value[2] 跟x[2]相加

在飞桨官方没有index_add函数的时候,可以用python来实现,当然速度会慢很多:

def paddleindex_add(x, dim, index, source): # 飞桨的index_add'''
x = paddle.ones([5, 3])
t = paddle.to_tensor([[1, 2, 3], [4, 5, 6], [7, 8, 9]], dtype=paddle.float32)
index = paddle.to_tensor([0, 4, 2])
# print(x)
with Benchmark("paddleindex_add"):x = paddleindex_add(x, 0, index, t)
print(x)'''for i in range(len(index)):x[index[i]] += source[i]return x

可以从赋值语句看到,就是从index里面取出值,然后x和source的相关值相加:x[index[i]] += source[i]

当然注释里面用了Benchmark函数,抄李沐老师的,源码如下

import time
class Timer:  #@save"""记录多次运行时间"""def __init__(self):self.times = []self.start()def start(self):"""启动计时器"""self.tik = time.time()def stop(self):"""停止计时器并将时间记录在列表中"""self.times.append(time.time() - self.tik)return self.times[-1]def avg(self):"""返回平均时间"""return sum(self.times) / len(self.times)def sum(self):"""返回时间总和"""return sum(self.times)def cumsum(self):"""返回累计时间"""return np.array(self.times).cumsum().tolist()class Benchmark:"""用于测量运行时间"""def __init__(self, description='Done'):self.description = descriptiondef __enter__(self):self.timer = Timer()return selfdef __exit__(self, *args):print(f'{self.description}: {self.timer.stop():.4f} sec')

相关文章:

  • 北京网站建设多少钱?
  • 辽宁网页制作哪家好_网站建设
  • 高端品牌网站建设_汉中网站制作
  • 8月8号前端日报:web在线进行eps32固件升级
  • 阿里云部署open-webui实现openai代理服务(持续更新)
  • Flink Checkpoint expired before completing解决方法
  • R 语言学习教程,从入门到精通,R 数据框(14)
  • 使用html+css+js实现完整的登录注册页面
  • Python酷库之旅-第三方库Pandas(082)
  • 数据集的简单制作和使用
  • TS中什么是泛型
  • MySQL与PostgreSQL语法区别
  • 小山菌_代码随想录算法训练营第六十二天|dijkstra(堆优化版)精讲 、Bellman_ford 算法精讲
  • 重新连接 到 时出错 Microsoft Windows Network:本地设备名已在使用中
  • Qt:线程
  • LeetCode 热题100-2
  • Unity引擎加密方案解析
  • Linux装ifort环境
  • css属性的继承、初识值、计算值、当前值、应用值
  • Javascript编码规范
  • js
  • js正则,这点儿就够用了
  • Logstash 参考指南(目录)
  • node 版本过低
  • v-if和v-for连用出现的问题
  • vue数据传递--我有特殊的实现技巧
  • Yeoman_Bower_Grunt
  • 前端每日实战:70# 视频演示如何用纯 CSS 创作一只徘徊的果冻怪兽
  • 设计模式 开闭原则
  • 使用 Docker 部署 Spring Boot项目
  • 一个SAP顾问在美国的这些年
  • MPAndroidChart 教程:Y轴 YAxis
  • 策略 : 一文教你成为人工智能(AI)领域专家
  • ​Benvista PhotoZoom Pro 9.0.4新功能介绍
  • ‌JavaScript 数据类型转换
  • #C++ 智能指针 std::unique_ptr 、std::shared_ptr 和 std::weak_ptr
  • $(selector).each()和$.each()的区别
  • $var=htmlencode(“‘);alert(‘2“); 的个人理解
  • (3)(3.2) MAVLink2数据包签名(安全)
  • (6)添加vue-cookie
  • (待修改)PyG安装步骤
  • (二)七种元启发算法(DBO、LO、SWO、COA、LSO、KOA、GRO)求解无人机路径规划MATLAB
  • (过滤器)Filter和(监听器)listener
  • (经验分享)作为一名普通本科计算机专业学生,我大学四年到底走了多少弯路
  • (免费领源码)python#django#mysql公交线路查询系统85021- 计算机毕业设计项目选题推荐
  • (四)七种元启发算法(DBO、LO、SWO、COA、LSO、KOA、GRO)求解无人机路径规划MATLAB
  • (贪心 + 双指针) LeetCode 455. 分发饼干
  • (一)SvelteKit教程:hello world
  • (转)JAVA中的堆栈
  • .Family_物联网
  • .L0CK3D来袭:如何保护您的数据免受致命攻击
  • .net wcf memory gates checking failed
  • .net/c# memcached 获取所有缓存键(keys)
  • /etc/fstab和/etc/mtab的区别
  • @AliasFor 使用
  • @Builder用法
  • @RequestBody与@RequestParam:Spring MVC中的参数接收差异解析
  • []新浪博客如何插入代码(其他博客应该也可以)