MLA = Multi-Head Latent Attention
一种通过低秩联合压缩 Key/Value 来减少 KV 缓存、提升推理效率的注意力机制,由 DeepSeek 团队在 DeepSeek-V2 中首次提出,在保持多头表达力的同时,实现接近 MQA 的内存效率。


🎯 核心动机:为什么需要 MLA?

标准 Multi-Head Attention (MHA) 在推理时需缓存完整的 K 和 V 矩阵,导致:

  • KV 缓存爆炸:序列长度 × 层数 × 头数 × 头维度 × 2
  • 显存瓶颈:限制长上下文、高并发推理
  • 计算冗余:K/V 矩阵存在大量低秩结构,可压缩

现有方案对比

方案 KV 缓存大小 表达能力 推理速度 代表模型
MHA 2 ⋅ L ⋅ H ⋅ d h 2 \cdot L \cdot H \cdot d_h 2LHdh Llama, GPT-3
MQA 2 ⋅ L ⋅ d h 2 \cdot L \cdot d_h 2Ldh Falcon, Phi-2
GQA 2 ⋅ L ⋅ G ⋅ d h 2 \cdot L \cdot G \cdot d_h 2LGdh Llama2-70B, Mixtral
MLA L ⋅ d c L \cdot d_c Ldc 高(近似 MHA) 快(近似 MQA) DeepSeek-V2

MLA 核心优势

  • KV 缓存压缩比 d c ≪ H ⋅ d h d_c \ll H \cdot d_h dcHdh → 缓存大小 ≈ MQA
  • 保持多头表达力:通过低秩重建 + RoPE 解耦,不损失性能
  • 矩阵吸收优化:推理时无需显式重建 K/V,可吸收到 Q/O 矩阵

🧮 数学形式化(修正 + 增强版)

设输入 token 嵌入: h t ∈ R d h_t \in \mathbb{R}^d htRd

1. 联合压缩 Key & Value(核心创新)

引入低秩潜在向量 c t K V ∈ R d c c_t^{KV} \in \mathbb{R}^{d_c} ctKVRdc,其中 d c ≪ d h ⋅ H d_c \ll d_h \cdot H dcdhH

c t K V = W D K V h t (下投影) c_t^{KV} = W^{DKV} h_t \quad \text{(下投影)} ctKV=WDKVht(下投影)
k t C = W U K c t K V , v t C = W U V c t K V (上投影重建) k_t^C = W^{UK} c_t^{KV}, \quad v_t^C = W^{UV} c_t^{KV} \quad \text{(上投影重建)} ktC=WUKctKV,vtC=WUVctKV(上投影重建)

💡 关键设计:K 和 V 共享同一个潜在向量 c t K V c_t^{KV} ctKV,实现联合压缩。


2. 解耦位置编码(RoPE)

为保留位置信息,引入独立路径生成带 RoPE 的 K:

k t R = RoPE ( W K R h t ) k_t^R = \text{RoPE}(W^{KR} h_t) ktR=RoPE(WKRht)

最终 K 为拼接形式:
k t = [ k t C ; k t R ] ∈ R ( d h + d h R ) × H k_t = [k_t^C; k_t^R] \in \mathbb{R}^{(d_h + d_h^R) \times H} kt=[ktC;ktR]R(dh+dhR)×H

⚠️ 注意:在 DeepSeek-V2 中, k t R k_t^R ktR 使用 单头设计(MQA-style),即所有头共享同一个 RoPE-K,进一步节省缓存。


3. 查询 Q 的低秩压缩(可选,用于训练内存优化)

c t Q = W D Q h t c_t^Q = W^{DQ} h_t ctQ=WDQht
q t C = W U Q c t Q q_t^C = W^{UQ} c_t^Q qtC=WUQctQ
q t R = RoPE ( W Q R c t Q ) q_t^R = \text{RoPE}(W^{QR} c_t^Q) qtR=RoPE(WQRctQ)
q t = [ q t C ; q t R ] q_t = [q_t^C; q_t^R] qt=[qtC;qtR]


