当前位置: 首页 > news >正文

【Pytorch】cumsum的实现逻辑

本文只记录cumsum的实现逻辑的CUDA部分,也即底层调用了CUDA的什么实现算子。

void launch_cumsum_cuda_kernel(const TensorBase& result, const TensorBase& self, int64_t dim) {AT_DISPATCH_ALL_TYPES_AND_COMPLEX_AND2(ScalarType::Half, ScalarType::BFloat16,self.scalar_type(), "cumsum_cuda",[&]() {scalar_t init = 0;scan_dim<scalar_t>(self,result,dim,init,std::plus<scalar_t>());});
}

通过定位源码,找到了执行kernel的关键代码,可以看到,此代码内部调用了Pytorch定义的宏,核心调用是pytorch定义的名为scan_dim的模板函数。
该模板函数的定义位于:aten/src/ATen/native/cuda/ScanUtils.cuh
代码如下:

template<typename scalar_t, typename BinaryFunction>
void scan_dim(const TensorBase& self, const TensorBase& result,int64_t dim, scalar_t init, BinaryFunction binary_op) {int ndim = self.dim();auto self_ = self.expect_contiguous();TORCH_INTERNAL_ASSERT(result.is_contiguous());if (self.numel() == self.size(dim)) {cuda::cub::inclusive_scan(self_->const_data_ptr<scalar_t>(), result.mutable_data_ptr<scalar_t>(), binary_op, self.numel());} else if (dim == ndim - 1) {scan_innermost_dim<scalar_t>(*self_, result, init, binary_op);} else {scan_outer_dim<scalar_t>(*self_, result, dim, init, binary_op);}
}

该函数内部最重要的是后面的条件结构,首先如果元素的总数和当前维度的元素个数相同,也即tensor是一维的,直接利用cub的前缀扫描方法,如果元素的总数和当前维度的元素个数不同,又分为最内层的维度,也即最后一维,以及其他情况。

template<typename scalar_t, class BinaryFunction>
__host__ void scan_outer_dim(const TensorBase& self, const TensorBase& result,int dim, scalar_t init, BinaryFunction binary_op) {const int64_t row_size = self.size(dim);auto sizes = self.sizes();// Treat all outer dimensions (i.e. dim_ < dim) as one.const int64_t num_orows = c10::multiply_integers(sizes.begin(), sizes.begin() + dim);// Treat all inner dimensions (i.e. dim > dimension) as one.const int64_t num_irows = c10::multiply_integers(sizes.begin() + dim + 1, sizes.end());dim3 threads(std::min(512, int(num_irows)));int64_t maxGridDim = at::cuda::getCurrentDeviceProperties()->maxGridSize[1];dim3 grid(std::min(maxGridDim, num_orows), std::min(maxGridDim, ceil_div(num_irows, int64_t{threads.x})));check_fits_in_unsigned(num_irows, "num_irows");check_fits_in_unsigned(num_orows, "num_orows");check_fits_in_unsigned(row_size, "row_size");tensor_kernel_scan_outer_dim<scalar_t><<<grid, threads, 0, at::cuda::getCurrentCUDAStream()>>>(result.mutable_data_ptr<scalar_t>(), self.const_data_ptr<scalar_t>(),num_orows, num_irows, row_size, init, binary_op);C10_CUDA_KERNEL_LAUNCH_CHECK();
}template <typename scalar_t, class BinaryFunction>
void scan_innermost_dim(const TensorBase& self, const TensorBase& result,scalar_t init, BinaryFunction binary_op) {int64_t ndim = self.dim();// Treat all outer dimensions as a single dimension.int64_t row_size = self.size(ndim - 1);int64_t num_rows = self.numel() / row_size;// assuming max_num_threads per block is 512const uint32_t num_threads = 512;const uint32_t log_num_threads_x = get_log_num_threads_x_inner_scan<uint32_t>(num_rows, row_size);const uint32_t num_threads_x = (1 << log_num_threads_x);const uint32_t num_threads_y = num_threads / num_threads_x;dim3 threads(num_threads_x, num_threads_y);int64_t maxGridDim = at::cuda::getCurrentDeviceProperties()->maxGridSize[0];dim3 grid(std::min(maxGridDim, ceil_div(num_rows, int64_t{threads.y})));check_fits_in_unsigned(num_rows, "Number of rows (self.numel()/self.size(self.dim()-1))");check_fits_in_unsigned(row_size, "row_size");tensor_kernel_scan_innermost_dim<scalar_t><<<grid, threads, num_threads * 2 * sizeof(scalar_t),at::cuda::getCurrentCUDAStream()>>>(result.mutable_data_ptr<scalar_t>(), self.const_data_ptr<scalar_t>(),num_rows, row_size, log_num_threads_x, init, binary_op);C10_CUDA_KERNEL_LAUNCH_CHECK();
}

