当前位置: 首页 > news >正文

什么是键值缓存?让 LLM 闪电般快速

一、为什么 LLMs 需要 KV 缓存?

大语言模型(LLMs)的文本生成遵循 “自回归” 模式 —— 每次仅输出一个 token(如词语、字符或子词),再将该 token 与历史序列拼接,作为下一轮输入,直到生成完整文本。这种模式的核心计算成本集中在注意力机制上:每个 token 的输出都依赖于它与所有历史 token 的关联,而注意力机制的计算复杂度会随序列长度增长而急剧上升。

以生成一个长度为 n 的序列为例,若不做优化,每生成第 m 个 token 时,模型需要重新计算前 m 个 token 的 “查询(Q)、键(K)、值(V)” 矩阵,导致重复计算量随 m 的增长呈平方级增加(时间复杂度 O (n²))。当 n 达到数千(如长文本生成),这种重复计算会让推理速度变得极慢。KV 缓存(Key-Value Caching)正是为解决这一问题而生 —— 通过 “缓存” 历史计算的 K 和 V,避免重复计算,将推理效率提升数倍,成为 LLMs 实现实时交互的核心技术之一。

二、注意力机制:KV 缓存优化的 “靶心”

要理解 KV 缓存的作用,需先明确注意力机制的计算逻辑。在 Transformer 架构中,注意力机制的核心公式为:

\text{Attention}(Q, K, V) = \text{softmax}\left(\frac{QK^T}{\sqrt{d_k}}\right)V

其中:

  • Q(查询矩阵):维度为(n \times d_k),代表当前 token 对 “需要关注什么” 的查询;
  • K(键矩阵):维度为(n \times d_k),代表历史 token 的 “特征标识”;
  • V(值矩阵):维度为(n \times d_v),代表历史 token 的 “特征值”(通常d_v = d_k);
  • d_k是Q和K的维度(由模型维度d_{\text{model}}和注意力头数决定,如d_k = \frac{d_{\text{model}}}{\text{num\_heads}});
  • QK^T会生成一个(n \times n)的注意力分数矩阵,描述每个 token 与其他所有 token 的关联强度;
  • 经过 softmax 归一化后与V相乘,最终得到每个 token 的注意力输出(维度(n \times d_v))。

三、KV 缓存的核心原理:“记住” 历史,避免重复计算

自回归生成的痛点在于:每轮生成新 token 时,历史 token 的 K 和 V 会被重复计算。例如:

  • 生成第 3 个 token 时,输入序列是[t_1, t_2],已计算过t_1t_2K_1, K_2V_1, V_2
  • 生成第 4 个 token 时,输入序列变为[t_1, t_2, t_3],若不优化,模型会重新计算t_1, t_2, t_3的K和V—— 其中t_1, t_2的K、V与上一轮完全相同,属于无效重复。

KV 缓存的解决方案极其直接:

  1. 缓存历史 K 和 V:每生成一个新 token 后,将其K和V存入缓存,与历史缓存的K、V拼接;
  2. 仅计算新 token 的 K 和 V:下一轮生成时,无需重新计算所有 token 的K、V,只需为新 token 计算K_{\text{new}}V_{\text{new}},再与缓存拼接,直接用于注意力计算。

这一过程将每轮迭代的计算量从 “重新计算 n 个 token 的 K、V” 减少到 “计算 1 个新 token 的 K、V”,时间复杂度从O(n²)优化为接近O(n),尤其在生成长文本时,效率提升会非常显著。

四、代码实现:从 “无缓存” 到 “有缓存” 的对比

以下用 PyTorch 代码模拟单头注意力机制,直观展示 KV 缓存的作用(假设模型维度d_{\text{model}}=64d_k=64):