4. 注意力计算

对每个头 i i i

score t , j , i = q t , i T k j , i d h + d h R \text{score}_{t,j,i} = \frac{q_{t,i}^T k_{j,i}}{\sqrt{d_h + d_h^R}} scoret,j,i=dh+dhR qt,iTkj,i
α t , j , i = softmax j ( score t , j , i ) \alpha_{t,j,i} = \text{softmax}_j(\text{score}_{t,j,i}) αt,j,i=softmaxj(scoret,j,i)
o t , i = ∑ j α t , j , i ⋅ v j , i C o_{t,i} = \sum_j \alpha_{t,j,i} \cdot v_{j,i}^C ot,i=jαt,j,ivj,iC

最终输出:

u t = W O ⋅ Concat ( o t , 1 , … , o t , H ) u_t = W^O \cdot \text{Concat}(o_{t,1}, \dots, o_{t,H}) ut=WOConcat(ot,1,,ot,H)


💾 KV 缓存优化机制(重点增强)

推理阶段只需缓存:

潜在向量 c j K V ∈ R d c c_j^{KV} \in \mathbb{R}^{d_c} cjKVRdc,而非完整的 k j , v j k_j, v_j kj,vj

→ 缓存大小从 2 ⋅ L ⋅ H ⋅ d h 2 \cdot L \cdot H \cdot d_h 2LHdh 降至 L ⋅ d c L \cdot d_c Ldc

矩阵吸收技巧(无需显式重建 K/V):

在推理时,可预先合并矩阵:

  • W U K W^{UK} WUK 吸收到 W Q W^Q WQ W new Q = W Q ⋅ W U K W^Q_{\text{new}} = W^Q \cdot W^{UK} WnewQ=WQWUK
  • W U V W^{UV} WUV 吸收到 W O W^O WO W new O = W O ⋅ W U V W^O_{\text{new}} = W^O \cdot W^{UV} WnewO=WOWUV

实际推理中,从不显式计算 k t C , v t C k_t^C, v_t^C ktC,vtC,直接用 c t K V c_t^{KV} ctKV 参与点积


🖼️ 结构图

在这里插入图片描述

🔄 训练时:显式计算所有路径
🚀 推理时:缓存 c t K V c_t^{KV} ctKV,吸收矩阵,不重建 K/V


🧪 完整可运行代码(增强版:支持 KV 缓存、注释优化)

import torch
import torch.nn as nn
import math


