当前位置: 首页 > news >正文

Transformer模型 PostionEmbedding的实现

在Transformer模型中,Position Embedding(位置嵌入)是用来给模型提供关于序列中每个元素位置信息的机制。因为Transformer模型本身不依赖于循环结构(如RNN中的隐藏状态)或卷积结构(如CNN中的局部视野)来捕获序列中的顺序信息,所以需要额外的位置信息来帮助模型理解输入序列中元素的顺序。

位置嵌入的几种实现方式

1. 正弦和余弦函数(Sinusoidal Positional Encoding)

这是Transformer原始论文中采用的方法。它使用不同频率的正弦和余弦函数来计算位置嵌入。这种方法的优点是它可以允许模型学习到相对位置信息,因为对于任何固定的偏移量k,PE(pos+k)都可以表示为PE(pos)的线性函数。

具体公式如下:

PE(pos,2i)​=sin(10000dmodel​2i​pos​)

PE(pos,2i+1)​=cos(10000dmodel​2i​pos​)

其中,pos 是位置索引,i 是维度索引,d_model 是嵌入的维度。通过这种方式,每个维度对应一个正弦或余弦波,波长从10000 * 2π变化。

2. 可学习的位置嵌入(Learnable Positional Embeddings)

与正弦和余弦函数相比,另一种方法是直接初始化一个位置嵌入矩阵,并在训练过程中更新这个矩阵。这种方法通常被称为可学习的位置嵌入。在这种方法中,每个位置都会关联一个与词嵌入维度相同的向量,这些向量通过反向传播进行更新。

3. 相对位置嵌入

除了上述两种方法外,还有一些工作提出了相对位置嵌入的概念,它考虑的是序列中元素之间的相对位置关系,而不是绝对位置。这种方法通常需要对Transformer架构进行一些修改,例如改变自注意力机制中的注意力分数计算方式。

实现示例(正弦和余弦函数)

这里给出一个使用PyTorch实现正弦和余弦位置嵌入的示例:

 

python复制代码

import torch
import numpy as np
def positional_encoding(position, d_model):
"""
生成位置嵌入。
参数:
- position: 形状为 (batch_size, sequence_length) 的位置索引Tensor。
- d_model: 嵌入的维度。
返回:
- pos_encoding: 形状为 (batch_size, sequence_length, d_model) 的位置嵌入Tensor。
"""
# 创建一个形状为 (1, sequence_length, d_model) 的Tensor
pos_encoding = torch.zeros((1, position.size(1), d_model))
# 分成偶数索引和奇数索引的两组
position = position.unsqueeze(1).unsqueeze(2) # 形状: (batch_size, 1, 1, sequence_length)
# 偶数索引
div_term = torch.exp(torch.arange(0, d_model, 2).float() * (-np.log(10000.0) / d_model))
pos_encoding[:, :, 0::2] = torch.sin(position * div_term)
# 奇数索引
div_term = torch.exp(torch.arange(1, d_model, 2).float() * (-np.log(10000.0) / d_model))
pos_encoding[:, :, 1::2] = torch.cos(position * div_term)
# 调整形状为 (batch_size, sequence_length, d_model)
pos_encoding = pos_encoding.squeeze(0)
return pos_encoding
# 示例
batch_size, sequence_length = 1, 100
d_model = 512
position = torch.arange(0, sequence_length, dtype=torch.float).unsqueeze(0) # 形状: (1, sequence_length)
pos_encoding = positional_encoding(position, d_model)
print(pos_encoding.shape) # torch.Size([1, 100, 512])

这段代码演示了如何为序列中的每个位置生成位置嵌入。在实际应用中,这些位置嵌入会与词嵌入相加,然后一起作为Transformer模型的输入。

相关文章:

  • 北京网站建设多少钱?
  • 辽宁网页制作哪家好_网站建设
  • 高端品牌网站建设_汉中网站制作
  • @RequestBody与@RequestParam:Spring MVC中的参数接收差异解析
  • WPF ToolkitMVVM IOC IServiceConllection
  • ssrf+redis未授权访问漏洞复现
  • 【SpringCloud应用框架】GateWay网关
  • 【AI绘画】Midjourney前置/imagine与单图指令详解
  • 【递归深搜之记忆化搜索算法】
  • 缓存解决方案。Redis 和 Amazon ElastiCache 比较
  • 力扣top300:3. 无重复字符的最长子串
  • VMware安装中标麒麟操作系统V7.0
  • 无人机之云台的作用
  • 数字化转型升级探索(一)
  • Spring Cloud全解析:网关之GateWay断言
  • 基于FreeRTOS的STM32多功能手表
  • STM32-PWM驱动舵机——HAL库
  • Kafka 到数据仓库:使用 bend-ingest-kafka 将消息加载到 Databend
  • ESLint简单操作
  • Javascripit类型转换比较那点事儿,双等号(==)
  • jquery cookie
  • js正则,这点儿就够用了
  • Netty源码解析1-Buffer
  • NLPIR语义挖掘平台推动行业大数据应用服务
  • PAT A1017 优先队列
  • Shell编程
  • 成为一名优秀的Developer的书单
  • 对超线程几个不同角度的解释
  • 多线程 start 和 run 方法到底有什么区别?
  • 基于组件的设计工作流与界面抽象
  • 前端面试题总结
  • 浅析微信支付:申请退款、退款回调接口、查询退款
  • 如何将自己的网站分享到QQ空间,微信,微博等等
  • 使用Maven插件构建SpringBoot项目,生成Docker镜像push到DockerHub上
  • 适配mpvue平台的的微信小程序日历组件mpvue-calendar
  • 我与Jetbrains的这些年
  • 学习使用ExpressJS 4.0中的新Router
  • 在weex里面使用chart图表
  • ​ 全球云科技基础设施:亚马逊云科技的海外服务器网络如何演进
  • "无招胜有招"nbsp;史上最全的互…
  • (13):Silverlight 2 数据与通信之WebRequest
  • (175)FPGA门控时钟技术
  • (42)STM32——LCD显示屏实验笔记
  • (二)基于wpr_simulation 的Ros机器人运动控制,gazebo仿真
  • (附源码)c#+winform实现远程开机(广域网可用)
  • (剑指Offer)面试题34:丑数
  • (实测可用)(3)Git的使用——RT Thread Stdio添加的软件包,github与gitee冲突造成无法上传文件到gitee
  • (一)Java算法:二分查找
  • (一)python发送HTTP 请求的两种方式(get和post )
  • (原創) 博客園正式支援VHDL語法著色功能 (SOC) (VHDL)
  • (转) Android中ViewStub组件使用
  • (转)总结使用Unity 3D优化游戏运行性能的经验
  • .a文件和.so文件
  • .net 获取url的方法
  • .NET3.5下用Lambda简化跨线程访问窗体控件,避免繁复的delegate,Invoke(转)
  • .netcore 6.0/7.0项目迁移至.netcore 8.0 注意事项
  • .NetCore+vue3上传图片 Multipart body length limit 16384 exceeded.
  • .NET设计模式(2):单件模式(Singleton Pattern)