import torch
import torch.nn.functional as F# 1. 定义基础参数与注意力函数
d_model = 64  # 模型维度
d_k = d_model  # 单头注意力中Q、K的维度
batch_size = 1  # 批量大小def scaled_dot_product_attention(Q, K, V):"""计算缩放点积注意力"""# 步骤1:计算注意力分数 (n×d_k) @ (d_k×n) → (n×n)scores = torch.matmul(Q, K.transpose(-2, -1))  # 转置K的最后两维,实现矩阵乘法scores = scores / torch.sqrt(torch.tensor(d_k, dtype=torch.float32))  # 缩放# 步骤2:softmax归一化,得到注意力权重 (n×n)attn_weights = F.softmax(scores, dim=-1)  # 沿最后一维归一化# 步骤3:加权求和 (n×n) @ (n×d_k) → (n×d_k)output = torch.matmul(attn_weights, V)return output, attn_weights# 2. 模拟输入数据:历史序列与新token
# 历史序列(已生成3个token)的嵌入向量:shape=(batch_size, seq_len, d_model)
prev_embeds = torch.randn(batch_size, 3, d_model)  # 1×3×64
# 新生成的第4个token的嵌入向量:shape=(1, 1, 64)
new_embed = torch.randn(batch_size, 1, d_model)# 3. 模型中用于计算K、V的权重矩阵(假设已训练好)
Wk = torch.randn(d_model, d_k)  # 用于从嵌入向量映射到K:64×64
Wv = torch.randn(d_model, d_k)  # 用于从嵌入向量映射到V:64×64# 场景1:无KV缓存——重复计算所有token的K、V
full_embeds_no_cache = torch.cat([prev_embeds, new_embed], dim=1)  # 拼接为1×4×64
# 重新计算4个token的K和V(包含前3个的重复计算)
K_no_cache = torch.matmul(full_embeds_no_cache, Wk)  # 1×4×64(前3个与历史重复)
V_no_cache = torch.matmul(full_embeds_no_cache, Wv)  # 1×4×64(前3个与历史重复)
# 计算注意力(Q使用当前序列的嵌入向量,此处简化为与K相同)
output_no_cache, _ = scaled_dot_product_attention(K_no_cache, K_no_cache, V_no_cache)# 场景2:有KV缓存——仅计算新token的K、V,复用历史缓存
# 缓存前3个token的K、V(上一轮已计算,无需重复)
K_cache = torch.matmul(prev_embeds, Wk)  # 1×3×64(历史缓存)
V_cache = torch.matmul(prev_embeds, Wv)  # 1×3×64(历史缓存)# 仅计算新token的K、V
new_K = torch.matmul(new_embed, Wk)  # 1×1×64(新计算)
new_V = torch.matmul(new_embed, Wv)  # 1×1×64(新计算)# 拼接缓存与新K、V,得到完整的K、V矩阵(与无缓存时结果一致)
K_with_cache = torch.cat([K_cache, new_K], dim=1)  # 1×4×64
V_with_cache = torch.cat([V_cache, new_V], dim=1)  # 1×4×64# 计算注意力(结果与无缓存完全相同,但计算量减少)
output_with_cache, _ = scaled_dot_product_attention(K_with_cache, K_with_cache, V_with_cache)# 验证:两种方式的输出是否一致(误差在浮点精度范围内)
print(torch.allclose(output_no_cache, output_with_cache, atol=1e-6))  # 输出:True

代码中,“有缓存” 模式通过复用前 3 个 token 的 K、V,仅计算新 token 的 K、V,就得到了与 “无缓存” 模式完全一致的结果,但计算量减少了 3/4(对于 4 个 token 的序列)。当序列长度增至 1000,这种优化会让每轮迭代的计算量从 1000 次矩阵乘法减少到 1 次,效率提升极其显著。

五、权衡:内存与性能的平衡

KV 缓存虽能提升速度,但需面对 “内存占用随序列长度线性增长” 的问题:

  • 缓存的 K 和 V 矩阵维度为(n \times d_k),当序列长度 n 达到 10000,且d_k=64时,单头注意力的缓存大小约为10000 \times 64 \times 2(K 和 V 各一份)=1.28 \times 10^6个参数,若模型有 12 个注意力头,总缓存会增至约 150 万参数,对显存(尤其是 GPU)是不小的压力。

