LLM 中的注意力机制
问题
在大语言模型推理过程中,注意力机制面临什么挑战?KV Cache 是什么?如何支撑长上下文?
答案
本文聚焦注意力机制在 LLM 推理场景下的特殊挑战和优化,是 注意力机制详解 在 LLM 场景的延伸。
一、LLM 推理的两个阶段
| 阶段 | 特点 | 瓶颈 |
|---|---|---|
| Prefill | 并行处理所有输入 Token | 计算密集型(GPU 算力) |
| Decode | 每次只生成一个 Token,依赖所有已生成 Token | 内存密集型(KV Cache 带宽) |
二、KV Cache 详解
为什么需要 KV Cache?
自回归生成中,每生成一个新 Token 都需要和所有之前的 Token 做注意力计算。如果每次都重新计算所有 K 和 V,第 步的计算量就是 ——随着序列增长线性增加。
KV Cache:将之前所有层的 K 和 V 矩阵缓存在显存中,每一步只计算新 Token 的 Q,和缓存的 K/V 做注意力。
KV Cache 显存计算
- :层数
- :KV 头数
- :头维度
- :序列长度
- bytes:FP16 = 2 字节
示例:LLaMA 2-70B(80 层、8 KV 头、128 维、序列 4096)
KV Cache 通常比模型权重消耗更多显存。batch size 为 32 时,KV Cache 就需要 342GB——远超模型本身的 140GB(FP16)。这就是为什么推理优化如此重要。
三、KV Cache 优化技术
| 技术 | 原理 | 节省倍数 | 代表 |
|---|---|---|---|
| GQA | 多个 Q 头共享 KV | 4-8× | LLaMA 2/3 |
| MQA | 所有 Q 头共享一个 KV | 32× | PaLM, Falcon |
| KV Cache 量化 | 将 KV 从 FP16 量化为 INT8/INT4 | 2-4× | vLLM |
| PagedAttention | 操作系统分页思想管理 KV Cache | 避免碎片 | vLLM |
| Sliding Window | 只保留最近 W 个 Token 的 KV | W/n× | Mistral |
PagedAttention(vLLM 核心)
传统 KV Cache 为每个请求预分配最大长度的连续内存——大量浪费。PagedAttention 借鉴操作系统的虚拟内存分页:
- KV Cache 被分成固定大小的"页面"(Block)
- 按需分配页面,不预留最大长度
- 不同请求的 KV Cache 页面可以不连续
- 结果:显存利用率从约 50% 提升到 >95%
四、长上下文注意力
现代 LLM 支持 128K-1M Token 的上下文窗口,核心技术:
1. RoPE 外推与内插
RoPE 天然支持位置编码外推,但直接外推效果会下降。常用方法:
| 方法 | 思路 | 典型模型 |
|---|---|---|
| Position Interpolation | 线性内插到更长位置 | LLaMA Long |
| NTK-aware Interpolation | 调整 RoPE 基频 | Code LLaMA |
| YaRN | 结合内插 + NTK + 温度缩放 | Mistral、Qwen |
| ABF(Adjusted Base Frequency) | 直接增大 base 频率 | LLaMA 3 |
2. Sliding Window Attention
Mistral 使用固定窗口大小(如 4096),每层只看最近 4096 个 Token。通过层层叠加,信息可以在更远的位置传播(理论感受野 = 层数 × 窗口大小)。
3. Ring Attention
将长序列分布在多个 GPU 上,每个 GPU 处理一段序列的 QKV,通过环形通信传递 KV 块。
五、推理优化:从 Prefill 到 Decode
Speculative Decoding(投机采样)
用小模型(Draft Model)快速生成多个候选 Token,大模型一次性验证——如果大部分被接受,等效于一步生成多个 Token:
加速比取决于小模型和大模型的一致率,通常 2-3 倍。
Continuous Batching
传统推理等所有请求都结束才释放 batch——短请求被长请求"拖累"。vLLM 等引擎的 Continuous Batching 允许每步动态加入/移除请求,大幅提升吞吐。
常见面试问题
Q1: KV Cache 是什么?为什么是推理的瓶颈?
答案: KV Cache 缓存自回归生成中每层的 Key 和 Value 矩阵。瓶颈在于:
- 显存消耗:随序列长度和 batch size 线性增长,很容易超过模型本身
- 带宽瓶颈:Decode 阶段每步都需要读取整个 KV Cache,是内存带宽瓶颈
- 碎片化:预分配最大长度导致显存浪费(PagedAttention 解决)
Q2: GQA 如何减少 KV Cache?
答案: GQA 让多个 Query 头共享一组 KV 头。例如 LLaMA 2-70B 有 64 个 Q 头但只有 8 个 KV 头——KV Cache 减少了 8 倍。模型质量几乎不受影响,因为多个 Q 头仍然学习了不同的查询模式。
Q3: vLLM 的 PagedAttention 解决了什么问题?
答案: 解决了 KV Cache 的显存碎片化问题。传统做法为每个请求预分配最大序列长度的连续显存块,请求实际长度通常远短于最大值,造成 40-60% 的显存浪费。PagedAttention 按需分配非连续的小块(类似 OS 的虚拟内存分页),显存利用率提升到 >95%,batch size 可增大 2-4 倍。
Q4: 如何让模型支持更长的上下文?
答案:
- 位置编码外推:ABF(增大 RoPE base)、YaRN(内插 + NTK)
- 注意力优化:Flash Attention(减少显存)、Sliding Window(限制每层窗口)
- 分布式计算:Ring Attention(跨 GPU 分割序列)
- 训练数据:需要在长序列数据上继续训练,否则模型无法有效利用长上下文
- 实测工具:NIAH(Needle in a Haystack)测试模型在长上下文中的信息检索能力