llama3 结构详解
文章目录
- 1. Llama3整体结构
- 2. Embeddings 模块
- 3. Transformer Block 模块
- 3.1 RMSNorm
- 3.2 Attention模块
- 2.3 FFN
- 3. RMSNorm 模块
- 4. Linear
1. Llama3整体结构
llama3 的整体结构还是延续transformer decoder 架构,其整体架构如下图左侧蓝色虚线框中所示。模型结构并不复杂,其主要组件为32个Transformer Block(32 为meta llama3 中的默认值)(见下图红色虚线框中所示)。
2. Embeddings 模块
llama3 的embedding 使用的是VocabParallelEmbedding这个类进行的向量转换,这个类是meta的fairscale包中的一个类,可以理解为对torch.nn.embedding做了并行化。
3. Transformer Block 模块
Transformer Block 模块是llama3的核心模块,或者说,llama3为Transformer Block模块堆叠而成。其结构如下图。我这里按模块结构拆成三部分来讲解,分别是RSMNorm, Attention,FFN
3.1 RMSNorm
RSMNorm 是在 layer normalization 基础上优化而来,所以先讲下layer normalization。这里直接引用下我在《Transformer(二)–论文理解:transformer 结构详解》关于layer normalization的解释。
不论是layer normalization还是batch normalization,其实做的都是一件事情,都是根据 x = a ∗ x − x ‾ s t d + e p s + b x = a * \frac{x - \overline{x}}{std + eps} + b x=a∗std+epsx−x+b对 x x x的分布进行调整。不同的是 x ‾ \overline{x} x和 s t d std std的计算方式不同。如下图:
RMSNorm做了什么优化呢,其实他对上面的试子 x = a ∗ x − x ‾ s t d + e p s + b x = a * \frac{x - \overline{x}}{std + eps} + b x=a∗std+epsx−x+b进行了简化。RMSNorm的计算公式如下:
a ‾ i = a i R M S ( a ) g i , w h e r e R M S ( a ) = 1 n Σ i = 1 n a i 2 \overline{a}_i=\frac{a_i}{RMS(a)}g_{i}, \quad where \quad RMS(a) = \sqrt{\frac{1}{n}\Sigma^n_{i=1}{a^{2}_{i}}} ai=RMS(a)aigi,whereRMS(a)=n1Σi=1nai2
从上式可以看出,RMSNorm移除了LayerNorm中的均值项(原式中的 x ‾ \overline{x} x项), s t d std std的计算中,也没有做减去均值的操作( s t d = 1 n Σ i = 1 n ( a i − a ‾ ) std=\sqrt{\frac{1}{n}\Sigma^n_{i=1}({a_i - \overline{a})}} std=n1Σi=1n(ai−a))。这种简化在计算效率是有一定提高,且原始论文也说了,没有在效果上有明显影响。
下面附上meta llama3中RMSNorm的源码,方便大家理解。
class RMSNorm(torch.nn.Module):def __init__(self, dim: int, eps: float = 1e-6):super().__init__()self.eps = epsself.weight = nn.Parameter(torch.ones(dim))def _norm(self, x):return x * torch.rsqrt(x.pow(2).mean(-1, keepdim=True) + self.eps)def forward(self, x):output = self._norm(x.float()).type_as(x)return output * self.weight
3.2 Attention模块
llama3中的attention模块与《Attention is all you need》中使用的attention技术有些许优化。同样是使用Scaled Dot-Product Attention来计算attention score,但分组优化这块没有延续使用MHA(Multi-head Attention)技术,而是使用了GQA(Grouped-Query Attention)分组技术。具体的Scaled Dot-Product Attention 与MHA我之前在《Transformer(二)–论文理解:transformer 结构详解》一文的2.2节中,已经写的非常详细了,所以这里不再展开,只讲解下GQA。
我们知道,在《Attention is all you need》一文中,作者为了提高计算效率,提出了MHA技术,思想是采用分而治之的策略,把K、Q、V 对应的切分为若干个短向量,然后使用Scaled Dot-Product Attention 计算出attention score后,再把结果拼接起来,从而避免了超大向量乘法的计算消耗,从而提高了计算效率。如下图所示。
然而,在MHA中,由于每个head都有独立的键和值,内存和计算成本较高,特别是在处理长序列或大批量数据时。然后就有大牛Noam Shazeer提出了MQA(Multi Query Attention)方法,将原来的h个KV对缩减为1个,所有query只使用一个共享的KV对,这种改造虽然大大减少了显存消耗,但其特征捕捉能力也受到影响。因此又提出了GQA(Grouped-Query Attention ), 将query 进行分组,每组共享一个KV对。下面是GQA原始论文中给出的对比图。
说了半天,其实在源码层次来就,就是在计算Scaled Dot-Product Attention之前对query进行个分组,组内共享一套Key和value。下面是meta llama3中的Attention类,方便大家理解。
class Attention(nn.Module):def __init__(self, args: ModelArgs):super().__init__()self.n_kv_heads = args.n_heads if args.n_kv_heads is None else args.n_kv_headsmodel_parallel_size = fs_init.get_model_parallel_world_size()...def forward(self,x: torch.Tensor,start_pos: int,freqs_cis: torch.Tensor,mask: Optional[torch.Tensor],):bsz, seqlen, _ = x.shapexq, xk, xv = self.wq(x), self.wk(x), self.wv(x)xq = xq.view(bsz, seqlen, self.n_local_heads, self.head_dim)xk = xk.view(bsz, seqlen, self.n_local_kv_heads, self.head_dim)xv = xv.view(bsz, seqlen, self.n_local_kv_heads, self.head_dim)xq, xk = apply_rotary_emb(xq, xk, freqs_cis=freqs_cis)self.cache_k = self.cache_k.to(xq)self.cache_v = self.cache_v.to(xq)self.cache_k[:bsz, start_pos : start_pos + seqlen] = xkself.cache_v[:bsz, start_pos : start_pos + seqlen] = xvkeys = self.cache_k[:bsz, : start_pos + seqlen]values = self.cache_v[:bsz, : start_pos + seqlen]# repeat k/v heads if n_kv_heads < n_headskeys = repeat_kv(keys, self.n_rep) # (bs, cache_len + seqlen, n_local_heads, head_dim)values = repeat_kv(values, self.n_rep) # (bs, cache_len + seqlen, n_local_heads, head_dim)xq = xq.transpose(1, 2) # (bs, n_local_heads, seqlen, head_dim)keys = keys.transpose(1, 2) # (bs, n_local_heads, cache_len + seqlen, head_dim)values = values.transpose(1, 2) # (bs, n_local_heads, cache_len + seqlen, head_dim)# 以下是Scaled Dot-Product Attention的计算scores = torch.matmul(xq, keys.transpose(2, 3)) / math.sqrt(self.head_dim)if mask is not None:scores = scores + mask # (bs, n_local_heads, seqlen, cache_len + seqlen)scores = F.softmax(scores.float(), dim=-1).type_as(xq)output = torch.matmul(scores, values) # (bs, n_local_heads, seqlen, head_dim)output = output.transpose(1, 2).contiguous().view(bsz, seqlen, -1)return self.wo(output)
2.3 FFN
由3个Linear组成的FeedForward网络,不再赘述。
3. RMSNorm 模块
同3.1 RMSNorm 模块
4. Linear
此模块的目的是把模型中 decoder的输出从 d m o d e l d_{model} dmodel维度映射到词表大小的维度。下面是meta llama中的linear层的初始化。
self.output = ColumnParallelLinear(params.dim, params.vocab_size, bias=False, init_method=lambda x: x)