注意力机制详解
问题
注意力机制有哪些类型和变种?Flash Attention 如何优化计算?MHA、MQA、GQA 有什么区别?
答案
注意力机制是 Transformer 的核心组件。本文深入探讨不同类型的注意力及其在现代 LLM 中的优化。
一、注意力机制的演进
二、注意力机制的类型
1. Self-Attention(自注意力)
序列内部的关注——每个 Token 看自己和其他 Token 的关系:
Q、K、V 全部来自同一输入序列的不同线性投影。
2. Cross-Attention(交叉注意力)
两个不同序列之间的关注——Q 来自一个序列,K 和 V 来自另一个序列:
- 翻译:Decoder 的 Q 查询 Encoder 的 K、V
- 多模态:文本的 Q 查询图像的 K、V
3. Causal Attention(因果注意力)
带掩码的 Self-Attention——只看当前位置之前的 Token:
# 因果掩码示例
import torch
def causal_attention(Q, K, V):
d_k = Q.size(-1)
scores = torch.matmul(Q, K.transpose(-2, -1)) / (d_k ** 0.5)
# 创建因果掩码:上三角为 -inf
seq_len = Q.size(-2)
mask = torch.triu(torch.ones(seq_len, seq_len), diagonal=1).bool()
scores.masked_fill_(mask, float('-inf'))
weights = torch.softmax(scores, dim=-1)
return torch.matmul(weights, V)
三、MHA、MQA、GQA 对比
这是现代 LLM 推理优化中的核心话题。
| 类型 | Key/Value 头数 | KV Cache 大小 | 质量 | 代表模型 |
|---|---|---|---|---|
| MHA(Multi-Head Attention) | = Query 头数 | 大 | 最高 | GPT-3、BERT |
| MQA(Multi-Query Attention) | 1 | 最小 | 略低 | PaLM、Falcon |
| GQA(Grouped-Query Attention) | 介于 1 和 Q 头数之间 | 适中 | 接近 MHA | LLaMA 2/3、Mistral |
GQA 是 MHA 和 MQA 的折中:
- 比 MHA 显著减少 KV Cache(推理时显存大幅降低)
- 比 MQA 质量更好(多组 KV 保留更多信息)
- LLaMA 2(70B)率先采用,后续模型普遍跟进
四、Flash Attention
Flash Attention 不是新的注意力机制,而是标准注意力的高效硬件实现——算法层面完全等价,但显存降低了 5-20 倍,速度提升 2-4 倍。
核心思想:IO 感知
传统实现中,注意力矩阵 ()需要写入 HBM(GPU 显存),再从 HBM 读回做 Softmax——这个来回读写是性能瓶颈。
Flash Attention 用 分块计算(Tiling)+ 在线 Softmax 的技巧,在 SRAM (GPU 高速缓存)中完成所有计算,避免将完整注意力矩阵写入 HBM。
| 对比 | 标准 Attention | Flash Attention |
|---|---|---|
| 显存复杂度 | ||
| 速度 | 受限于 HBM 带宽 | 利用 SRAM 高带宽 |
| 是否精确 | ✅ | ✅(不是近似) |
| 支持 | PyTorch 原生 | PyTorch 2.0+ F.scaled_dot_product_attention |
Flash Attention 不是近似算法——它的数学结果和标准注意力完全相同。它是一个纯粹的系统级优化,通过减少 GPU 内存读写次数来加速。
五、长序列注意力
处理超长序列(>100K tokens)的策略:
| 方法 | 思路 | 复杂度 | 代表 |
|---|---|---|---|
| Sliding Window | 只看固定窗口内的 Token | Mistral | |
| Sparse Attention | 稀疏注意力模式 | BigBird、Longformer | |
| Ring Attention | 跨 GPU 分布注意力计算 | Llama 3 | |
| Linear Attention | 用核函数近似 Softmax | kattn |
六、KV Cache
自回归生成中,每生成一个新 Token 都需要对所有已生成 Token 做注意力计算。KV Cache 缓存之前的 K 和 V 矩阵,避免重复计算:
- 没有 KV Cache:每生成一个 Token 重新计算所有
- 有 KV Cache:只计算新 Token 的 Q 和缓存的 K、V 的注意力
KV Cache 是推理加速的关键,但也是显存消耗的大头——所以 GQA(减少 KV 头数)和量化 KV Cache 非常重要。
常见面试问题
Q1: Attention 的计算复杂度是多少?如何优化?
答案: 标准 Self-Attention 的时间复杂度 ,空间复杂度 。优化路线:
- 硬件级:Flash Attention(减少 IO)
- 架构级:GQA(减少 KV 头数)、Sliding Window
- 算法级:Sparse Attention、Linear Attention
Q2: Flash Attention 是怎么做到降低显存的?
答案: 核心技巧是避免在 HBM 中存储完整的 注意力矩阵。它将 Q、K、V 分成小块(Tile),在 SRAM 中逐块计算注意力,使用 Online Softmax 算法增量更新结果。最终效果:结果精确,但只需要 额外显存。
Q3: 什么是 KV Cache?为什么它是推理的显存瓶颈?
答案: KV Cache 缓存自回归生成过程中每一层、每一步的 K 和 V 矩阵。对于一个 70B 参数的模型(80 层、8 个 KV 头、128 维),生成 4096 个 Token 的 KV Cache 大小约为:
(FP16)
这意味着一个请求就需要 10+GB 显存来存 KV Cache,是 batch 推理的核心瓶颈。
Q4: MHA 和 GQA 的具体区别是什么?
答案:
- MHA:假设 32 个注意力头,每个头有独立的 Q、K、V 投影——共 32 组 K/V
- GQA:32 个 Q 头分成 8 组,每组 4 个 Q 头共享 1 组 K/V——共 8 组 K/V
- 效果:KV Cache 减少 4 倍,推理显存大幅降低,质量几乎不损失