当前位置: 首页 > news >正文

Llama3:全模型GQA与tiktoken分词的新突破

        在本篇文章中,我们将介绍Llama3模型,并且对比它与Llama2在模型层面上的主要区别。Llama3 相较于Llama2的最显著变化是引入了全模型GQA(Grouped Query Attention)机制,并且在分词阶段使用了与GPT一致的 tiktoken 分词方式。

Llama3 和 Llama2 的模型层面区别

        Llama3 相较于 Llama2 的主要区别在于其全模型使用了 GQA(Grouped Query Attention),这使得多头注意力机制中的键值对变得更加高效,减少了计算和内存开销。

模型参数定义

        在实现 Llama3 时,我们使用了 Python 的 @dataclass 装饰器来定义模型的超参数。@dataclass 能够简化类的定义过程,自动生成构造函数 __init__(),打印方法 __repr__(),以及判断两个类是否相等的 __eq__()

代码示例:
@dataclass
class ModelArgs:dim: int = 4096  # 模型维度n_layers: int = 6  # 层数n_heads: int = 6  # 注意力头数n_group: Optional[int] = 3  # GQA组数vocab_size: int = 4096  # 词表大小hidden_dim: Optional[int] = None  # 隐藏层维度multiple_of: int = 256  # MLP层隐层维度的计算因子norm_eps: float = 1e-5  # 正则化epsmax_seq_len: int = 2048  # 最大序列长度dropout: float = 0.0  # Dropout比率

RMS正则化

        RMS正则化的原理已经在之前的 Qwen 文章中讲解过,Llama3 采用了同样的 RMSNorm 来实现层的标准化。

代码示例:
class RMSNorm(torch.nn.Module):def __init__(self, dim: int, eps: float):super().__init__()self.eps = epsself.weight = nn.Parameter(torch.ones(dim))def _norm(self, x):return x * torch.rsqrt(x.pow(2).mean(-1, keepdim=True) + self.eps)def forward(self, x):output = self._norm(x.float()).type_as(x)return output * self.weight

ROPE 相对位置嵌入

        ROPE(Rotary Positional Embedding)的实现与 Qwen 模型类似,负责对自注意力的查询(Query)和键(Key)进行位置编码。

