当前位置: 首页 > news >正文

ChatGLM-6B 主要代码分析 RotaryEmbedding

ChatGLM-6B 主要代码分析 RotaryEmbedding

flyfish
在这里插入图片描述

图片链接地址

传统的 Transformer 位置编码(Positional Encoding)被称为绝对位置编码 ,而 Rotary Embedding 被称为相对位置编码 ,主要是因为它们编码位置信息的方式不同,进而影响模型对序列中元素之间位置关系的理解。

1. 传统 Transformer 位置编码:绝对位置编码

在传统的 Transformer 模型中,位置编码使用正弦和余弦函数将每个位置 t t t 映射到一个固定的向量: P E ( t , 2 i ) = sin ⁡ ( t 1000 0 2 i / d ) PE(t, 2i) = \sin\left(\frac{t}{10000^{2i/d}}\right) PE(t,2i)=sin(100002i/dt)

P E ( t , 2 i + 1 ) = cos ⁡ ( t 1000 0 2 i / d ) PE(t, 2i+1) = \cos\left(\frac{t}{10000^{2i/d}}\right) PE(t,2i+1)=cos(100002i/dt)
其中, t t t 是序列中的位置索引, i i i 是维度索引, d d d 是嵌入维度。

特点:
  • 固定位置编码 :每个位置 t t t 的编码是固定的,无论它出现在序列的哪个部分,其编码都是由位置 t t t 唯一确定的。

  • 不变性 :这种编码方式不会随着序列的变化而变化,意味着同一位置的编码在每次出现时都是相同的。

绝对性:
  • 绝对位置感知 :由于位置编码与序列中的具体位置 t t t 紧密关联,模型在训练时会将这些编码与特定的序列模式联系起来。这种方式能够让模型感知到序列中每个元素的绝对位置,但对元素之间的相对位置(如相对距离)缺乏直接的建模能力。

  • 难以处理相对位置信息 :在绝对位置编码下,如果需要感知两个元素之间的相对距离或关系,模型必须通过训练学习到这些关系,而不是通过位置编码直接得到。

2. Rotary Embedding:相对位置编码

Rotary Embedding 的核心思想是通过旋转操作,将位置信息嵌入到序列的每个元素中,从而使模型能够自然地感知到序列中元素之间的相对位置关系。

工作原理:
  1. 旋转矩阵 :Rotary Embedding 将位置信息与特征向量通过旋转矩阵结合。假设 x 1 x_1 x1 x 2 x_2 x2 是在位置 t t t t + 1 t+1 t+1 的特征向量,那么旋转操作后的位置编码变换为: R ( θ ) ⋅ x = [ cos ⁡ ( θ ) − sin ⁡ ( θ ) sin ⁡ ( θ ) cos ⁡ ( θ ) ] ⋅ [ x 1 x 2 ] R(\theta) \cdot x = \begin{bmatrix} \cos(\theta) & -\sin(\theta) \\ \sin(\theta) & \cos(\theta) \end{bmatrix} \cdot \begin{bmatrix} x_1 \\ x_2 \end{bmatrix} R(θ)x=[cos(θ)sin(θ)sin(θ)cos(θ)][x1x2]
    其中 θ \theta θ 是根据位置计算得到的旋转角度。

  2. 相对位置感知 :当两个位置 t t t t + 1 t+1 t+1 的特征向量进行旋转变换时,模型可以通过旋转角度的差异自然感知到这两个位置之间的相对关系,而无需依赖绝对位置编码。

相对性:
  • 相对位置感知 :Rotary Embedding 通过旋转矩阵直接捕捉相邻元素之间的相对位置信息。例如,元素 x 1 x_1 x1 x 2 x_2 x2 在相邻位置 t t t t + 1 t+1 t+1 之间的相对关系可以通过旋转角度的差异直接表达。

  • 位置编码灵活性 :由于旋转矩阵使得位置编码可以灵活变化,因此模型能够更自然地处理不同长度的序列和不同的相对位置关系。

3. 绝对 vs. 相对位置编码

  • 绝对位置编码 (传统 Transformer):编码固定,适合处理具体位置相关的任务,但难以直接处理相对位置关系。

  • 相对位置编码 (Rotary Embedding):编码与序列中的相对位置变化相关,更加灵活,适合处理长序列和需要相对位置信息的任务。

Rotary Embedding 与传统位置编码的比较

