当前位置: 首页 > news >正文

tensorflow算子调用示例(MINIST)

tensorflow算子调用示例(MINIST)

本文以MINIST为例,阐述在模型训练时,tensorflow框架每个算子具体调用kernel的过程。

1. 数据准备和输入

在 MNIST 示例中,首先加载数据并进行预处理,生成用于训练和测试的数据集。这个步骤本身不涉及 GPU 加速,但数据会被加载到内存中,准备在计算图中进行后续操作。

import tensorflow as tf
from tensorflow.keras.datasets import mnist# 加载 MNIST 数据集
(x_train, y_train), (x_test, y_test) = mnist.load_data()
# 数据预处理
x_train, x_test = x_train / 255.0, x_test / 255.0

2. 构建模型

在 TensorFlow 中,神经网络的构建涉及多个算子,如矩阵乘法(MatMul)、卷积(Conv2D)、激活函数(如 ReLU)、以及用于分类任务的 Softmax

model = tf.keras.models.Sequential([tf.keras.layers.Flatten(input_shape=(28, 28)),tf.keras.layers.Dense(128, activation='relu'),tf.keras.layers.Dropout(0.2),tf.keras.layers.Dense(10)
])
1.Flatten

这是一个张量操作(Tensor Reshape),将 28x28 的图像展平为一维数组。

2.Dense (全连接层)

这里主要使用了矩阵乘法算子 MatMul 和偏置项加法 BiasAdd,这些算子会被发送到 CUDA 设备上进行计算。如果使用 GPU,这些操作会被 TensorFlow 映射到相应的 CUDA 核函数。

