当前位置: 首页 > news >正文

transformer_正余弦位置编码代码笔记

transformer_正余弦位置编码代码笔记

transformer输入的序列中,不同位置的相同词汇可能会表达不同的含义,通过考虑位置信息的不同来区分序列中不同位置的相同词汇。

位置编码有多种方式,此处仅记录正余弦位置编码

正余弦位置编码公式如下:
在这里插入图片描述
代码如下:

import numpy as np
import torchdef positional_encoding(seq_len, d_model):# 创建一个形状为(seq_len, 1)的数组,其中的值为[0, 1, 2, ... seq_len-1]position = np.arange(seq_len)[:, np.newaxis]  # 使用np.newaxis增加列上的维度,position矩阵为seq_len×1# 计算除数,这里的除数将用于计算正弦和余弦的频率,div_term矩阵为1×d_modeldiv_term = np.exp(np.arange(0, d_model, 2) * -(np.log(10000.0) / d_model))# 初始化位置编码矩阵为零,后续计算所有位置的位置编码并更新相对位置的初始化位置编码矩阵pe = np.zeros((seq_len, d_model))# 以下是针对偶数列使用正弦函数,奇数列使用余弦函数,最终输出的结果矩阵为seq_len×d_model# 对矩阵的偶数列机型正弦函数编码pe[:, 0::2] = np.sin(position * div_term)# 对矩阵的奇数列机型余弦函数编码pe[:, 1::2] = np.cos(position * div_term)# 返回位置编码矩阵,转换为PyTorch张量return torch.tensor(pe, dtype=torch.float32)if __name__ == '__main__':# 使用示例seq_len = 50  # 定义序列长度d_model = 512  # 定义模型的embedding维度pe = positional_encoding(seq_len, d_model)  # 获得位置编码print(pe)

实际使用时代码如下:

# forward the GPT model itself
# token的embedding
tok_emb = self.transformer.wte(idx) # token embeddings of shape (b, t, n_embd)
# 位置的embedding
pos_emb = self.transformer.wpe(pos) # position embeddings of shape (1, t, n_embd)
# 将token的embedding与位置得到embedding相加
x = self.transformer.drop(tok_emb + pos_emb)

相关文章:

  • 服务器为什么老是被攻击?被攻击了怎么办?
  • 十一、常用API——练习
  • 《HTML 简易速速上手小册》第8章:HTML 表单高级技术(2024 最新版)
  • 【云上建站】快速在云上构建个人网站3——网站选型和搭建
  • 用ASM HEMT模型提取GaN器件的参数
  • VUE!!!必看
  • ElementUI Form:Input 输入框
  • 消息中间件(MQ)对比:RabbitMQ、Kafka、ActiveMQ 和 RocketMQ
  • 本地socket通信
  • linux ping 某台服务的端口
  • C#中的WebApi响应Accept头,自动返回xml或者json
  • SpringBoot中异步方法的使用
  • 【Vue】vue项目中使用tinymce富文本组件(@tinymce/tinymce-vue)
  • KAFKA鉴权设计以及相关探讨
  • 2024第16届中国西部教体融合博览会在成渝双城举办
  • 「面试题」如何实现一个圣杯布局?
  • 【Amaple教程】5. 插件
  • CoolViewPager:即刻刷新,自定义边缘效果颜色,双向自动循环,内置垂直切换效果,想要的都在这里...
  • mysql 数据库四种事务隔离级别
  • Sublime text 3 3103 注册码
  • win10下安装mysql5.7
  • WordPress 获取当前文章下的所有附件/获取指定ID文章的附件(图片、文件、视频)...
  • 函数式编程与面向对象编程[4]:Scala的类型关联Type Alias
  • 理解在java “”i=i++;”所发生的事情
  • 树莓派 - 使用须知
  • Hibernate主键生成策略及选择
  • #android不同版本废弃api,新api。
  • #LLM入门|Prompt#1.7_文本拓展_Expanding
  • #调用传感器数据_Flink使用函数之监控传感器温度上升提醒
  • #绘制圆心_R语言——绘制一个诚意满满的圆 祝你2021圆圆满满
  • #我与虚拟机的故事#连载20:周志明虚拟机第 3 版:到底值不值得买?
  • %3cscript放入php,跟bWAPP学WEB安全(PHP代码)--XSS跨站脚本攻击
  • (1)(1.8) MSP(MultiWii 串行协议)(4.1 版)
  • (2021|NIPS,扩散,无条件分数估计,条件分数估计)无分类器引导扩散
  • (3)(3.2) MAVLink2数据包签名(安全)
  • (C语言)编写程序将一个4×4的数组进行顺时针旋转90度后输出。
  • (保姆级教程)Mysql中索引、触发器、存储过程、存储函数的概念、作用,以及如何使用索引、存储过程,代码操作演示
  • (层次遍历)104. 二叉树的最大深度
  • (附源码)SSM环卫人员管理平台 计算机毕设36412
  • (附源码)计算机毕业设计高校学生选课系统
  • (九)One-Wire总线-DS18B20
  • (十)c52学习之旅-定时器实验
  • (转贴)用VML开发工作流设计器 UCML.NET工作流管理系统
  • .Net 8.0 新的变化
  • .NET CF命令行调试器MDbg入门(二) 设备模拟器
  • .Net FrameWork总结
  • .net2005怎么读string形的xml,不是xml文件。
  • .Net6 Api Swagger配置
  • .NET国产化改造探索(三)、银河麒麟安装.NET 8环境
  • .net和jar包windows服务部署
  • @取消转义
  • [2016.7.Test1] T1 三进制异或
  • [2024] 十大免费电脑数据恢复软件——轻松恢复电脑上已删除文件
  • [AIGC] 开源流程引擎哪个好,如何选型?
  • [CISCN2019 华东北赛区]Web2