特点传统位置编码 (Positional Encoding)Rotary Embedding
编码方式正弦和余弦函数的绝对位置编码旋转矩阵的相对位置编码
位置关系只能表示绝对位置更好地表示相对位置
对长序列的处理长序列时可能失效能够有效处理长序列
模型适应性需要在训练期间观察到所有可能位置更具扩展性,适应超长序列
应用场景适用于大多数任务尤其适用于需要处理长序列和复杂依赖关系的任务
import torch
class RotaryEmbedding(torch.nn.Module):def __init__(self, dim, base=10000, precision=torch.half, learnable=False):super().__init__()inv_freq = 1. / (base ** (torch.arange(0, dim, 2).float() / dim))inv_freq = inv_freq.half()self.learnable = learnableif learnable:self.inv_freq = torch.nn.Parameter(inv_freq)self.max_seq_len_cached = Noneelse:self.register_buffer('inv_freq', inv_freq)self.max_seq_len_cached = Noneself.cos_cached = Noneself.sin_cached = Noneself.precision = precisiondef _load_from_state_dict(self, state_dict, prefix, local_metadata, strict, missing_keys, unexpected_keys,error_msgs):passdef forward(self, x, seq_dim=1, seq_len=None):if seq_len is None:seq_len = x.shape[seq_dim]if self.max_seq_len_cached is None or (seq_len > self.max_seq_len_cached):self.max_seq_len_cached = None if self.learnable else seq_lent = torch.arange(seq_len, device=x.device, dtype=self.inv_freq.dtype)freqs = torch.einsum('i,j->ij', t, self.inv_freq)# Different from paper, but it uses a different permutation in order to obtain the same calculationemb = torch.cat((freqs, freqs), dim=-1).to(x.device)if self.precision == torch.bfloat16:emb = emb.float()# [sx, 1 (b * np), hn]cos_cached = emb.cos()[:, None, :]sin_cached = emb.sin()[:, None, :]if self.precision == torch.bfloat16:cos_cached = cos_cached.bfloat16()sin_cached = sin_cached.bfloat16()if self.learnable:return cos_cached, sin_cachedself.cos_cached, self.sin_cached = cos_cached, sin_cachedreturn self.cos_cached[:seq_len, ...], self.sin_cached[:seq_len, ...]def _apply(self, fn):if self.cos_cached is not None:self.cos_cached = fn(self.cos_cached)if self.sin_cached is not None:self.sin_cached = fn(self.sin_cached)return super()._apply(fn)# 初始化 RotaryEmbedding 模块
dim = 64  # 嵌入维度
rotary_emb = RotaryEmbedding(dim=dim)# 模拟输入张量
batch_size = 2
seq_len = 10
embedding_dim = dim
x = torch.randn(batch_size, seq_len, embedding_dim)# 调用 forward 方法
cos, sin = rotary_emb(x)# 输出 cos 和 sin 的形状
print("Cosine Embedding Shape:", cos.shape)
print("Sine Embedding Shape:", sin.shape)

输出

Cosine Embedding Shape: torch.Size([10, 1, 64])
Sine Embedding Shape: torch.Size([10, 1, 64])

Rotary Embedding 的设计思想是将位置编码嵌入到一个旋转的向量空间中,从而为序列建模提供更强的相对位置感知能力。

1. 三角函数基础

三角函数 cossin 描述了一个角度在单位圆上的投影,定义如下: cos ⁡ ( θ ) = 邻边 斜边 , sin ⁡ ( θ ) = 对边 斜边 \cos(\theta) = \frac{\text{邻边}}{\text{斜边}}, \quad \sin(\theta) = \frac{\text{对边}}{\text{斜边}} cos(θ)=斜边邻边,sin(θ)=斜边对边
这些函数具有周期性,对于任何角度 θ \theta θ,都有以下性质: cos ⁡ ( θ + 2 π ) = cos ⁡ ( θ ) , sin ⁡ ( θ + 2 π ) = sin ⁡ ( θ ) \cos(\theta + 2\pi) = \cos(\theta), \quad \sin(\theta + 2\pi) = \sin(\theta) cos(θ+2π)=cos(θ),sin(θ+2π)=sin(θ)

2. 位置编码(Positional Encoding)

在传统的 Transformer 模型中,位置编码通过 sincos 函数来表示输入序列中的位置信息。对于一个给定的位置 t t t,对应的编码可以表示为: P E ( t , 2 i ) = sin ⁡ ( t 1000 0 2 i / d ) PE(t, 2i) = \sin\left(\frac{t}{10000^{2i/d}}\right) PE(t,2i)=sin(100002i/dt)

P E ( t , 2 i + 1 ) = cos ⁡ ( t 1000 0 2 i / d ) PE(t, 2i+1) = \cos\left(\frac{t}{10000^{2i/d}}\right) PE(t,2i+1)=cos(100002i/dt)
其中, t t t 是序列中的位置, i i i 是维度索引, d d d 是嵌入维度。这个编码方式保证了不同维度具有不同的频率,以便模型能够感知到位置的不同。

3. 旋转嵌入(Rotary Embedding)

Rotary Embedding 是一种改进的相对位置编码方法,其核心思想是将位置信息通过旋转矩阵嵌入到序列中的每个特征向量中。它通过以下步骤实现:

1). 逆频率生成
首先,生成一个逆频率向量 inv_freq inv_freq j = 1 base 2 j d \text{inv\_freq}_j = \frac{1}{\text{base}^{\frac{2j}{d}}} inv_freqj=based2j1
其中 base 通常取 10000,j 是维度索引,d 是嵌入维度。

