当前位置: 首页 > news >正文

llama3 结构详解

文章目录

  • 1. Llama3整体结构
  • 2. Embeddings 模块
  • 3. Transformer Block 模块
    • 3.1 RMSNorm
    • 3.2 Attention模块
    • 2.3 FFN
  • 3. RMSNorm 模块
  • 4. Linear

1. Llama3整体结构

  llama3 的整体结构还是延续transformer decoder 架构,其整体架构如下图左侧蓝色虚线框中所示。模型结构并不复杂,其主要组件为32个Transformer Block(32 为meta llama3 中的默认值)(见下图红色虚线框中所示)。
在这里插入图片描述

2. Embeddings 模块

  llama3 的embedding 使用的是VocabParallelEmbedding这个类进行的向量转换,这个类是meta的fairscale包中的一个类,可以理解为对torch.nn.embedding做了并行化。

3. Transformer Block 模块

  Transformer Block 模块是llama3的核心模块,或者说,llama3为Transformer Block模块堆叠而成。其结构如下图。我这里按模块结构拆成三部分来讲解,分别是RSMNorm, Attention,FFN

在这里插入图片描述

3.1 RMSNorm

  RSMNorm 是在 layer normalization 基础上优化而来,所以先讲下layer normalization。这里直接引用下我在《Transformer(二)–论文理解:transformer 结构详解》关于layer normalization的解释。

  不论是layer normalization还是batch normalization,其实做的都是一件事情,都是根据 x = a ∗ x − x ‾ s t d + e p s + b x = a * \frac{x - \overline{x}}{std + eps} + b x=astd+epsxx+b x x x的分布进行调整。不同的是 x ‾ \overline{x} x s t d std std的计算方式不同。如下图:
在这里插入图片描述

  RMSNorm做了什么优化呢,其实他对上面的试子 x = a ∗ x − x ‾ s t d + e p s + b x = a * \frac{x - \overline{x}}{std + eps} + b x=astd+epsxx+b进行了简化。RMSNorm的计算公式如下:
a ‾ i = a i R M S ( a ) g i , w h e r e R M S ( a ) = 1 n Σ i = 1 n a i 2 \overline{a}_i=\frac{a_i}{RMS(a)}g_{i}, \quad where \quad RMS(a) = \sqrt{\frac{1}{n}\Sigma^n_{i=1}{a^{2}_{i}}} ai=RMS(a)aigi,whereRMS(a)=n1Σi=1nai2

  从上式可以看出,RMSNorm移除了LayerNorm中的均值项(原式中的 x ‾ \overline{x} x项), s t d std std的计算中,也没有做减去均值的操作( s t d = 1 n Σ i = 1 n ( a i − a ‾ ) std=\sqrt{\frac{1}{n}\Sigma^n_{i=1}({a_i - \overline{a})}} std=n1Σi=1n(aia) )。这种简化在计算效率是有一定提高,且原始论文也说了,没有在效果上有明显影响。

下面附上meta llama3中RMSNorm的源码,方便大家理解。

class RMSNorm(torch.nn.Module):def __init__(self, dim: int, eps: float = 1e-6):super().__init__()self.eps = epsself.weight = nn.Parameter(torch.ones(dim))def _norm(self, x):return x * torch.rsqrt(x.pow(2).mean(-1, keepdim=True) + self.eps)def forward(self, x):output = self._norm(x.float()).type_as(x)return output * self.weight

3.2 Attention模块

  llama3中的attention模块与《Attention is all you need》中使用的attention技术有些许优化。同样是使用Scaled Dot-Product Attention来计算attention score,但分组优化这块没有延续使用MHA(Multi-head Attention)技术,而是使用了GQA(Grouped-Query Attention)分组技术。具体的Scaled Dot-Product Attention 与MHA我之前在《Transformer(二)–论文理解:transformer 结构详解》一文的2.2节中,已经写的非常详细了,所以这里不再展开,只讲解下GQA。
  我们知道,在《Attention is all you need》一文中,作者为了提高计算效率,提出了MHA技术,思想是采用分而治之的策略,把K、Q、V 对应的切分为若干个短向量,然后使用Scaled Dot-Product Attention 计算出attention score后,再把结果拼接起来,从而避免了超大向量乘法的计算消耗,从而提高了计算效率。如下图所示。
在这里插入图片描述

  然而,在MHA中,由于每个head都有独立的键和值,内存和计算成本较高,特别是在处理长序列或大批量数据时。然后就有大牛Noam Shazeer提出了MQA(Multi Query Attention)方法,将原来的h个KV对缩减为1个,所有query只使用一个共享的KV对,这种改造虽然大大减少了显存消耗,但其特征捕捉能力也受到影响。因此又提出了GQA(Grouped-Query Attention ), 将query 进行分组,每组共享一个KV对。下面是GQA原始论文中给出的对比图。
在这里插入图片描述
  说了半天,其实在源码层次来就,就是在计算Scaled Dot-Product Attention之前对query进行个分组,组内共享一套Key和value。下面是meta llama3中的Attention类,方便大家理解。