代码示例:
def apply_rotary_emb(xq: torch.Tensor, xk: torch.Tensor, freqs_cos: torch.Tensor, freqs_sin: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:xq_r, xq_i = xq.float().reshape(xq.shape[:-1] + (-1, 2)).unbind(-1)xk_r, xk_i = xk.float().reshape(xk.shape[:-1] + (-1, 2)).unbind(-1)freqs_cos = reshape_for_broadcast(freqs_cos, xq_r)freqs_sin = reshape_for_broadcast(freqs_sin, xq_r)xq_out_r = xq_r * freqs_cos - xq_i * freqs_sinxq_out_i = xq_r * freqs_sin + xq_i * freqs_cosxk_out_r = xk_r * freqs_cos - xk_i * freqs_sinxk_out_i = xk_r * freqs_sin + xk_i * freqs_cosxq_out = torch.stack([xq_out_r, xq_out_i], dim=-1).flatten(3)xk_out = torch.stack([xk_out_r, xk_out_i], dim=-1).flatten(3)return xq_out.type_as(xq), xk_out.type_as(xk)

Grouped Query Attention (GQA)

        在 Llama3 中,Attention 模块使用了 GQA 机制,这意味着每组注意力头共享相同的键和值,这种方法减少了计算开销。

代码示例:
class Attention(nn.Module):def __init__(self, args: ModelArgs):super().__init__()self.group = args.n_groupself.heads = args.n_headsself.kv_heads = args.n_heads // args.n_groupself.head_dim = args.dim // args.n_headsself.wq = nn.Linear(args.dim, args.n_heads * self.head_dim, bias=False)self.wk = nn.Linear(args.dim, self.kv_heads * self.head_dim, bias=False)self.wv = nn.Linear(args.dim, self.kv_heads * self.head_dim, bias=False)self.wo = nn.Linear(args.n_heads * self.head_dim, args.dim, bias=False)self.attn_dropout = nn.Dropout(args.dropout)self.resid_dropout = nn.Dropout(args.dropout)def forward(self, x: torch.Tensor, freqs_cos: torch.Tensor, freqs_sin: torch.Tensor):xq, xk, xv = self.wq(x), self.wk(x), self.wv(x)xq = xq.view(-1, xq.size(1), self.heads, self.head_dim)xk = xk.view(-1, xk.size(1), self.kv_heads, self.head_dim)xv = xv.view(-1, xv.size(1), self.kv_heads, self.head_dim)xq, xk = apply_rotary_emb(xq, xk, freqs_cos, freqs_sin)xk = repeat_kv(xk, self.group)xv = repeat_kv(xv, self.group)scores = torch.matmul(xq, xk.transpose(-1, -2)) / math.sqrt(self.head_dim)scores = torch.softmax(scores, dim=-1)output = torch.matmul(scores, xv)output = output.transpose(1, 2).contiguous().view(-1, x.size(1), self.heads * self.head_dim)return self.wo(output)

FeedForward 模块

        Llama3 的 MLP 模块通过线性变换、激活函数和 Dropout 组成,与 Qwen 模型一致。

代码示例:
class FeedForward(nn.Module):def __init__(self, dim: int, hidden_dim: int, multiple_of: int, dropout: float):super().__init__()self.w1 = nn.Linear(dim, hidden_dim, bias=False)self.w2 = nn.Linear(hidden_dim, dim, bias=False)self.w3 = nn.Linear(dim, hidden_dim, bias=False)self.dropout = nn.Dropout(dropout)def forward(self, x):return self.dropout(self.w2(F.silu(self.w1(x)) * self.w3(x)))

TransformerBlock: 将模块组合成完整层

        Llama3 的 Transformer 层由 Attention、FeedForward、RMSNorm 等模块组成,通过多层堆叠构建模型。

代码示例:
class TransformerBlock(nn.Module):def __init__(self, layer_id: int, args: ModelArgs):super().__init__()self.attention = Attention(args)self.feed_forward = FeedForward(dim=args.dim, hidden_dim=args.hidden_dim, multiple_of=args.multiple_of, dropout=args.dropout)self.attention_norm = RMSNorm(args.dim, eps=args.norm_eps)self.ffn_norm = RMSNorm(args.dim, eps=args.norm_eps)def forward(self, x, freqs_cos, freqs_sin):h = x + self.attention.forward(self.attention_norm(x), freqs_cos, freqs_sin)out = h + self.feed_forward.forward(self.ffn_norm(h))return out

Transformer模型:完整的 Llama3 实现

代码示例:
class Transformer(nn.Module):def __init__(self, params: ModelArgs):super().__init__()self.tok_embeddings = nn.Embedding(params.vocab_size, params.dim)self.layers = nn.ModuleList([TransformerBlock(i, params) for i in range(params.n_layers)])self.norm = RMSNorm(params.dim, eps=params.norm_eps)self.output = nn.Linear(params.dim, params.vocab_size, bias=False)self.tok_embeddings.weight = self.output.weightfreqs_cos, freqs_sin = precompute_freqs_cis(params.dim // params.n_heads, params.max_seq_len)self.register_buffer("freqs_cos", freqs_cos, persistent=False)self.register_buffer("freqs_sin", freqs_sin, persistent=False)def forward(self, tokens: torch.Tensor, targets: Optional[torch.Tensor] = None) -> torch.Tensor:h = self.tok_embeddings(tokens)for layer in self.layers:h = layer(h, self.freqs_cos[:h.size(1)], self.freqs_sin[:h.size(1)])h = self.norm(h)return self.output(h)

结语

        通过本篇文章,我们学习了如何从零开始预训练Llama3模型,并认识了它与Llama2在模型结构上的主要区别。Llama3的引入GQA机制大幅提升了模型的推理效率,同时结合 tiktoken 的分词方式,使其在处理文本任务时更具优势。后续,我们将进一步更新关于数据预处理和模型优化的相关教程。

如果你觉得这篇博文对你有帮助,请点赞、收藏、关注我,并且可以打赏支持我!

欢迎关注我的后续博文,我将分享更多关于人工智能、自然语言处理和计算机视觉的精彩内容。

谢谢大家的支持!

相关文章:

  • BFS之最短路径模型
  • 解决银河麒麟V10中/data目录执行权限问题
  • JDK1.8安装配置教程(图文结合,最简洁易懂)
  • undeclared identifier ‘UNITY_PREV_MATRIX_M‘ - Unity Shader自己写URP,引用内部 hlsl
  • 15年408计算机网络
  • FPGA-Vivado-IP核-逻辑分析仪(ILA)
  • 机器学习和深度学习的区别
  • 多模态——基于XrayGLM的X光片诊断的多模态大模型
  • MYSQL(学习笔记)
  • STM32F407之Flash
  • 3.4 爬虫实战-爬去智联招聘职位信息
  • 演示:基于WPF的DrawingVisual开发的频谱图和律动图
  • 【分布式微服务云原生】10分钟打造坚不可摧的系统:深入探索系统的鲁棒性
  • 在树莓派上基于 LNMP 搭建 Nextcloud
  • 图灵完备-奇数个信号
  • 【RocksDB】TransactionDB源码分析
  • Logstash 参考指南(目录)
  • python3 使用 asyncio 代替线程
  • React-redux的原理以及使用
  • Redis提升并发能力 | 从0开始构建SpringCloud微服务(2)
  • Selenium实战教程系列(二)---元素定位
  • springboot_database项目介绍
  • Spring思维导图,让Spring不再难懂(mvc篇)
  • Xmanager 远程桌面 CentOS 7
  • 阿里中间件开源组件:Sentinel 0.2.0正式发布
  • 关于Java中分层中遇到的一些问题
  • 设计模式走一遍---观察者模式
  • 王永庆:技术创新改变教育未来
  • 验证码识别技术——15分钟带你突破各种复杂不定长验证码
  • [地铁译]使用SSD缓存应用数据——Moneta项目: 低成本优化的下一代EVCache ...
  • 如何用纯 CSS 创作一个货车 loader
  • 如何在招聘中考核.NET架构师
  • ​Benvista PhotoZoom Pro 9.0.4新功能介绍
  • ‌分布式计算技术与复杂算法优化:‌现代数据处理的基石
  • (9)YOLO-Pose:使用对象关键点相似性损失增强多人姿态估计的增强版YOLO
  • (DFS + 剪枝)【洛谷P1731】 [NOI1999] 生日蛋糕
  • (Matalb分类预测)GA-BP遗传算法优化BP神经网络的多维分类预测
  • (笔试题)分解质因式
  • (七)c52学习之旅-中断
  • (小白学Java)Java简介和基本配置
  • (转)fock函数详解
  • (转)JVM内存分配 -Xms128m -Xmx512m -XX:PermSize=128m -XX:MaxPermSize=512m
  • (转)ORM
  • .NET C# 配置 Options
  • .NET CORE Aws S3 使用
  • .net Stream篇(六)
  • .net 桌面开发 运行一阵子就自动关闭_聊城旋转门家用价格大约是多少,全自动旋转门,期待合作...
  • .net打印*三角形
  • .NET未来路在何方?
  • .NET中两种OCR方式对比
  • [20140403]查询是否产生日志
  • [2023年]-hadoop面试真题(一)
  • [AIGC] SpringBoot的自动配置解析
  • [AutoSar]工程中的cpuload陷阱(三)测试
  • [BetterExplained]书写是为了更好的思考(转载)