在 TensorFlow 中,Dense 层的操作主要依赖于矩阵乘法和偏置项的加法。具体的核函数与这些操作相关。以下是每个操作对应的 CUDA 核函数的细节:
2.1.矩阵乘法 (MatMul) 的核函数:
对于 Dense 层中的矩阵乘法操作,TensorFlow 在 GPU 上通过 NVIDIA 提供的 cuBLAS 库执行该操作。具体的核函数根据数据类型的不同,通常是以下两个:

  • cublasSgemm(用于单精度浮点数,float
  • cublasDgemm(用于双精度浮点数,double

cublasSgemm 函数原型:

cublasStatus_t cublasSgemm(cublasHandle_t handle,cublasOperation_t transa,cublasOperation_t transb,int m,int n,int k,const float *alpha,  // 标量 alphaconst float *A,      // 输入矩阵 Aint lda,             // leading dimension of Aconst float *B,      // 输入矩阵 Bint ldb,             // leading dimension of Bconst float *beta,   // 标量 betafloat *C,            // 输出矩阵 Cint ldc              // leading dimension of C
);

cublasDgemm 函数原型:

cublasStatus_t cublasDgemm(cublasHandle_t handle,cublasOperation_t transa,cublasOperation_t transb,int m,int n,int k,const double *alpha,  // 标量 alphaconst double *A,      // 输入矩阵 Aint lda,              // leading dimension of Aconst double *B,      // 输入矩阵 Bint ldb,              // leading dimension of Bconst double *beta,   // 标量 betadouble *C,            // 输出矩阵 Cint ldc               // leading dimension of C
);

2.2偏置项加法 (BiasAdd) 的核函数:
对于 BiasAdd 操作,TensorFlow 在 GPU 上使用的是自定义的 CUDA 核函数。这些核函数负责将偏置项加到矩阵的每一行或列上,通常涉及到张量广播操作。

BiasAdd 核函数原型:

在 TensorFlow 的源代码中,BiasAdd 通常是通过名为 BiasAddKernel 的自定义 CUDA 核来实现的。在反向传播过程中,偏置项的梯度计算会调用 BiasGradKernel

由于这是自定义实现的 CUDA 核函数,源代码可以在 TensorFlow 的 GitHub 仓库中找到。以下是一个简化的自定义核函数实现示例,用于在 GPU 上执行 BiasAdd 操作:

template <typename T>
__global__ void BiasAddKernel(const T* input, const T* bias, T* output, int num_rows, int num_cols) {int row = blockIdx.x * blockDim.x + threadIdx.x;if (row < num_rows) {for (int col = 0; col < num_cols; ++col) {output[row * num_cols + col] = input[row * num_cols + col] + bias[col];}}
}

该核函数的主要功能是遍历输入矩阵的每一行,并将偏置项加到每个元素上。每个线程处理一行矩阵中的元素,通过并行化来提升计算效率。

2.3反向传播的核函数

在反向传播中,计算权重和偏置项的梯度同样需要矩阵运算和累加操作。

  • 权重梯度的计算 依然会使用 cuBLAS 中的 cublasSgemmcublasDgemm,这是通过反向传播的梯度和输入数据的矩阵乘法来计算权重的更新。
  • 偏置项梯度的计算 通常是一个简单的张量求和操作,这可以通过自定义的 CUDA 核函数来高效实现,典型实现如下:
template <typename T>
__global__ void BiasGradKernel(const T* grad_output, T* bias_grad, int num_rows, int num_cols) {int col = blockIdx.x * blockDim.x + threadIdx.x;if (col < num_cols) {T sum = 0;for (int row = 0; row < num_rows; ++row) {sum += grad_output[row * num_cols + col];}bias_grad[col] = sum;}
}

该核函数遍历反向传播的梯度输出,并在每一列上累加梯度,得到偏置项的梯度。

2.4CUDA 流管理

为了管理异步操作,TensorFlow 会使用 CUDA 流(stream)来并行执行多个核函数。这允许 MatMulBiasAdd 和反向传播操作同时进行,而不会阻塞 CPU。

CUDA 流管理函数:
  • cudaStreamCreate:创建一个 CUDA 流,允许异步操作。
  • cudaStreamSynchronize:等待流中的所有操作完成。
  • cudaLaunchKernel:将核函数提交到 CUDA 流中,执行并行操作。
3.Dropout

虽然 Dropout 是一种正则化技术,但其操作(如随机丢弃部分神经元)也会利用 CUDA 进行高效的矩阵运算。

  • ReLU (激活函数):激活函数使用的是 tf.nn.relu 算子。这个算子将被映射到 CUDA 核来加速 ReLU 操作,特别是在大规模的矩阵计算中。
__global__ void DropoutKernel(float* input, float* output, float* mask, float keep_prob, int size) {int idx = blockIdx.x * blockDim.x + threadIdx.x;if (idx < size) {// mask 生成,利用 GPU 生成随机数if (mask[idx] < keep_prob) {// 保留神经元并缩放输出值output[idx] = input[idx] / keep_prob;} else {// 丢弃神经元output[idx] = 0.0f;}}
}

input:输入张量,是神经网络前一层的输出。
output:经过 Dropout 处理后的输出张量。
mask:随机生成的掩码矩阵,包含 0 和 1。
keep_prob:神经元保留的概率。
size:神经元的数量(即输入张量的大小)。

3. 前向传播与反向传播

在训练过程中,TensorFlow 中的计算图会负责前向传播(计算损失)和反向传播(计算梯度并更新权重)。这涉及大量的矩阵运算、卷积操作以及求导计算。

矩阵乘法 (MatMul) 为例,当在训练中进行矩阵计算时,TensorFlow 会将计算任务调度到 GPU 上。对于每个算子,TensorFlow 调用相应的 CUDA 核函数,利用 GPU 的并行计算能力进行加速。例如,tf.matmul 在底层会映射到对应的 CUDA 核函数,用以处理大规模的矩阵运算。


loss_fn = tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True)

这里的 SparseCategoricalCrossentropy 损失函数涉及Softmax 操作和交叉熵的计算,也会映射到 CUDA 核进行加速。
softmax核函数示例:

__global__ void SoftmaxKernel(float* logits, float* output, int n) {int idx = blockIdx.x * blockDim.x + threadIdx.x;if (idx < n) {float max_logit = -FLT_MAX;float sum = 0.0f;// 找到最大值,避免指数溢出for (int i = 0; i < n; i++) {max_logit = max(max_logit, logits[i]);}// 计算 Softmaxfor (int i = 0; i < n; i++) {sum += expf(logits[i] - max_logit);}for (int i = 0; i < n; i++) {output[i] = expf(logits[i] - max_logit) / sum;}}
}

交叉熵核函数示例:

__global__ void CrossEntropyKernel(float* softmax_output, int* labels, float* loss, int n) {int idx = blockIdx.x * blockDim.x + threadIdx.x;if (idx < n) {int true_label = labels[idx];float p_true = softmax_output[true_label];loss[idx] = -logf(p_true);}
}

加速过程

  • Softmax 加速:通过 CUDA 内核并行化计算每个 logits 的指数值和总和,并高效实现归约操作。
  • 交叉熵加速:使用并行化计算 Softmax 输出中对应真实标签的概率,并计算其对数作为交叉熵损失。
  • 整体加速:TensorFlow 中 SparseCategoricalCrossentropy 损失函数的实现会将 Softmax 和交叉熵的计算结合在一起,通过 CUDA 内核融合和高效的内存访问来加速整个过程。

4. 训练优化

TensorFlow 的优化器(例如 SGD、Adam 等)会使用反向传播算法来更新模型的参数。这一过程涉及对权重的矩阵求导和更新,这些操作会通过 CUDA 的矩阵运算加速库(例如 cuBLAS、cuDNN)来实现。

model.compile(optimizer='adam', loss=loss_fn, metrics=['accuracy'])
model.fit(x_train, y_train, epochs=5)

model.fit 的过程中,TensorFlow 会自动将支持 CUDA 的算子交给 GPU 执行。例如:

  • cuBLAS:用于加速矩阵乘法、矩阵求逆等基本线性代数运算。
  • cuDNN:用于加速卷积操作,特别是涉及到卷积神经网络的训练时,tf.nn.conv2d 操作会调用 cuDNN 的加速函数来计算。

5. 设备分配与张量操作

在执行时,TensorFlow 会自动将计算图中的操作分配到可用的设备上。如果检测到有 GPU 可用,TensorFlow 会将支持 CUDA 的算子分配到 GPU 上。

# 检查 TensorFlow 是否使用 GPU
print("GPU Available: ", tf.config.list_physical_devices('GPU'))

TensorFlow 使用其底层的 Placer 机制来自动选择设备,并将计算任务调度到 GPU。如果使用 CUDA,它会调用对应的 CUDA 内核来执行张量操作,如卷积、矩阵乘法和激活函数。

6. 推理时的 CUDA 使用

在训练完成后,推理过程也会涉及到类似的操作,TensorFlow 会继续利用 CUDA 来加速前向传播中的计算。


Reference:

  1. https://www.tensorflow.org/guide/gpu

  2. https://docs.nvidia.com/cuda/

  3. https://www.tensorflow.org/install/gpu

  4. https://docs.nvidia.com/cuda/cublas/index.html

  5. https://developer.nvidia.com/cudnn

相关文章:

  • 北京网站建设多少钱?
  • 辽宁网页制作哪家好_网站建设
  • 高端品牌网站建设_汉中网站制作
  • 【项目实战】如何在项目中基于 Spring Boot Starter 开发简单的 SDK
  • ARM基础知识点及简单汇编语法
  • 【数据结构-栈】力扣71. 简化路径
  • 【计算机网络 - 基础问题】每日 3 题(二十一)
  • YOLOv8 OBB win10+ visual 2022移植部署
  • 【2023次方 / B】
  • 王红梅老师ppt介绍算法设计一般过程---对上周csdn的补充----可以参考老版教师用书--单链表专题在介绍插入时介绍了正向思维方法,这是更详细的解释跟全面
  • iptables和nftables
  • 淘客系统开发之卷轴模式系统源码功能分析
  • 解锁视频生成新时代! 探索智谱CogVideoX-2b:轻松生成6秒视频的详细指南
  • ReKep——李飞飞团队提出的让机器人具备空间智能:基于视觉语言模型GPT-4o和关系关键点约束
  • C语言常见字符串函数模拟实现一:(strlen,strcpy,strcat,strcmp,strstr )
  • 最新最详细的Mastercam安装包下载安装教程(保姆级)
  • Go语言的垃圾回收(GC)机制的迭代和优化历史
  • 在HTML中添加图片
  • [数据结构]链表的实现在PHP中
  • 【跃迁之路】【699天】程序员高效学习方法论探索系列(实验阶段456-2019.1.19)...
  • Electron入门介绍
  • Go 语言编译器的 //go: 详解
  • ReactNativeweexDeviceOne对比
  • SpriteKit 技巧之添加背景图片
  • Vue2.0 实现互斥
  • yii2中session跨域名的问题
  • 关于for循环的简单归纳
  • 马上搞懂 GeoJSON
  • 如何解决微信端直接跳WAP端
  • 使用parted解决大于2T的磁盘分区
  • 一个项目push到多个远程Git仓库
  • 在weex里面使用chart图表
  • 正则表达式小结
  • ​LeetCode解法汇总2808. 使循环数组所有元素相等的最少秒数
  • !$boo在php中什么意思,php前戏
  • #我与Java虚拟机的故事#连载03:面试过的百度,滴滴,快手都问了这些问题
  • (Java实习生)每日10道面试题打卡——JavaWeb篇
  • (Matalb回归预测)PSO-BP粒子群算法优化BP神经网络的多维回归预测
  • (TipsTricks)用客户端模板精简JavaScript代码
  • (WSI分类)WSI分类文献小综述 2024
  • (二)springcloud实战之config配置中心
  • (附源码)基于SSM多源异构数据关联技术构建智能校园-计算机毕设 64366
  • (南京观海微电子)——COF介绍
  • (十七)Flink 容错机制
  • (十一)c52学习之旅-动态数码管
  • (一)、软硬件全开源智能手表,与手机互联,标配多表盘,功能丰富(ZSWatch-Zephyr)
  • .NET Windows:删除文件夹后立即判断,有可能依然存在
  • /dev下添加设备节点的方法步骤(通过device_create)
  • @KafkaListener注解详解(一)| 常用参数详解
  • [ 常用工具篇 ] POC-bomber 漏洞检测工具安装及使用详解
  • [ 隧道技术 ] cpolar 工具详解之将内网端口映射到公网
  • [AutoSar]BSW_Com07 CAN报文接收流程的函数调用
  • [BZOJ2281][SDOI2011]黑白棋(K-Nim博弈)
  • [C#]winform制作圆形进度条好用的圆环圆形进度条控件和使用方法
  • [C++]二叉搜索树
  • [EFI]ASUS EX-B365M-V5 Gold G5400 CPU电脑 Hackintosh 黑苹果引导文件
  • [Hive] CTE 通用表达式 WITH关键字
  • [HTML]Web前端开发技术29(HTML5、CSS3、JavaScript )JavaScript基础——喵喵画网页