Transformer模型 PostionEmbedding的实现
在Transformer模型中,Position Embedding(位置嵌入)是用来给模型提供关于序列中每个元素位置信息的机制。因为Transformer模型本身不依赖于循环结构(如RNN中的隐藏状态)或卷积结构(如CNN中的局部视野)来捕获序列中的顺序信息,所以需要额外的位置信息来帮助模型理解输入序列中元素的顺序。
位置嵌入的几种实现方式
1. 正弦和余弦函数(Sinusoidal Positional Encoding)
这是Transformer原始论文中采用的方法。它使用不同频率的正弦和余弦函数来计算位置嵌入。这种方法的优点是它可以允许模型学习到相对位置信息,因为对于任何固定的偏移量k,PE(pos+k)
都可以表示为PE(pos)
的线性函数。
具体公式如下:
PE(pos,2i)=sin(10000dmodel2ipos)
PE(pos,2i+1)=cos(10000dmodel2ipos)
其中,pos
是位置索引,i
是维度索引,d_model
是嵌入的维度。通过这种方式,每个维度对应一个正弦或余弦波,波长从2π
到10000 * 2π
变化。
2. 可学习的位置嵌入(Learnable Positional Embeddings)
与正弦和余弦函数相比,另一种方法是直接初始化一个位置嵌入矩阵,并在训练过程中更新这个矩阵。这种方法通常被称为可学习的位置嵌入。在这种方法中,每个位置都会关联一个与词嵌入维度相同的向量,这些向量通过反向传播进行更新。
3. 相对位置嵌入
除了上述两种方法外,还有一些工作提出了相对位置嵌入的概念,它考虑的是序列中元素之间的相对位置关系,而不是绝对位置。这种方法通常需要对Transformer架构进行一些修改,例如改变自注意力机制中的注意力分数计算方式。
实现示例(正弦和余弦函数)
这里给出一个使用PyTorch实现正弦和余弦位置嵌入的示例:
python复制代码
import torch | |
import numpy as np | |
def positional_encoding(position, d_model): | |
""" | |
生成位置嵌入。 | |
参数: | |
- position: 形状为 (batch_size, sequence_length) 的位置索引Tensor。 | |
- d_model: 嵌入的维度。 | |
返回: | |
- pos_encoding: 形状为 (batch_size, sequence_length, d_model) 的位置嵌入Tensor。 | |
""" | |
# 创建一个形状为 (1, sequence_length, d_model) 的Tensor | |
pos_encoding = torch.zeros((1, position.size(1), d_model)) | |
# 分成偶数索引和奇数索引的两组 | |
position = position.unsqueeze(1).unsqueeze(2) # 形状: (batch_size, 1, 1, sequence_length) | |
# 偶数索引 | |
div_term = torch.exp(torch.arange(0, d_model, 2).float() * (-np.log(10000.0) / d_model)) | |
pos_encoding[:, :, 0::2] = torch.sin(position * div_term) | |
# 奇数索引 | |
div_term = torch.exp(torch.arange(1, d_model, 2).float() * (-np.log(10000.0) / d_model)) | |
pos_encoding[:, :, 1::2] = torch.cos(position * div_term) | |
# 调整形状为 (batch_size, sequence_length, d_model) | |
pos_encoding = pos_encoding.squeeze(0) | |
return pos_encoding | |
# 示例 | |
batch_size, sequence_length = 1, 100 | |
d_model = 512 | |
position = torch.arange(0, sequence_length, dtype=torch.float).unsqueeze(0) # 形状: (1, sequence_length) | |
pos_encoding = positional_encoding(position, d_model) | |
print(pos_encoding.shape) # torch.Size([1, 100, 512]) |
这段代码演示了如何为序列中的每个位置生成位置嵌入。在实际应用中,这些位置嵌入会与词嵌入相加,然后一起作为Transformer模型的输入。