class Attention(nn.Module):def __init__(self, args: ModelArgs):super().__init__()self.n_kv_heads = args.n_heads if args.n_kv_heads is None else args.n_kv_headsmodel_parallel_size = fs_init.get_model_parallel_world_size()...def forward(self,x: torch.Tensor,start_pos: int,freqs_cis: torch.Tensor,mask: Optional[torch.Tensor],):bsz, seqlen, _ = x.shapexq, xk, xv = self.wq(x), self.wk(x), self.wv(x)xq = xq.view(bsz, seqlen, self.n_local_heads, self.head_dim)xk = xk.view(bsz, seqlen, self.n_local_kv_heads, self.head_dim)xv = xv.view(bsz, seqlen, self.n_local_kv_heads, self.head_dim)xq, xk = apply_rotary_emb(xq, xk, freqs_cis=freqs_cis)self.cache_k = self.cache_k.to(xq)self.cache_v = self.cache_v.to(xq)self.cache_k[:bsz, start_pos : start_pos + seqlen] = xkself.cache_v[:bsz, start_pos : start_pos + seqlen] = xvkeys = self.cache_k[:bsz, : start_pos + seqlen]values = self.cache_v[:bsz, : start_pos + seqlen]# repeat k/v heads if n_kv_heads < n_headskeys = repeat_kv(keys, self.n_rep)  # (bs, cache_len + seqlen, n_local_heads, head_dim)values = repeat_kv(values, self.n_rep)  # (bs, cache_len + seqlen, n_local_heads, head_dim)xq = xq.transpose(1, 2)  # (bs, n_local_heads, seqlen, head_dim)keys = keys.transpose(1, 2)  # (bs, n_local_heads, cache_len + seqlen, head_dim)values = values.transpose(1, 2)  # (bs, n_local_heads, cache_len + seqlen, head_dim)# 以下是Scaled Dot-Product Attention的计算scores = torch.matmul(xq, keys.transpose(2, 3)) / math.sqrt(self.head_dim)if mask is not None:scores = scores + mask  # (bs, n_local_heads, seqlen, cache_len + seqlen)scores = F.softmax(scores.float(), dim=-1).type_as(xq)output = torch.matmul(scores, values)  # (bs, n_local_heads, seqlen, head_dim)output = output.transpose(1, 2).contiguous().view(bsz, seqlen, -1)return self.wo(output)

2.3 FFN

  由3个Linear组成的FeedForward网络,不再赘述。

3. RMSNorm 模块

同3.1 RMSNorm 模块

4. Linear

  此模块的目的是把模型中 decoder的输出从 d m o d e l d_{model} dmodel维度映射到词表大小的维度。下面是meta llama中的linear层的初始化。

 self.output = ColumnParallelLinear(params.dim, params.vocab_size, bias=False, init_method=lambda x: x)

相关文章:

  • 北京网站建设多少钱?
  • 辽宁网页制作哪家好_网站建设
  • 高端品牌网站建设_汉中网站制作
  • Upload-Lab第12关:如何巧妙利用%00截断法绕过上传验证
  • linux 改文件夹所有者
  • Git工具练习网站
  • 【k8s从节点报错】error: You must be logged in to the server (Unauthorized)
  • Oracle RAC vs Clusterware vs ASM
  • 【Linux系列】telnet使用入门
  • 基于Mybatis 数据过滤组件(二) - 使用文档
  • web技术1——http详解(重要)
  • 兼容并蓄,高效集成:EasyCVR视频综合接入能力助力多元化项目需求
  • Fragment学习笔记
  • 数组前缀和算法技巧
  • html+css网页设计 淘宝首页
  • 数据处理二维数组转单数组
  • 免费商用字体下载指南!(哪里可以免费下载字体,哪里可以免费下载可商用字体)
  • C++ 模版进阶【非类型模板参数、模板特化等】
  • Angular 响应式表单 基础例子
  • CODING 缺陷管理功能正式开始公测
  • Consul Config 使用Git做版本控制的实现
  • Date型的使用
  • Docker: 容器互访的三种方式
  • js算法-归并排序(merge_sort)
  • MySQL用户中的%到底包不包括localhost?
  • MySQL主从复制读写分离及奇怪的问题
  • Python 反序列化安全问题(二)
  • Shell编程
  • 大整数乘法-表格法
  • 关于使用markdown的方法(引自CSDN教程)
  • 力扣(LeetCode)357
  • 前端存储 - localStorage
  • 前端每日实战 2018 年 7 月份项目汇总(共 29 个项目)
  • 微服务核心架构梳理
  • 为物联网而生:高性能时间序列数据库HiTSDB商业化首发!
  • 我的面试准备过程--容器(更新中)
  • 一加3T解锁OEM、刷入TWRP、第三方ROM以及ROOT
  • #Lua:Lua调用C++生成的DLL库
  • #绘制圆心_R语言——绘制一个诚意满满的圆 祝你2021圆圆满满
  • (1)(1.11) SiK Radio v2(一)
  • (9)YOLO-Pose:使用对象关键点相似性损失增强多人姿态估计的增强版YOLO
  • (C)一些题4
  • (Mirage系列之二)VMware Horizon Mirage的经典用户用例及真实案例分析
  • (补充):java各种进制、原码、反码、补码和文本、图像、音频在计算机中的存储方式
  • (三)uboot源码分析
  • (学习日记)2024.01.09
  • (转)Google的Objective-C编码规范
  • (转)拼包函数及网络封包的异常处理(含代码)
  • (转)自己动手搭建Nginx+memcache+xdebug+php运行环境绿色版 For windows版
  • (转载)从 Java 代码到 Java 堆
  • ***检测工具之RKHunter AIDE
  • .net core 的缓存方案
  • .NET Core 实现 Redis 批量查询指定格式的Key
  • .NET 使用 ILRepack 合并多个程序集(替代 ILMerge),避免引入额外的依赖
  • .NET(C#) Internals: as a developer, .net framework in my eyes
  • .NET8 动态添加定时任务(CRON Expression, Whatever)
  • /bin/bash^M: bad interpreter: No such file or directory
  • @for /l %i in (1,1,10) do md %i 批处理自动建立目录