当前位置: 首页 > news >正文

大模型推理--KV Cache

KV Cache是大模型推理中常用到的一个技巧,可以减少重复计算,加快推理速度。不少人只是从概念上知道它可以减少重复计算,详细的原理则知之甚少,此外为啥只有KV Cache而没有Q Cache呢,我们在本博客中给出详尽的解释。我想在博客的开头先阐明KV Cache可行的根本原因,大家可以带着这句话去看后面的内容:KV Cache之所以可行,就是因为在LLM decoding阶段的Attention计算需要mask,这就会导致前面已经生成的Token不需要与后面的Token重新计算Attention,从而每个token的生成只和KV相关,与之前的Q无关,进而可以让我们通过缓存KV的方式避免重复计算。不明所以也不要紧,你只需要耐着性子往下看即可。

1. LLM推理流程

当我们给大模型提供一段prompt的时候,大模型是如何转化成我们想要的答案的呢?这里面会经历很多步骤:

  1. 输入Tokenize化。原始的输入都是文本,也有可能是多模态类型的音频、图像,我们需要将其转化成大模型能识别的token。Token可以翻译为词元,在自然语言处理中,它通常是指一个词、一个字符或者一个子词等语言单位。比如我们针对本章开头的第一句话进行tokenize会得到如下输出:“当 我们 给 大模型 提供 一段 prompt 的 时候 , 大模型 是 如何 转化 成 我们 想要 的 答案 的呢 ?”(一个可行的切分,真实场景中不一定如此)。可以看出,tokenize就相当于传统NLP任务中的分词。大模型有一个词汇表,其中包含了模型能够识别的所有token。在将文本转化为token序列时,会根据词汇表中的token进行匹配和划分。这一步的主要目的是将原始的prompt转化为token的id序列。

  2. 获取token的embedding+位置编码。这一步的操作是将token的id序列转化成embedding序列,embedding就是每个token的向量化表示。此外,在真正开始推理之前embedding中还需要加入位置编码(positional encoding)信息,这是由于transformer结构不再使用基于循环的方式建模文本输入,序列中不再有任何信息能够提示token之间的相对位置关系。具体来说,每个token在整个序列中所处的位置都会有一个位置向量,这一向量会与token的embedding向量相加。步骤1和2相当于大模型推理的前处理,将文本或者多模态输入变成一个N x d的矩阵正式进入后续大模型的推理阶段。
    在这里插入图片描述
    图1 大模型推理的两个阶段prefilling和decoding

  3. 大模型推理一般包括两个阶段prefilling和decoding,如图1所示(from paper:A Survey on Efficient Inference for Large Language Models)。Prefilling阶段就是将步骤2中生成的N x d矩阵送到大模型中做一遍推理。N就是prompt分词之后的长度,也即上下文的长度。N通常情况下会比较大,当前的大模型普遍可以支持128K的上下文。因为要一次性完成整个prompt对应token的计算,所以prefilling阶段的计算量很大,但它是一次性的过程,所以在整个推理中只占不到 10% 的时间。更核心的是要解决Attention计算过程中显存占用过大的问题,业界的主流的方案就是FlashAttention,可以参考我之前写的博客《大模型推理–FlashAttention》。大模型的最后一层往往是softmax,用来输出每个token的概率,从中选择概率最大的作为大模型的第一个输出token。事实上,如果不考虑速度因素,单纯利用prefilling也可以完成整个大模型的推理过程。每完成一次大模型推理我们就可以获得一个token,将其与prompt和之前输出的所有token组合重新开始新一轮的大模型推理,直到输出停止符为止。不过我们很容易看出这种推理模式包含非常多的重复计算,每增加一个token,prompt部分的推理就会重复一次。但是这些重复计算是没有意义的,所以才会将大模型推理拆分成了prefilling+decoding两阶段,目的就是要在prefilling阶段计算KV Cache供decoding阶段使用,当然prefilling也会生成第一个输出token。

  4. decoding阶段就是自回归输出每个token的过程。这里首先要强调一点,prefilling和decoding用的都是相同的大模型,只是推理方式不同:prefilling阶段一次性灌入整个prompt,而decoding阶段则是一次灌入一个token。所以prefilling阶段主要的操作是gemm(矩阵乘),到了decoding阶段就变成了gemv(矩阵向量乘)。正是由于 decoding阶段是逐个token生成的,每一次回答都会生成很多token,所以decoding阶段的数量非常多,占到整个推理过程的90%以上。decoding阶段就是要利用prefilling阶段生成的KV Cache来避免重复计算,这里的核心点在于mask机制保证了transformer每一个block中的Attention计算可以实现逐行运算。