为解决这一问题,实际应用中会采用以下优化策略:

  • 滑动窗口缓存:仅保留最近的k个 token 的 K、V(如 k=2048),超过长度则丢弃最早的缓存,适用于对长距离依赖要求不高的场景;
  • 动态缓存管理:根据输入序列长度自动调整缓存策略,在短序列时全量缓存,长序列时启用滑动窗口;
  • 量化缓存:将 K、V 从 32 位浮点(float32)量化为 16 位(float16)或 8 位(int8),以牺牲少量精度换取内存节省,目前主流 LLMs(如 GPT-3、LLaMA)均采用此方案。

六、实际应用:KV 缓存如何支撑 LLMs 的实时交互?

在实际部署中,KV 缓存是 LLMs 实现 “秒级响应” 的关键。例如:

  • 聊天机器人(如 ChatGPT)生成每句话时,通过 KV 缓存避免重复计算历史对话的 K、V,让长对话仍能保持流畅响应;
  • 代码生成工具(如 GitHub Copilot)在补全长代码时,缓存已输入的代码 token 的 K、V,确保补全速度与输入长度无关;
  • 语音转文本实时生成(如实时字幕)中,KV 缓存能让模型随语音输入逐词生成文本,延迟控制在数百毫秒内。

可以说,没有 KV 缓存,当前 LLMs 的 “实时交互” 体验几乎无法实现 —— 它是平衡模型性能与推理效率的 “隐形支柱”。

总结

KV 缓存通过复用历史 token 的 K 和 V 矩阵,从根本上解决了 LLMs 自回归生成中的重复计算问题,将时间复杂度从O(n²)优化为接近O(n)。其核心逻辑简单却高效:“记住已经算过的,只算新的”。尽管需要在内存与性能间做权衡,但通过滑动窗口、量化等策略,KV 缓存已成为现代 LLMs 推理不可或缺的技术,支撑着从聊天机器人到代码生成的各类实时交互场景。

http://www.lryc.cn/news/612682.html

相关文章:

  • OpenCV的关于图片的一些运用
  • 数据分析进阶——53页跨境数据分析【附全文阅读】
  • 僵尸进程问题排查
  • Mac+Chrome滚动截图
  • localforage的数据仓库、实例、storeName和name的概念和区别
  • OpenAI 开源模型 gpt-oss 正式上线微软 Foundry 平台
  • [Oracle] CEIL()函数
  • 利用微软SQL Server数据库管理员(SA)口令为空的攻击活动猖獗
  • MySQL中的DDL(一)
  • 直连微软,下载速度达18M/S
  • [2402MT-A] Redbag
  • 从周末去哪儿玩到决策树:机器学习算法的生活启示
  • 《深入解析缓存三大难题:穿透、雪崩、击穿及应对之道》
  • Mysql数据仓库备份脚本
  • 突破距离桎梏:5G 高清视频终端如何延伸无人机图传边界
  • 【完整源码+数据集+部署教程】无人机自然场景分割系统源码和数据集:改进yolo11-RVB
  • 计算机网络1-4:计算机网络的定义和分类
  • 【网络编程】一请求一线程
  • 云原生安全挑战与治理策略:从架构思维到落地实践
  • PyTorch + PaddlePaddle 语音识别
  • 从BaseMapper到LambdaWrapper:MyBatis-Plus的封神之路
  • day44 力扣1143.最长公共子序列 力扣1035.不相交的线 力扣53. 最大子序和 力扣392.判断子序列
  • WEB开发-第二十七天(PHP篇)
  • 笔试——Day31
  • Linux(17)——Linux进程信号(下)
  • 【42】【OpenCV C++】 计算图像某一列像素方差 或 某一行像素的方差;
  • uniapp vue3中使用pinia 和 pinia持久化(没有使用ts)
  • SQLite 创建表
  • VUE+SPRINGBOOT从0-1打造前后端-前后台系统-文章列表
  • [失败记录] 使用HBuilderX创建的uniapp vue3项目添加tailwindcss3的完整过程