跳到主要内容

注意力机制详解

问题

注意力机制有哪些类型和变种?Flash Attention 如何优化计算?MHA、MQA、GQA 有什么区别?

答案

注意力机制是 Transformer 的核心组件。本文深入探讨不同类型的注意力及其在现代 LLM 中的优化。

一、注意力机制的演进

二、注意力机制的类型

1. Self-Attention(自注意力)

序列内部的关注——每个 Token 看自己和其他 Token 的关系:

Attention(Q,K,V)=Softmax(QKTdk)V\text{Attention}(Q, K, V) = \text{Softmax}\left(\frac{QK^T}{\sqrt{d_k}}\right)V

Q、K、V 全部来自同一输入序列的不同线性投影。

2. Cross-Attention(交叉注意力)

两个不同序列之间的关注——Q 来自一个序列,K 和 V 来自另一个序列:

  • 翻译:Decoder 的 Q 查询 Encoder 的 K、V
  • 多模态:文本的 Q 查询图像的 K、V

3. Causal Attention(因果注意力)

带掩码的 Self-Attention——只看当前位置之前的 Token:

# 因果掩码示例
import torch

def causal_attention(Q, K, V):
d_k = Q.size(-1)
scores = torch.matmul(Q, K.transpose(-2, -1)) / (d_k ** 0.5)

# 创建因果掩码:上三角为 -inf
seq_len = Q.size(-2)
mask = torch.triu(torch.ones(seq_len, seq_len), diagonal=1).bool()
scores.masked_fill_(mask, float('-inf'))

weights = torch.softmax(scores, dim=-1)
return torch.matmul(weights, V)

三、MHA、MQA、GQA 对比

这是现代 LLM 推理优化中的核心话题。

类型Key/Value 头数KV Cache 大小质量代表模型
MHA(Multi-Head Attention)= Query 头数最高GPT-3、BERT
MQA(Multi-Query Attention)1最小略低PaLM、Falcon
GQA(Grouped-Query Attention)介于 1 和 Q 头数之间适中接近 MHALLaMA 2/3、Mistral
为什么 GQA 是当前主流?

GQA 是 MHA 和 MQA 的折中:

  • 比 MHA 显著减少 KV Cache(推理时显存大幅降低)
  • 比 MQA 质量更好(多组 KV 保留更多信息)
  • LLaMA 2(70B)率先采用,后续模型普遍跟进

四、Flash Attention

Flash Attention 不是新的注意力机制,而是标准注意力的高效硬件实现——算法层面完全等价,但显存降低了 5-20 倍,速度提升 2-4 倍。

核心思想:IO 感知

传统实现中,注意力矩阵 S=QKTS = QK^Tn×nn \times n)需要写入 HBM(GPU 显存),再从 HBM 读回做 Softmax——这个来回读写是性能瓶颈。

Flash Attention 用 分块计算(Tiling)+ 在线 Softmax 的技巧,在 SRAM (GPU 高速缓存)中完成所有计算,避免将完整注意力矩阵写入 HBM。

对比标准 AttentionFlash Attention
显存复杂度O(n2)O(n^2)O(n)O(n)
速度受限于 HBM 带宽利用 SRAM 高带宽
是否精确✅(不是近似)
支持PyTorch 原生PyTorch 2.0+ F.scaled_dot_product_attention
面试要点

Flash Attention 不是近似算法——它的数学结果和标准注意力完全相同。它是一个纯粹的系统级优化,通过减少 GPU 内存读写次数来加速。

五、长序列注意力

处理超长序列(>100K tokens)的策略:

方法思路复杂度代表
Sliding Window只看固定窗口内的 TokenO(n×w)O(n \times w)Mistral
Sparse Attention稀疏注意力模式O(nn)O(n\sqrt{n})BigBird、Longformer
Ring Attention跨 GPU 分布注意力计算O(n2/p)O(n^2/p)Llama 3
Linear Attention用核函数近似 SoftmaxO(n)O(n)kattn

六、KV Cache

自回归生成中,每生成一个新 Token 都需要对所有已生成 Token 做注意力计算。KV Cache 缓存之前的 K 和 V 矩阵,避免重复计算:

  • 没有 KV Cache:每生成一个 Token 重新计算所有 QKTQK^T
  • 有 KV Cache:只计算新 Token 的 Q 和缓存的 K、V 的注意力

KV Cache 是推理加速的关键,但也是显存消耗的大头——所以 GQA(减少 KV 头数)和量化 KV Cache 非常重要。


常见面试问题

Q1: Attention 的计算复杂度是多少?如何优化?

答案: 标准 Self-Attention 的时间复杂度 O(n2d)O(n^2 \cdot d),空间复杂度 O(n2)O(n^2)。优化路线:

  1. 硬件级:Flash Attention(减少 IO)
  2. 架构级:GQA(减少 KV 头数)、Sliding Window
  3. 算法级:Sparse Attention、Linear Attention

Q2: Flash Attention 是怎么做到降低显存的?

答案: 核心技巧是避免在 HBM 中存储完整的 n×nn \times n 注意力矩阵。它将 Q、K、V 分成小块(Tile),在 SRAM 中逐块计算注意力,使用 Online Softmax 算法增量更新结果。最终效果:结果精确,但只需要 O(n)O(n) 额外显存。

Q3: 什么是 KV Cache?为什么它是推理的显存瓶颈?

答案: KV Cache 缓存自回归生成过程中每一层、每一步的 K 和 V 矩阵。对于一个 70B 参数的模型(80 层、8 个 KV 头、128 维),生成 4096 个 Token 的 KV Cache 大小约为:

80×8×4096×128×2×2=10.7GB80 \times 8 \times 4096 \times 128 \times 2 \times 2 = 10.7\text{GB}(FP16)

这意味着一个请求就需要 10+GB 显存来存 KV Cache,是 batch 推理的核心瓶颈。

Q4: MHA 和 GQA 的具体区别是什么?

答案

  • MHA:假设 32 个注意力头,每个头有独立的 Q、K、V 投影——共 32 组 K/V
  • GQA:32 个 Q 头分成 8 组,每组 4 个 Q 头共享 1 组 K/V——共 8 组 K/V
  • 效果:KV Cache 减少 4 倍,推理显存大幅降低,质量几乎不损失

相关链接