class RotaryEmbedding(nn.Module):
    def __init__(self, d_model: int, num_heads: int, base: int = 10000, max_len: int = 512):
        """
        旋转位置编码(RoPE)模块

        参数:
            d_model (int): 输入特征维度
            num_heads (int): 注意力头数
            base (int): 频率基底,控制波长范围,默认10000
            max_len (int): 预生成位置编码的最大长度,默认512
        """
        super().__init__()
        assert d_model % num_heads == 0, "d_model 必须能被num_heads整除"

        self.head_dim = d_model // num_heads  # 每个注意力头的维度
        self.d_model = d_model
        self.num_heads = num_heads
        self.base = base
        self.max_len = max_len

        # 预计算位置编码(训练时固定不更新)
        self.register_buffer("cos_pos_cache", self._compute_cos_emb())
        self.register_buffer("sin_pos_cache", self._compute_sin_emb())

    def _compute_angle_rates(self):
        """计算角度变化率 theta_i = 1/(base^(2i/d))"""
        # 示例:当head_dim=4时,i的取值为[0, 1, 2]
        i = torch.arange(0, self.head_dim, 2, dtype=torch.float)
        return 1.0 / (self.base ** (i / self.head_dim))

    def _compute_cos_emb(self):
        """ 计算余弦分量位置编码 """
        theta = self._compute_angle_rates()
        positions = torch.arange(self.max_len).unsqueeze(1)  # [max_len, 1]
        pos_angle = positions * theta  # [max_len, head_dim//2]
        return torch.cos(pos_angle).repeat_interleave(2, dim=-1)  # 维度扩展 [max_len, head_dim]

    def _compute_sin_emb(self):
        """ 计算正弦分量位置编码 """
        theta = self._compute_angle_rates()
        positions = torch.arange(self.max_len).unsqueeze(1)  # [max_len, 1]
        pos_angle = positions * theta  # [max_len, head_dim//2]
        return torch.sin(pos_angle).repeat_interleave(2, dim=-1)  # 维度扩展 [max_len, head_dim]

    def _rotate_half(self, x):
        """ 执行旋转操作:将后一半维度与前一半交换,并取反 """
        x1, x2 = x.chunk(2, dim=-1)
        return torch.cat((-x2, x1), dim=-1)

    def forward(self, q):
        """ 应用旋转位置编码到查询向量
        参数:
            q (Tensor): 输入查询向量,形状为 [batch_size, seq_len, d_model]
        返回:
            rotated_q (Tensor): 旋转后的查询向量,形状保持 [batch_size, seq_len, d_model]
        """
        batch_size, seq_len, _ = q.shape

        # 获取当前序列长度的位置编码
        cos_pos = self.cos_pos_cache[:seq_len]  # [seq_len, head_dim]
        sin_pos = self.sin_pos_cache[:seq_len]  # [seq_len, head_dim]

        # 调整查询向量形状以匹配查询向量 [batch_size, num_heads, seq_len, head_dim]
        q = q.view(batch_size, seq_len, self.num_heads, self.head_dim).transpose(1, 2)

        # 扩展位置编码维度以匹配查询向量 [batch_size, num_heads, seq_len, head_dim]
        # 使用unsqueeze自动广播替代显式repeat操作,更高效
        cos_pos = cos_pos.unsqueeze(0).unsqueeze(1)  # [1, 1, seq_len, head_dim]
        sin_pos = sin_pos.unsqueeze(0).unsqueeze(1)  # [1, 1, seq_len, head_dim]

        # 执行旋转操作(高效实现)
        rotated_q = q * cos_pos + self._rotate_half(q) * sin_pos

        # 恢复原始形状 [batch_size, seq_len, d_model]
        return rotated_q.transpose(1, 2).contiguous().view(batch_size, seq_len, -1)


