当前位置: 首页 > news >正文

【HuggingFace Transformers】BertSelfOutput 和 BertOutput源码解析

BertSelfOutput 和 BertOutput源码解析

  • 1. 介绍
    • 1.1 共同点
      • (1) 残差连接 (Residual Connection)
      • (2) 层归一化 (Layer Normalization)
      • (3) Dropout
      • (4) 线性变换 (Linear Transformation)
    • 1.2 不同点
      • (1) 处理的输入类型
      • (2) 线性变换的作用
      • (3) 输入的特征大小
  • 2. 源码解析
    • 2.1 BertSelfOutput 源码解析
    • 2.2 BertOutput 源码解析

1. 介绍

BertSelfOutputBertOutputBERT 模型中两个相关但不同的模块。它们在功能上有许多共同点,但也有一些关键的不同点。以下通过共同点和不同点来介绍它们。

1.1 共同点

BertSelfOutputBertOutput 都包含残差连接、层归一化、Dropout 和线性变换,并且这些操作的顺序相似。

(1) 残差连接 (Residual Connection)

两个模块都应用了残差连接,即将模块的输入直接与经过线性变换后的输出相加。这种结构可以帮助缓解深层神经网络中的梯度消失问题,使信息更直接地传递,保持梯度流动顺畅。

(2) 层归一化 (Layer Normalization)

在应用残差连接后,两个模块都使用层归一化 (LayerNorm) 来规范化输出。这有助于加速训练,稳定网络性能,并减少内部分布变化的问题。

(3) Dropout

两个模块都包含一个 Dropout 层,用于随机屏蔽一部分神经元的输出,增强模型的泛化能力,防止过拟合。

(4) 线性变换 (Linear Transformation)

两个模块都包含一个线性变换 (dense 层)。这个线性变换用于调整数据的维度,并为后续的残差连接和层归一化做准备。

1.2 不同点

BertSelfOutput 专注于处理自注意力机制的输出,而 BertOutput 则处理前馈神经网络的输出。它们的输入特征维度也有所不同,线性变换的作用在两个模块中也略有差异。

(1) 处理的输入类型

  • BertSelfOutput:处理自注意力机制 (BertSelfAttention) 的输出。它关注的是如何将注意力机制生成的特征向量与原始输入结合起来。
  • BertOutput:处理的是前馈神经网络的输出。它将经过注意力机制处理后的特征进一步加工,并整合到当前层的最终输出中。

(2) 线性变换的作用

  • BertSelfOutput:线性变换的作用是对自注意力机制的输出进行进一步的变换和投影,使其适应后续的处理流程。
  • BertOutput:线性变换的作用是对前馈神经网络的输出进行变换,使其与前一层的输出相结合,并准备传递到下一层。

(3) 输入的特征大小

  • BertSelfOutput:输入和输出的特征维度保持一致,都是 BERT 模型的隐藏层大小 (hidden_size)。
  • BertOutput:输入的特征维度是中间层大小 (intermediate_size),输出则是 BERT 模型的隐藏层大小 (hidden_size)。这意味着 BertOutput 的线性变换需要将中间层的维度转换回隐藏层的维度。

2. 源码解析

源码地址:transformers/src/transformers/models/bert/modeling_bert.py

2.1 BertSelfOutput 源码解析

# -*- coding: utf-8 -*-
# @time: 2024/7/15 14:27import torch
from torch import nnclass BertSelfOutput(nn.Module):def __init__(self, config):super().__init__()self.dense = nn.Linear(config.hidden_size, config.hidden_size)  # 定义线性变换层,将自注意力输出映射到 hidden_size 维度self.LayerNorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)  # 层归一化self.dropout = nn.Dropout(config.hidden_dropout_prob)  # Dropout层def forward(self, hidden_states: torch.Tensor, input_tensor: torch.Tensor) -> torch.Tensor:hidden_states = self.dense(hidden_states)  # 对自注意力机制的输出进行线性变换hidden_states = self.dropout(hidden_states)  # Dropout操作hidden_states = self.LayerNorm(hidden_states + input_tensor)  # 残差连接后进行层归一化return hidden_states

2.2 BertOutput 源码解析

