跳到主要内容

LLM 中的注意力机制

问题

在大语言模型推理过程中,注意力机制面临什么挑战?KV Cache 是什么?如何支撑长上下文?

答案

本文聚焦注意力机制在 LLM 推理场景下的特殊挑战和优化,是 注意力机制详解 在 LLM 场景的延伸。

一、LLM 推理的两个阶段

阶段特点瓶颈
Prefill并行处理所有输入 Token计算密集型(GPU 算力)
Decode每次只生成一个 Token,依赖所有已生成 Token内存密集型(KV Cache 带宽)

二、KV Cache 详解

为什么需要 KV Cache?

自回归生成中,每生成一个新 Token 都需要和所有之前的 Token 做注意力计算。如果每次都重新计算所有 K 和 V,第 tt 步的计算量就是 O(td)O(t \cdot d)——随着序列增长线性增加。

KV Cache:将之前所有层的 K 和 V 矩阵缓存在显存中,每一步只计算新 Token 的 Q,和缓存的 K/V 做注意力。

KV Cache 显存计算

KV Cache=2×L×nkv×dh×s×bytes\text{KV Cache} = 2 \times L \times n_{kv} \times d_h \times s \times \text{bytes}
  • LL:层数
  • nkvn_{kv}:KV 头数
  • dhd_h:头维度
  • ss:序列长度
  • bytes:FP16 = 2 字节

示例:LLaMA 2-70B(80 层、8 KV 头、128 维、序列 4096)

2×80×8×128×4096×2=10.7GB(单个请求!)2 \times 80 \times 8 \times 128 \times 4096 \times 2 = 10.7\text{GB(单个请求!)}
推理的显存瓶颈

KV Cache 通常比模型权重消耗更多显存。batch size 为 32 时,KV Cache 就需要 342GB——远超模型本身的 140GB(FP16)。这就是为什么推理优化如此重要。

三、KV Cache 优化技术

技术原理节省倍数代表
GQA多个 Q 头共享 KV4-8×LLaMA 2/3
MQA所有 Q 头共享一个 KV32×PaLM, Falcon
KV Cache 量化将 KV 从 FP16 量化为 INT8/INT42-4×vLLM
PagedAttention操作系统分页思想管理 KV Cache避免碎片vLLM
Sliding Window只保留最近 W 个 Token 的 KVW/n×Mistral

PagedAttention(vLLM 核心)

传统 KV Cache 为每个请求预分配最大长度的连续内存——大量浪费。PagedAttention 借鉴操作系统的虚拟内存分页:

  • KV Cache 被分成固定大小的"页面"(Block)
  • 按需分配页面,不预留最大长度
  • 不同请求的 KV Cache 页面可以不连续
  • 结果:显存利用率从约 50% 提升到 >95%

四、长上下文注意力

现代 LLM 支持 128K-1M Token 的上下文窗口,核心技术:

1. RoPE 外推与内插

RoPE 天然支持位置编码外推,但直接外推效果会下降。常用方法:

方法思路典型模型
Position Interpolation线性内插到更长位置LLaMA Long
NTK-aware Interpolation调整 RoPE 基频Code LLaMA
YaRN结合内插 + NTK + 温度缩放Mistral、Qwen
ABF(Adjusted Base Frequency)直接增大 base 频率LLaMA 3

2. Sliding Window Attention

Mistral 使用固定窗口大小(如 4096),每层只看最近 4096 个 Token。通过层层叠加,信息可以在更远的位置传播(理论感受野 = 层数 × 窗口大小)。

3. Ring Attention

将长序列分布在多个 GPU 上,每个 GPU 处理一段序列的 QKV,通过环形通信传递 KV 块。

五、推理优化:从 Prefill 到 Decode

Speculative Decoding(投机采样)

用小模型(Draft Model)快速生成多个候选 Token,大模型一次性验证——如果大部分被接受,等效于一步生成多个 Token:

加速比取决于小模型和大模型的一致率,通常 2-3 倍。

Continuous Batching

传统推理等所有请求都结束才释放 batch——短请求被长请求"拖累"。vLLM 等引擎的 Continuous Batching 允许每步动态加入/移除请求,大幅提升吞吐。


常见面试问题

Q1: KV Cache 是什么?为什么是推理的瓶颈?

答案: KV Cache 缓存自回归生成中每层的 Key 和 Value 矩阵。瓶颈在于:

  1. 显存消耗:随序列长度和 batch size 线性增长,很容易超过模型本身
  2. 带宽瓶颈:Decode 阶段每步都需要读取整个 KV Cache,是内存带宽瓶颈
  3. 碎片化:预分配最大长度导致显存浪费(PagedAttention 解决)

Q2: GQA 如何减少 KV Cache?

答案: GQA 让多个 Query 头共享一组 KV 头。例如 LLaMA 2-70B 有 64 个 Q 头但只有 8 个 KV 头——KV Cache 减少了 8 倍。模型质量几乎不受影响,因为多个 Q 头仍然学习了不同的查询模式。

Q3: vLLM 的 PagedAttention 解决了什么问题?

答案: 解决了 KV Cache 的显存碎片化问题。传统做法为每个请求预分配最大序列长度的连续显存块,请求实际长度通常远短于最大值,造成 40-60% 的显存浪费。PagedAttention 按需分配非连续的小块(类似 OS 的虚拟内存分页),显存利用率提升到 >95%,batch size 可增大 2-4 倍。

Q4: 如何让模型支持更长的上下文?

答案

  1. 位置编码外推:ABF(增大 RoPE base)、YaRN(内插 + NTK)
  2. 注意力优化:Flash Attention(减少显存)、Sliding Window(限制每层窗口)
  3. 分布式计算:Ring Attention(跨 GPU 分割序列)
  4. 训练数据:需要在长序列数据上继续训练,否则模型无法有效利用长上下文
  5. 实测工具:NIAH(Needle in a Haystack)测试模型在长上下文中的信息检索能力

相关链接