class MLA(nn.Module):
    def __init__(self, d_model=512, down_dim=128, up_dim=256, num_heads=8,
                 rope_head_dim=26, dropout_prob=0.1):
        """
        Args:
            d_model (int): 输入特征维度
            down_dim (int): 低秩降维后的维度
            up_dim (int): 升维后的维度 (需能被num_heads整除)
            num_heads (int): 注意力头数
            rope_head_dim (int): RoPE (旋转位置编码) 每个头的维度
            dropout_prob (float): Dropout概率,默认0.1
        """
        super(MLA, self).__init__()

        # 参数初始化
        self.d_model = d_model
        self.down_dim = down_dim
        self.up_dim = up_dim
        self.num_heads = num_heads
        self.head_dim = d_model // num_heads  # 标准注意力头的维度
        self.rope_head_dim = rope_head_dim
        self.v_head_dim = up_dim // num_heads  # 位向量的每个头维度

        # 低秩投影层 (用于Key/Value的联合降维)
        self.down_proj_kv = nn.Linear(d_model, down_dim)  # W^(DKV): 联合降维K/V
        self.up_proj_k = nn.Linear(down_dim, up_dim)  # W^(UK): 升维K
        self.up_proj_v = nn.Linear(down_dim, up_dim)  # W^(UV): 升维V

        # 查询向量独立降维
        self.down_proj_q = nn.Linear(d_model, down_dim)  # W^(DQ): Q的降维
        self.up_proj_q = nn.Linear(down_dim, up_dim)  # W^(UQ): Q的升维

        # 解耦的RoPE投影层 (独立处理Q/K)
        self.proj_qr = nn.Linear(d_model, rope_head_dim * num_heads)  # 生成多头RoPE的Q
        self.proj_kr = nn.Linear(d_model, rope_head_dim * 1)  # 生成单头RoPE的K (MQA设计)
        # RoPE位置编码实例(Q使用多头,K使用单头)
        self.rope_q = RotaryEmbedding(rope_head_dim * num_heads, num_heads)  # Q使用多查询注意力 (MQA)
        self.rope_k = RotaryEmbedding(rope_head_dim, 1)  # K使用单查询注意力 (MQA)

        # 注意力计算后的处理层
        self.dropout = nn.Dropout(dropout_prob)  # 注意力权重Dropout
        self.fc = nn.Linear(num_heads * self.v_head_dim, d_model)  # 合并多头输出
        self.res_dropout = nn.Dropout(dropout_prob)  # 残差连接后的Dropout

    def forward(self, h, mask=None):
        """
        Args:
            h (Tensor): 输入张量,形状为 [batch_size, seq_len, d_model]
            mask (Tensor): 注意力掩码,形状为 [batch_size, seq_len, seq_len]

        Return:
            output (Tensor): 输出张量,形状同输入
        """
        bs, seq_len, _ = h.size()

        # --- 阶段1: 低秩变换 ---
        # 对K/V进行联合降维+升维
        c_t_kv = self.down_proj_kv(h)  # [bs, seq, down_dim]
        k_t_c = self.up_proj_k(c_t_kv)  # [bs, seq, up_dim]
        v_t_c = self.up_proj_v(c_t_kv)  # [bs, seq, up_dim]

        # 对Q独立降维+升维
        c_t_q = self.down_proj_q(h)  # [bs, seq, down_dim]
        q_t_c = self.up_proj_q(c_t_q)  # [bs, seq, up_dim]

        # --- 阶段2: 解耦的RoPE处理 ---
        # 生成带RoPE的Q/K(维度扩展为[bs, num_heads, seq_len, rope_head_dim])
        q_t_r = self.rope_q(self.proj_qr(h))  # Q的RoPE(多头)
        k_t_r = self.rope_k(self.proj_kr(h))  # K的RoPE(单头,MQA设计)

        # --- 阶段3: 张量拼接与注意力计算 ---
        # 处理Q的低秩部分:调整形状以匹配多头
        q_t_c = q_t_c.reshape(bs, seq_len, self.num_heads, -1).transpose(1, 2)
        # [bs, heads, seq, head_dim]

        # 处理Q的RoPE部分:调整形状以匹配多头
        q_t_r = q_t_r.reshape(bs, seq_len, self.num_heads, -1).transpose(1, 2)
        # [bs, heads, seq, rope_head_dim]

        q = torch.cat([q_t_c, q_t_r], dim=-1)  # 拼接低秩Q和RoPE Q [bs, heads, seq, head_dim + rope_head_dim]

        # 处理K的低秩部分:调整形状以匹配多头
        k_t_c = k_t_c.reshape(bs, seq_len, self.num_heads, -1).transpose(1, 2)
        # [bs, heads, seq, head_dim]

        # 处理K的RoPE部分:调整形状以匹配多头
        k_t_r = k_t_r.reshape(bs, seq_len, 1, -1).transpose(1, 2)  # 先转为[bs, 1, seq, rope_head_dim]
        k_t_r = k_t_r.repeat(1, self.num_heads, 1, 1)  # 再复制到多头[bs, heads, seq, rope_head_dim]
        # [bs, heads, seq, rope_head_dim]

        k = torch.cat([k_t_c, k_t_r], dim=-1)  # 拼接低秩K和RoPE K
        # [bs, heads, seq, head_dim + rope_head_dim]

        # 计算缩放点积注意力
        scores = torch.matmul(q, k.transpose(-1, -2))  # [bs, heads, seq, seq]
        if mask is not None:  # 应用注意力掩码
            # 调整mask的形状以匹配注意力分数
            mask = mask.unsqueeze(1)  # [bs, 1, seq, seq]
            scores = scores.masked_fill(mask == 0, -1e9)

        # 缩放(考虑拼接后的总维度)
        scale = math.sqrt(self.head_dim + self.rope_head_dim)  # 更新缩放因子
        scores = torch.softmax(scores / scale, dim=-1)
        scores = self.dropout(scores)

        # 计算加权值向量
        v_t_c = v_t_c.reshape(bs, seq_len, self.num_heads, self.v_head_dim).transpose(1, 2)
        # [bs, heads, seq, v_head_dim]
        output = torch.matmul(scores, v_t_c)  # [bs, heads, seq, v_dim]

        # 合并多头输出并通过全连接层
        output = output.transpose(1, 2).reshape(bs, seq_len, -1)  # [bs, seq, d_model]
        output = self.fc(output)  # [bs, seq, d_model]
        output = self.res_dropout(output)  # 残差连接前应用Dropout
        return output