# -*- coding: utf-8 -*-
# @time: 2024/8/22 15:41import torch
from torch import nnclass BertOutput(nn.Module):def __init__(self, config):super().__init__()self.dense = nn.Linear(config.intermediate_size, config.hidden_size)  # 定义线性变换层,将前馈神经网络输出从 intermediate_size 映射到 hidden_sizeself.LayerNorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)  # 层归一化self.dropout = nn.Dropout(config.hidden_dropout_prob)  # Dropout层def forward(self, hidden_states: torch.Tensor, input_tensor: torch.Tensor) -> torch.Tensor:hidden_states = self.dense(hidden_states)  # 对前馈神经网络的输出进行线性变换hidden_states = self.dropout(hidden_states)  # Dropout操作hidden_states = self.LayerNorm(hidden_states + input_tensor)  # 残差连接后进行层归一化return hidden_states

相关文章:

  • 北京网站建设多少钱?
  • 辽宁网页制作哪家好_网站建设
  • 高端品牌网站建设_汉中网站制作
  • ARM——驱动——内核编译
  • zdppy+vue3+onlyoffice文档管理系统实战 20240825上课笔记 zdppy_cache框架增加resize清理缓存的方法
  • javascript(js)入门指南
  • 【Android】Android Studio中利用git进行协同开发
  • 杰发科技AC7840——CAN通信简介(8)_通过波特率和时钟计算SEG_1/SEG_2/SJW/PRESC
  • 淘客系统源码的架构分析
  • 徐州服务器租用:高防服务器的用途有哪些?
  • 在 MyBatis 中进行一对多的连表子查询
  • thinkphp8 定时任务 addOption
  • leetcode 数组+哈希+双指针+子串+滑动窗口
  • 网络安全 DVWA通关指南 DVWA File Upload(文件上传)
  • 华为手机换ip地址怎么换?手机换ip地址有什么影响
  • 前端宝典十八:高频算法排序之冒泡、插入、选择、归并和快速
  • 利用网络爬虫获取数据的刑事责任分析
  • FPGA在医疗方面的应用
  • 【Leetcode】101. 对称二叉树
  • 30天自制操作系统-2
  • Akka系列(七):Actor持久化之Akka persistence
  • Brief introduction of how to 'Call, Apply and Bind'
  • C++类的相互关联
  • ES10 特性的完整指南
  • GitUp, 你不可错过的秀外慧中的git工具
  • NSTimer学习笔记
  • PHP 7 修改了什么呢 -- 2
  • vue--为什么data属性必须是一个函数
  • 半理解系列--Promise的进化史
  • 规范化安全开发 KOA 手脚架
  • 理清楚Vue的结构
  • 配置 PM2 实现代码自动发布
  • 前端存储 - localStorage
  • 做一名精致的JavaScripter 01:JavaScript简介
  • 《码出高效》学习笔记与书中错误记录
  • 整理一些计算机基础知识!
  • # windows 安装 mysql 显示 no packages found 解决方法
  • #QT(串口助手-界面)
  • (03)光刻——半导体电路的绘制
  • (7)摄像机和云台
  • (CVPRW,2024)可学习的提示:遥感领域小样本语义分割
  • (附源码)spring boot火车票售卖系统 毕业设计 211004
  • (附源码)springboot家庭装修管理系统 毕业设计 613205
  • (论文阅读26/100)Weakly-supervised learning with convolutional neural networks
  • (每日持续更新)jdk api之FileReader基础、应用、实战
  • (入门自用)--C++--抽象类--多态原理--虚表--1020
  • (学习日记)2024.03.25:UCOSIII第二十二节:系统启动流程详解
  • (一)认识微服务
  • (原创)可支持最大高度的NestedScrollView
  • (原創) 如何刪除Windows Live Writer留在本機的文章? (Web) (Windows Live Writer)
  • (转)visual stdio 书签功能介绍
  • ****** 二十三 ******、软设笔记【数据库】-数据操作-常用关系操作、关系运算
  • *算法训练(leetcode)第三十九天 | 115. 不同的子序列、583. 两个字符串的删除操作、72. 编辑距离
  • ./configure、make、make install 命令
  • .Net FrameWork总结
  • .NET IoC 容器(三)Autofac
  • .Net Redis的秒杀Dome和异步执行
  • .Net8 Blazor 尝鲜