2). 频率矩阵生成
接下来,计算频率矩阵 freqs,将逆频率与时间步长(即序列位置)相乘: freqs i , j = t i × inv_freq j \text{freqs}_{i,j} = t_i \times \text{inv\_freq}_j freqsi,j=ti×inv_freqj
其中 t i t_i ti 是序列位置。

3). 三角函数编码
频率矩阵的每个元素通过 cossin 进行编码,并合并为一个编码矩阵: emb = [ cos ⁡ ( freqs ) , sin ⁡ ( freqs ) ] \text{emb} = [\cos(\text{freqs}), \sin(\text{freqs})] emb=[cos(freqs),sin(freqs)]

4). 旋转变换
在旋转嵌入中,编码后的 cossin 矢量与输入向量进行旋转变换。给定一个输入向量 x x x 及其旋转矩阵 R ( θ ) R(\theta) R(θ) R ( θ ) ⋅ x = [ cos ⁡ ( θ ) − sin ⁡ ( θ ) sin ⁡ ( θ ) cos ⁡ ( θ ) ] ⋅ [ x 1 x 2 ] R(\theta) \cdot x = \begin{bmatrix} \cos(\theta) & -\sin(\theta) \\ \sin(\theta) & \cos(\theta) \end{bmatrix} \cdot \begin{bmatrix} x_1 \\ x_2 \end{bmatrix} R(θ)x=[cos(θ)sin(θ)sin(θ)cos(θ)][x1x2]

相关文章:

  • 北京网站建设多少钱?
  • 辽宁网页制作哪家好_网站建设
  • 高端品牌网站建设_汉中网站制作
  • vulnhub靶机 DC-9(渗透测试详解)
  • 顺丰科技25届秋季校园招聘常见问题答疑及校招网申测评笔试题型分析SHL题库Verify测评
  • IO器件性能评估
  • 刷刷前端手写题
  • 理解JavaScript的基本概念和语法:让网页动起来
  • 【笔记】Android 多用户模式和用户类型
  • Codeforces Round 965 (Div. 2)
  • 如何对 GitLab 中文版进行升级?
  • 鸿蒙内核源码分析(进程管理篇) | 谁在管理内核资源?
  • cpu管理
  • Oracle(63)什么是临时表(Temporary Table)?
  • Dubbo,Zookeeper,NSF,Druid,CouchDB未授权访问漏洞(附带修复方法)
  • GORM 插入和批量插入操作介绍
  • EmguCV学习笔记 VB.Net 2.S 特别示例
  • 系统运维工程师学习路线
  • 【5+】跨webview多页面 触发事件(二)
  • Asm.js的简单介绍
  • CSS居中完全指南——构建CSS居中决策树
  • DataBase in Android
  • Java教程_软件开发基础
  • JS 面试题总结
  • js操作时间(持续更新)
  • Js实现点击查看全文(类似今日头条、知乎日报效果)
  • linux学习笔记
  • Promise初体验
  • seaborn 安装成功 + ImportError: DLL load failed: 找不到指定的模块 问题解决
  • vue中实现单选
  • 技术:超级实用的电脑小技巧
  • 检测对象或数组
  • 看域名解析域名安全对SEO的影响
  • 手机app有了短信验证码还有没必要有图片验证码?
  • 通过来模仿稀土掘金个人页面的布局来学习使用CoordinatorLayout
  • 深度学习之轻量级神经网络在TWS蓝牙音频处理器上的部署
  • 看到一个关于网页设计的文章分享过来!大家看看!
  • 教程:使用iPhone相机和openCV来完成3D重建(第一部分) ...
  • ​【数据结构与算法】冒泡排序:简单易懂的排序算法解析
  • ​马来语翻译中文去哪比较好?
  • #《AI中文版》V3 第 1 章 概述
  • #QT(TCP网络编程-服务端)
  • $GOPATH/go.mod exists but should not goland
  • (C语言)球球大作战
  • (亲测成功)在centos7.5上安装kvm,通过VNC远程连接并创建多台ubuntu虚拟机(ubuntu server版本)...
  • (亲测有效)推荐2024最新的免费漫画软件app,无广告,聚合全网资源!
  • (四十一)大数据实战——spark的yarn模式生产环境部署
  • (限时免费)震惊!流落人间的haproxy宝典被找到了!一切玄妙尽在此处!
  • (一)插入排序
  • (轉貼) 寄發紅帖基本原則(教育部禮儀司頒布) (雜項)
  • .NET 依赖注入和配置系统
  • .net6+aspose.words导出word并转pdf
  • .NetCore部署微服务(二)
  • .NET学习全景图
  • .py文件应该怎样打开?
  • [2009][note]构成理想导体超材料的有源THz欺骗表面等离子激元开关——
  • [AAuto]给百宝箱增加娱乐功能
  • [BZOJ] 2006: [NOI2010]超级钢琴