可以看到Pytorch针对上述两种情况进行了自定义,因为cub的inclusive_scan针对的是一维张量而非多维张量。

在调用核函数前,首先要定义调用核函数的网络结构和线程块结构,pytorch默认的线程块大小是512的,那么如何将512个线程块进行二维切分以满足合适的比例呢,pytorch中的做法是像下面这样:

template <typename integer>
constexpr inline integer get_log_num_threads_x_inner_scan(integer num_rows, integer row_size) {integer log_num_threads_x = 0;integer log_num_threads_y = 0;while (((integer)1 << log_num_threads_x) < row_size) {++log_num_threads_x;}while (((integer)1 << log_num_threads_y) < num_rows) {++log_num_threads_y;}// we want to keep the ratio between the x-threads and y-threads about the same as// the ratio between the row_size and num_rows, but the total number of threads in// a block should be about 512integer diff = log_num_threads_x - log_num_threads_y;// 9 is from log2(512)log_num_threads_x = ((integer)9 + diff) / (integer)2;// I found that in having larger log_num_threads_x can give significant speed up in some cases,// but detrimental in another case, so just keep the lower bound to be log2(16) == 4 to make it// similar to the previous implementation// Keeping the upper bound to be log2(512) == 9 as the maximum number of threads in a block.log_num_threads_x = std::min(std::max((integer)4, log_num_threads_x), (integer)9);return log_num_threads_x;
}

使用对数进行计算是便于计算出的x的结果可以整除,关键点在于最后平衡二者的比例的那行代码。可以预见,在某些情况下由于待处理数据的大小超过512造成线程块不能够完全分配的情况,此时就需要顾及线程块的比例,那么如果两个维度上线程块的对数值分别为x和y,对应的线程数分别为X,Y,也即 X = 2 x X=2^x X=2x。此时X与Y的比例 X / Y X / Y X/Y 的结果也即 2 x − y 2^{x - y} 2xy ,其实也就是 2 d i f f 2 ^ {diff} 2diff。那么如果将x变为(diff+9) / 2, y也就是 (9 - diff) / 2,二者相减也就是diff,因此保证了变换前后的比例。

相关文章:

  • Linux网络:传输层协议TCP(一)
  • 基于riscv64架构的Dayu800开发板的napi_demo开发介绍
  • MySQL大框架总结
  • 《南京师大学报》(社会科学版)
  • 如何进行小程序的调试
  • c++基础2
  • 在WPF中使用WebView2详解
  • Angular 18.2.0 的新功能增强和创新
  • 问题记录-SpringBoot 2.7.2 整合 Swagger 报错
  • html必知必会-html内嵌JavaScript和文件路径
  • 如何使用大语言模型绘制专业图表
  • Sqlmap中文使用手册 - Techniques模块参数使用
  • 最新源支付系统源码 V7版全开源 免授权 附搭建教程
  • C++ | Leetcode C++题解之第278题第一个错误的版本
  • Vue2和Vue3实战代码中的小差异(实时更新)
  • 【笔记】你不知道的JS读书笔记——Promise
  • 【从零开始安装kubernetes-1.7.3】2.flannel、docker以及Harbor的配置以及作用
  • 03Go 类型总结
  • bearychat的java client
  • FastReport在线报表设计器工作原理
  • flask接收请求并推入栈
  • gops —— Go 程序诊断分析工具
  • HTTP中GET与POST的区别 99%的错误认识
  • Javascript编码规范
  • Linux CTF 逆向入门
  • magento2项目上线注意事项
  • puppeteer stop redirect 的正确姿势及 net::ERR_FAILED 的解决
  • text-decoration与color属性
  • 从零开始的无人驾驶 1
  • 判断客户端类型,Android,iOS,PC
  • 前端js -- this指向总结。
  • 世界上最简单的无等待算法(getAndIncrement)
  • 延迟脚本的方式
  • 一些关于Rust在2019年的思考
  • 译有关态射的一切
  • 容器镜像
  • #HarmonyOS:基础语法
  • #if和#ifdef区别
  • #我与Java虚拟机的故事#连载01:人在JVM,身不由己
  • (PADS学习)第二章:原理图绘制 第一部分
  • (python)数据结构---字典
  • (补)B+树一些思想
  • (附源码)ssm失物招领系统 毕业设计 182317
  • (附源码)基于ssm的模具配件账单管理系统 毕业设计 081848
  • (规划)24届春招和25届暑假实习路线准备规划
  • (学习日记)2024.03.25:UCOSIII第二十二节:系统启动流程详解
  • (一) 初入MySQL 【认识和部署】
  • .mysql secret在哪_MySQL如何使用索引
  • .net core 的缓存方案
  • .NET Core 将实体类转换为 SQL(ORM 映射)
  • .NET Framework 服务实现监控可观测性最佳实践
  • .NET 设计模式初探
  • .Net 知识杂记
  • .Net下的签名与混淆
  • .net中我喜欢的两种验证码