将prefilling阶段输出的第一个token加decoding阶段所有的输出token组合起来即可以转换为最终的输出。这里不少人可能会有疑问,按照上面的步骤完成推理之后貌似生成的结果是固定的,但是为什么我们实际在使用大模型的时候同样的输入每次输出也不同呢?这里面涉及到了大模型采样参数的问题。在prefilling阶段我们提到,可以选择softmax层中概率最大的token作为输出,这可以看做是一种贪婪解码,我们可以完全不选择top1作为输出,而是引入更多采样策略让生成的结果更加多样化,比如从topk中随机选择一个,又或者引入beam-search之类的解码策略,更完整的介绍大家可以参考:《优化采样参数提升大语言模型响应质量:深入分析温度、top_p、top_k和min_p的随机解码策略》。

2. 单层transformer推理过程

给定一个输入 X X X,维度为N x d(忽略batch维),我们将其送入图1左侧所示的大模型一个block里面,看看具体的执行过程。如前所述,N表示prompt上下文长度,可以很长,目前大模型普遍可以支持到128K甚至更长。d表示embedding的维度,通常情况根据大模型的尺寸不同从4K到12K不等。输入X会先送入三个projection矩阵 W Q , W K , W V W_Q,W_K,W_V WQWKWV进行线性变换:
Q = X ∗ W Q \begin{equation} Q=X*W_Q \end{equation} Q=XWQ K = X ∗ W K \begin{equation} K=X*W_K \end{equation} K=XWK V = X ∗ W V \begin{equation} V=X*W_V \end{equation} V=XWV这三个矩阵的维度一般相等,通常都是d x d,所以输出维度还是N x d。上述三个公式都是 X X X乘以 W W W,所以 X X X增加一行对应的结果 Q 、 K 、 V Q、K、V QKV也会增加一行。变换完成之后生成的 K 、 V K、V KV即是我们要详细介绍的KV Cache。从这三个公式也可以看出,如果单纯采用prefilling完成大模型推理,虽然可以获得正确结果,但是 X X X在prompt部分的推理一直在重复。单单看这一步,我们确实可以分成prefilling和decoding两阶段,prefilling对完整的prompt进行推理,后续的自回归解码就可以逐行推理,因为每一行的计算都不会对之前的行产生影响。
上面的三个矩阵乘只是block的第一步,还没有用到完整的 Q 、 K 、 V Q、K、V QKV,更关键的是下一步Attention的计算。Attention的计算按照公式: O = A t t e n t i o n ( Q , K , V ) = s o f t m a x ( Q K T d ) V \begin{equation} O=Attention(Q,K,V)=softmax(\frac{QK^T}{\sqrt d})V \end{equation} O=Attention(Q,K,V)=softmax(d QKT)V完成,在这一步中完整的 Q 、 K 、 V Q、K、V QKV都会参与运算,如图2所示。
在这里插入图片描述
从上图可以看出,Attention在计算过程中会生成一个N x N的临时矩阵,这个矩阵在N很大时会占用大量空间,FlashAttention就是为了避免生成这个临时矩阵而设计的。Attention计算的输入是三个N x d的矩阵 Q 、 K 、 V Q、K、V QKV,输出还是N x d的矩阵 O O O,到这里我们就会迫切想知道如果我们给输入增加一行变为(N+1) x d,输出的变化是怎样的。如果输出只是在旧的 O O O基础上新增一行,那就有机会实现逐行计算;如果会导致 O O O每一行都发生变化,则逐行计算就不可能实现。幸运的是,由于Mask机制的存在,大模型decoding阶段Attention计算的输出 O O O确实是逐行更新的,不会对之前行的结果进行修改。这也就确保了transformer一个block内Attention部分逐行计算的可行性。
Attention之后还包括Add、LayerNorm、Activation、FeedForward等运算,这些运算从原理上来说都非常容易实现逐行计算。所以放眼transformer整个block,只要解决了Attention部分的逐行计算,整个block就可以逐行计算,进而整个transformer大模型就都可以实现逐行计算,如此大模型的decoding阶段就可以按照自回归的方式逐个token逐个token的生成了。

3. Attention逐行计算的原理

