Llama3:全模型GQA与tiktoken分词的新突破
在本篇文章中,我们将介绍Llama3模型,并且对比它与Llama2在模型层面上的主要区别。Llama3 相较于Llama2的最显著变化是引入了全模型GQA(Grouped Query Attention)机制,并且在分词阶段使用了与GPT一致的 tiktoken
分词方式。
Llama3 和 Llama2 的模型层面区别
Llama3 相较于 Llama2 的主要区别在于其全模型使用了 GQA(Grouped Query Attention),这使得多头注意力机制中的键值对变得更加高效,减少了计算和内存开销。
模型参数定义
在实现 Llama3 时,我们使用了 Python 的 @dataclass
装饰器来定义模型的超参数。@dataclass
能够简化类的定义过程,自动生成构造函数 __init__()
,打印方法 __repr__()
,以及判断两个类是否相等的 __eq__()
。
代码示例:
@dataclass
class ModelArgs:dim: int = 4096 # 模型维度n_layers: int = 6 # 层数n_heads: int = 6 # 注意力头数n_group: Optional[int] = 3 # GQA组数vocab_size: int = 4096 # 词表大小hidden_dim: Optional[int] = None # 隐藏层维度multiple_of: int = 256 # MLP层隐层维度的计算因子norm_eps: float = 1e-5 # 正则化epsmax_seq_len: int = 2048 # 最大序列长度dropout: float = 0.0 # Dropout比率
RMS正则化
RMS正则化的原理已经在之前的 Qwen 文章中讲解过,Llama3 采用了同样的 RMSNorm 来实现层的标准化。
代码示例:
class RMSNorm(torch.nn.Module):def __init__(self, dim: int, eps: float):super().__init__()self.eps = epsself.weight = nn.Parameter(torch.ones(dim))def _norm(self, x):return x * torch.rsqrt(x.pow(2).mean(-1, keepdim=True) + self.eps)def forward(self, x):output = self._norm(x.float()).type_as(x)return output * self.weight
ROPE 相对位置嵌入
ROPE(Rotary Positional Embedding)的实现与 Qwen 模型类似,负责对自注意力的查询(Query)和键(Key)进行位置编码。
代码示例:
def apply_rotary_emb(xq: torch.Tensor, xk: torch.Tensor, freqs_cos: torch.Tensor, freqs_sin: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:xq_r, xq_i = xq.float().reshape(xq.shape[:-1] + (-1, 2)).unbind(-1)xk_r, xk_i = xk.float().reshape(xk.shape[:-1] + (-1, 2)).unbind(-1)freqs_cos = reshape_for_broadcast(freqs_cos, xq_r)freqs_sin = reshape_for_broadcast(freqs_sin, xq_r)xq_out_r = xq_r * freqs_cos - xq_i * freqs_sinxq_out_i = xq_r * freqs_sin + xq_i * freqs_cosxk_out_r = xk_r * freqs_cos - xk_i * freqs_sinxk_out_i = xk_r * freqs_sin + xk_i * freqs_cosxq_out = torch.stack([xq_out_r, xq_out_i], dim=-1).flatten(3)xk_out = torch.stack([xk_out_r, xk_out_i], dim=-1).flatten(3)return xq_out.type_as(xq), xk_out.type_as(xk)
Grouped Query Attention (GQA)
在 Llama3 中,Attention 模块使用了 GQA 机制,这意味着每组注意力头共享相同的键和值,这种方法减少了计算开销。
代码示例:
class Attention(nn.Module):def __init__(self, args: ModelArgs):super().__init__()self.group = args.n_groupself.heads = args.n_headsself.kv_heads = args.n_heads // args.n_groupself.head_dim = args.dim // args.n_headsself.wq = nn.Linear(args.dim, args.n_heads * self.head_dim, bias=False)self.wk = nn.Linear(args.dim, self.kv_heads * self.head_dim, bias=False)self.wv = nn.Linear(args.dim, self.kv_heads * self.head_dim, bias=False)self.wo = nn.Linear(args.n_heads * self.head_dim, args.dim, bias=False)self.attn_dropout = nn.Dropout(args.dropout)self.resid_dropout = nn.Dropout(args.dropout)def forward(self, x: torch.Tensor, freqs_cos: torch.Tensor, freqs_sin: torch.Tensor):xq, xk, xv = self.wq(x), self.wk(x), self.wv(x)xq = xq.view(-1, xq.size(1), self.heads, self.head_dim)xk = xk.view(-1, xk.size(1), self.kv_heads, self.head_dim)xv = xv.view(-1, xv.size(1), self.kv_heads, self.head_dim)xq, xk = apply_rotary_emb(xq, xk, freqs_cos, freqs_sin)xk = repeat_kv(xk, self.group)xv = repeat_kv(xv, self.group)scores = torch.matmul(xq, xk.transpose(-1, -2)) / math.sqrt(self.head_dim)scores = torch.softmax(scores, dim=-1)output = torch.matmul(scores, xv)output = output.transpose(1, 2).contiguous().view(-1, x.size(1), self.heads * self.head_dim)return self.wo(output)
FeedForward 模块
Llama3 的 MLP 模块通过线性变换、激活函数和 Dropout 组成,与 Qwen 模型一致。
代码示例:
class FeedForward(nn.Module):def __init__(self, dim: int, hidden_dim: int, multiple_of: int, dropout: float):super().__init__()self.w1 = nn.Linear(dim, hidden_dim, bias=False)self.w2 = nn.Linear(hidden_dim, dim, bias=False)self.w3 = nn.Linear(dim, hidden_dim, bias=False)self.dropout = nn.Dropout(dropout)def forward(self, x):return self.dropout(self.w2(F.silu(self.w1(x)) * self.w3(x)))
TransformerBlock: 将模块组合成完整层
Llama3 的 Transformer 层由 Attention、FeedForward、RMSNorm 等模块组成,通过多层堆叠构建模型。
代码示例:
class TransformerBlock(nn.Module):def __init__(self, layer_id: int, args: ModelArgs):super().__init__()self.attention = Attention(args)self.feed_forward = FeedForward(dim=args.dim, hidden_dim=args.hidden_dim, multiple_of=args.multiple_of, dropout=args.dropout)self.attention_norm = RMSNorm(args.dim, eps=args.norm_eps)self.ffn_norm = RMSNorm(args.dim, eps=args.norm_eps)def forward(self, x, freqs_cos, freqs_sin):h = x + self.attention.forward(self.attention_norm(x), freqs_cos, freqs_sin)out = h + self.feed_forward.forward(self.ffn_norm(h))return out
Transformer模型:完整的 Llama3 实现
代码示例:
class Transformer(nn.Module):def __init__(self, params: ModelArgs):super().__init__()self.tok_embeddings = nn.Embedding(params.vocab_size, params.dim)self.layers = nn.ModuleList([TransformerBlock(i, params) for i in range(params.n_layers)])self.norm = RMSNorm(params.dim, eps=params.norm_eps)self.output = nn.Linear(params.dim, params.vocab_size, bias=False)self.tok_embeddings.weight = self.output.weightfreqs_cos, freqs_sin = precompute_freqs_cis(params.dim // params.n_heads, params.max_seq_len)self.register_buffer("freqs_cos", freqs_cos, persistent=False)self.register_buffer("freqs_sin", freqs_sin, persistent=False)def forward(self, tokens: torch.Tensor, targets: Optional[torch.Tensor] = None) -> torch.Tensor:h = self.tok_embeddings(tokens)for layer in self.layers:h = layer(h, self.freqs_cos[:h.size(1)], self.freqs_sin[:h.size(1)])h = self.norm(h)return self.output(h)
结语
通过本篇文章,我们学习了如何从零开始预训练Llama3模型,并认识了它与Llama2在模型结构上的主要区别。Llama3的引入GQA机制大幅提升了模型的推理效率,同时结合
tiktoken
的分词方式,使其在处理文本任务时更具优势。后续,我们将进一步更新关于数据预处理和模型优化的相关教程。
如果你觉得这篇博文对你有帮助,请点赞、收藏、关注我,并且可以打赏支持我!
欢迎关注我的后续博文,我将分享更多关于人工智能、自然语言处理和计算机视觉的精彩内容。
谢谢大家的支持!