if __name__ == '__main__':
    # 假设我们有一些输入参数
    batch_size = 4
    seq_len = 256
    d_model = 512

    # 创建一个随机输入张量,模拟一批序列数据
    input_tensor = torch.randn(batch_size, seq_len, d_model)

    # 初始化 MLA 模块
    mla_layer = MLA(d_model=d_model, down_dim=128, up_dim=256, num_heads=8,
                    rope_head_dim=26, dropout_prob=0.1)

    # 创建一个可选的注意力掩码(例如用于屏蔽填充位置)
    # 这里我们创建一个全1的掩码,表示所有位置都可见
    mask = torch.ones(batch_size, seq_len, seq_len)

    # 执行前向传播
    output = mla_layer(input_tensor, mask=mask)

    print(f"Input shape: {input_tensor.shape}")
    print(f"Output shape: {output.shape}")

    # 验证输出张量形状是否与输入一致
    assert input_tensor.shape == output.shape, "输入和输出张量形状不匹配"

在这里插入图片描述

📊 性能对比与适用场景

MLA vs MHA vs MQA vs GQA

指标 MHA MQA GQA (G=8) MLA
KV 缓存大小 2×L×H×d 2×L×d 2×L×G×d L×d_c
计算复杂度 O(L²Hd) O(L²d) O(L²Gd) O(L²Hd)(训练)
O(L²d_c)(推理)
表达能力
推理速度
适用场景 短文本、训练 长文本、低成本 平衡场景 长文本+高性能

MLA 最佳适用场景

  • 长上下文推理(32K+ tokens)
  • 高并发服务(KV 缓存小 → 支持更多并发)
  • 资源受限但要求高性能(如边缘设备、手机端)

⚠️ 局限性与注意事项

  1. 训练复杂度未降低:MLA 主要优化推理,训练时仍需计算完整路径。
  2. 超参敏感 d c d_c dc 需仔细调优,过小损失性能,过大失去压缩意义。
  3. 矩阵吸收需工程支持:推理框架需支持矩阵融合(如 vLLM、TensorRT-LLM)。
  4. 位置编码设计关键:RoPE-K 的单头设计是性能保障,不可随意改为多头。

📚 原始论文与引用

MLA 首次提出于
《DeepSeek-V2: A Strong, Economical, and Efficient Mixture-of-Experts Language Model》
DeepSeek AI, 2024
arXiv:2405.04434

🏁 总结

MLA 是目前最先进的注意力机制之一,它:

✅ 在推理效率上媲美 MQA
✅ 在模型性能上接近 MHA
✅ 通过低秩联合压缩 + RoPE 解耦 + 矩阵吸收三重优化实现突破
✅ 是构建长上下文、高并发、低成本 LLM 服务的理想选择

💡 未来方向

  • 动态压缩率(根据 token 重要性调整 d c d_c dc
  • 与 MoE 结合(专家级 MLA)
  • 硬件友好设计(专用 MLA 加速器)
Logo

讨论HarmonyOS开发技术,专注于API与组件、DevEco Studio、测试、元服务和应用上架分发等。

更多推荐