图2展示的算是prefilling阶段的Attention计算,输入就是所有的token,我们将decoding阶段包含进来,在图2的基础上增加一行,看看会出现什么变化。
在这里插入图片描述
图 3 展示的相当于把第一次做 decoding 的 Attention 运算利用 prefilling 来实现的例子。Prefilling 阶段已经对 N x d 的输入完成了推理,并生成了第一个 token,我们将这一个 token和之前的 N 个 token 组合在一起继续进行推理。图 3 中我添加了很多问号,用来表示如果不添加限制,就会导致中间临时矩阵每一行之前计算的 softmax 不再正确,需要重新计算,进而导致后续乘以 V V V 时得到最终的输出 O O O 每一个位置的值都发生了变化。这种情况下就说明,新增一个 token 就会导致 Attention 计算输出 O O O 的完全改变,完全无法实现逐行计算。
但是我们真实使用的Attention是带mask的Attention,所以就不会出现上面所述的可能。具体来说,前面已经生成的token不需要与后面的token产生Attention,也即 Q Q Q前N行无需和 K T K^T KT的第N+1列进行计算,这样就会导致中间临时矩阵最后一列的前N行的值其实为0(mask操作一般在softmax之前,会将最后一列的前N行变为负无穷,等价于softmax之后的0),从而也就不会影响之前计算的softmax。中间临时矩阵前N行的值没有改变,在与 V V V相乘时也就不会导致 O O O的前N行发生变化,这就相当于 O O O只是在末尾增加了一行,如图4所示。当然不止第N+1行的计算如此,在Mask的影响下,后续N+2,N+3等等都是如此。
在这里插入图片描述
所以我们可以得到一个结论:在完成prefilling阶段之后,对N+1个输入再次进行prefilling只会产生第N+1行一个新的输出,旧的输出都不会改变。这就启发我们只对第N+1行的输入进行计算,这样很自然地就引入了decoding。我们将图4进行简化,只灌入一行 Q Q Q得到图5。
在这里插入图片描述
从图5我们又很清晰地看到,为了计算 Q Q Q的Attention值,我们需要完整的 K K K V V V,这也就是我们本博客要介绍的KV Cache!我们做一个总结:大模型自回归的机制使得单纯的prefilling推理存在大量的重复计算,带mask的Attention又使我们可以将重复计算去掉,为了实现逐行的Attention计算我们又必须要保留完整的 K K K V V V,这也就是KV Cache的基本原理。也正是因为Mask的原因,我们无需保留完整的 Q Q Q,只需要每来一个 Q Q Q进行逐行计算即可,这也是为啥存在KV Cache但是没有Q Cache的原因。希望通过本博客的介绍大家可以真正搞懂KV Cache。

4. 参考

  1. 大模型推理优化技术-KV Cache

相关文章:

  • 北京网站建设多少钱?
  • 辽宁网页制作哪家好_网站建设
  • 高端品牌网站建设_汉中网站制作
  • SCL 常见问题
  • 异常整理(JAVA基础)
  • 【C++】STL容器详解【上】
  • Java——堆
  • 路灯集中控制器与智慧照明:塑造未来城市的智能光影
  • 亦菲喊你来学机器学习(20) --PCA数据降维
  • 江协科技stm32————11-5 硬件SPI读写W25Q64
  • explicit 的作用(如何避免编译器进行隐式类型转换)
  • 并发编程:synchronized 关键字
  • 【Linux】Linux 可重入函数
  • 0.ffmpeg面向对象oopc
  • 项目实战系列三: 家居购项目 第五部分
  • C++ STL-Map容器从入门到精通详解
  • HarmonyOs DevEco Studio小技巧9--翻译软件
  • 怎么利用XML发送物流快递通知短信
  • SegmentFault for Android 3.0 发布
  • 30天自制操作系统-2
  • ABAP的include关键字,Java的import, C的include和C4C ABSL 的import比较
  • CSS魔法堂:Absolute Positioning就这个样
  • Java 网络编程(2):UDP 的使用
  • JavaScript对象详解
  • javascript面向对象之创建对象
  • Nodejs和JavaWeb协助开发
  • opencv python Meanshift 和 Camshift
  • SQLServer插入数据
  • 发布国内首个无服务器容器服务,运维效率从未如此高效
  • 前端学习笔记之原型——一张图说明`prototype`和`__proto__`的区别
  • 扫描识别控件Dynamic Web TWAIN v12.2发布,改进SSL证书
  • 用 Swift 编写面向协议的视图
  • 在Unity中实现一个简单的消息管理器
  • Nginx惊现漏洞 百万网站面临“拖库”风险
  • ​iOS实时查看App运行日志
  • # Pytorch 中可以直接调用的Loss Functions总结:
  • #pragam once 和 #ifndef 预编译头
  • #知识分享#笔记#学习方法
  • %3cscript放入php,跟bWAPP学WEB安全(PHP代码)--XSS跨站脚本攻击
  • (26)4.7 字符函数和字符串函数
  • (CVPRW,2024)可学习的提示:遥感领域小样本语义分割
  • (html5)在移动端input输入搜索项后 输入法下面为什么不想百度那样出现前往? 而我的出现的是换行...
  • (Mirage系列之二)VMware Horizon Mirage的经典用户用例及真实案例分析
  • (Note)C++中的继承方式
  • (pt可视化)利用torch的make_grid进行张量可视化
  • (ZT)薛涌:谈贫说富
  • (编程语言界的丐帮 C#).NET MD5 HASH 哈希 加密 与JAVA 互通
  • (待修改)PyG安装步骤
  • (二十四)Flask之flask-session组件
  • (附源码)springboot助农电商系统 毕业设计 081919
  • (附源码)计算机毕业设计SSM疫情下的学生出入管理系统
  • (理论篇)httpmoudle和httphandler一览
  • (六)激光线扫描-三维重建
  • (论文阅读11/100)Fast R-CNN
  • (三十)Flask之wtforms库【剖析源码上篇】
  • (十一)c52学习之旅-动态数码管
  • (一)eclipse Dynamic web project 工程目录以及文件路径问题
  • (一)基于IDEA的JAVA基础1