本站内容由 AI 生成,可能存在错误。如发现问题,欢迎到 GitHub Issues 反馈。

状态空间模型与 Mamba

状态空间模型与 Mamba

更新于 2026-04-13

引言

标准 Attention 的计算复杂度为 O(L2)O(L^2)LL = 序列长度),推理时 KV cache 随序列长度线性增长。100K token 的上下文意味着 Attention 矩阵有 101010^{10} 个元素——显存和计算成本成为不可忽视的瓶颈。

RNN 看似是一个解决方案:它的推理复杂度为 O(1)O(1),每步只需一个固定大小的隐状态。但传统 RNN 有两个致命弱点:(1) 梯度消失/爆炸导致无法学习长距离依赖;(2) 逐步递推的本质使得训练无法并行化。

有没有一种方法,同时解决两个问题——训练可并行(像 Attention)+ 推理 O(1)O(1)(像 RNN)+ 理论上可学习任意长距离依赖?状态空间模型 (State Space Model, SSM) 正是这样一种方案。SSM 源自控制论和信号处理,核心思想是用一个 固定大小的状态向量 压缩序列的全部历史信息,每处理一个新 token 只需更新状态而非回看全部历史。SSM 的关键演进路线:HiPPO (2020) → S4 (2021) → H3 (2022) → Mamba (2023) → Mamba-2 (2024)。


1. 连续状态空间模型

SSM 的数学基础是一个连续时间的线性动态系统:

x˙(t)=Ax(t)+Bu(t),y(t)=Cx(t)+Du(t)\dot{x}(t) = Ax(t) + Bu(t), \quad y(t) = Cx(t) + Du(t)

其中 u(t)u(t) 是输入信号,x(t)RNx(t) \in \mathbb{R}^N 是隐状态,y(t)y(t) 是输出。矩阵 ARN×NA \in \mathbb{R}^{N \times N} 控制状态的演化规律,BRN×1B \in \mathbb{R}^{N \times 1} 控制输入如何写入状态,CR1×NC \in \mathbb{R}^{1 \times N} 控制如何从状态读出输出。

其中 DD 项是直接前馈连接(skip connection),在 S4 和 Mamba 的实际实现中通常设为 0 或恒等映射,不影响 SSM 的核心动态,后文省略。

直觉:状态 xx 是”压缩的历史摘要”。与 Attention 不同——Attention 每步都回看完整的 token 序列——SSM 只维护一个固定大小的状态,所有历史信息都被压缩到这 NN 个维度中。这就像一个”带记忆的滤波器”:输入信号经过系统,被状态记住、混合、然后输出。

上述是单输入单输出(SISO)的最简形式。实际实现中,SSM 在模型的每个特征维度 DD 上独立运行(类似 depth-wise convolution)。

信号处理类比:SSM 就像一个”带记忆的 IIR 滤波器”。输入信号 u(t)u(t) 通过系统,状态 x(t)x(t) 相当于滤波器的内部寄存器,AA 控制寄存器间的反馈(极点位置),BB 控制输入到寄存器的路径,CC 控制从寄存器到输出的路径。

关键问题:AA 矩阵的初始化决定了”记忆的数学结构”。随机初始化的 AA 会导致信息指数衰减,使模型无法学习长距离依赖。这正是下一节 HiPPO 要解决的问题。

1. 初始状态 x₀ = 0
SSM 初始化:空状态向量x₀ (state)0.000.000.000.00∈ ℝ^4状态向量 x ∈ ℝᴺ 是 SSM 的"记忆"— 固定大小,不随序列增长N 通常为 16-64,远小于序列长度

1.5 HiPPO:SSM 记忆的数学基础

一个 NN 维状态向量如何”最优地”压缩一段连续信号的历史?随机初始化的 AA 矩阵会导致信息指数衰减——状态只记得最近几步,远处信息迅速丢失。

HiPPO(High-order Polynomial Projection Operators, Gu et al., NeurIPS 2020)给出了一个优雅的解法:用 正交多项式基(Legendre 多项式)来逼近历史信号。状态向量的第 nn 个分量 = 信号在第 nn 个 Legendre 多项式上的投影系数。这样 NN 维状态就是信号历史的 NN 阶多项式逼近,而非简单的指数衰减。

HiPPO-LegS 矩阵(HiPPO 论文 Section 3.2)具有特殊的下三角 + 对角结构。这个特定的 AA 矩阵使得状态更新恰好等价于”在线计算 Legendre 投影系数”——每处理一个新 token,状态向量自动维护对整段历史信号的最优多项式逼近。

1. 随机 A 的遗忘问题
随机初始化 A → 指数衰减输入信号(6 个 token)w₁w₂w₃w₄w₅w₆状态向量(N=4)dim₀dim₁dim₂dim₃随机 A → 旧信息指数衰减 → 只记得最近几步

为什么关键:S4 论文实验证明,使用 HiPPO 初始化的 S4 在 Path-X(16384 步序列分类)任务上达到 SoTA,而随机初始化完全失败。HiPPO 为 SSM 提供了理论上可学习任意长距离依赖的基础。

Mamba 的简化:Mamba 不使用完整的 HiPPO 矩阵,而用更简单的 S4D-Real 初始化:A=diag(1,2,...,N)A = -\text{diag}(1, 2, ..., N)。代码实现中,A_log = log([1,...,N]) 存储为参数,计算时 A = -exp(A_log)。负号确保状态衰减(稳定性),对数存储确保数值稳定。Mamba 的选择性机制(input-dependent Δ\Delta, BB, CC)弥补了简化初始化的表达力损失。


2. 离散化:从连续到序列

为什么需要离散化? 语言模型处理离散 token 序列,但 SSM 的数学是连续时间的。类比:连续信号 → 数字采样,步长 Δ\Delta 就是采样间隔。

x˙(t)=Ax(t)+Bu(t)\dot{x}(t) = Ax(t) + Bu(t) 出发,在 [kΔ,(k+1)Δ][k\Delta, (k+1)\Delta] 区间积分,假设 u(t)u(t) 在区间内恒定(Zero-Order Hold 假设),得到:

Aˉ=eΔA,Bˉ=A1(eΔAI)B\bar{A} = e^{\Delta A}, \quad \bar{B} = A^{-1}(e^{\Delta A} - I) \cdot B

离散化后的递推公式:

xk=Aˉxk1+Bˉuk,yk=Cxkx_k = \bar{A} x_{k-1} + \bar{B} u_k, \quad y_k = C x_k

最简单的离散化是一阶 Euler 近似:AˉI+ΔA\bar{A} \approx I + \Delta A, BˉΔB\bar{B} \approx \Delta B。直觉清晰但精度低。Mamba 和 S4 实际使用 ZOH(精度更高)。不同离散化方法产生不同的 Aˉ\bar{A}, Bˉ\bar{B},但都将同一个连续系统映射到离散递推。

步长 Δ\Delta 在 LTI(线性时不变)语境下控制系统的”时间分辨率”:类比音频采样率——44.1kHz(小 Δ\Delta,高分辨率)vs 8kHz(大 Δ\Delta,低分辨率)。注意:到第 4 节 Mamba 的选择性 Δ\Delta 时,含义不同——大 Δ\Delta 意味着 “reset & focus”,小 Δ\Delta 意味着 “persist & ignore”(详见第 4 节)。


3. Recurrence 与 Convolution 的对偶性

离散 SSM 有一个优雅的性质:同一个模型可以用两种完全不同的方式计算。

Recurrence 模式:逐步递推 xk=Aˉxk1+Bˉukx_k = \bar{A}x_{k-1} + \bar{B}u_k,每步 O(1)O(1) 计算。适合 推理——每个新 token 只需一次状态更新。

Convolution 模式:将递推展开可以看到卷积结构:

y0=CBˉu0,y1=CAˉBˉu0+CBˉu1,yk=j=0kCAˉkjBˉujy_0 = C\bar{B}u_0, \quad y_1 = C\bar{A}\bar{B}u_0 + C\bar{B}u_1, \quad y_k = \sum_{j=0}^{k} C\bar{A}^{k-j}\bar{B} \cdot u_j

这正是卷积 y=Kˉuy = \bar{K} * u,卷积核 Kˉi=CAˉiBˉ\bar{K}_i = C\bar{A}^i\bar{B}。用 FFT 加速(时域卷积 = 频域乘法):y=IFFT(FFT(Kˉ)FFT(u))y = \text{IFFT}(\text{FFT}(\bar{K}) \cdot \text{FFT}(u)),整个序列在 O(LlogL)O(L \log L) 时间内并行计算(LL = 序列长度)。适合 训练——所有 token 同时处理。

Recurrence 模式(推理)Recurrence (逐步)Convolution (并行)输入 uu1u2u3u4u5u6状态 xx1B̄ux2ĀB̄ux3ĀB̄ux4ĀB̄ux5ĀB̄ux6ĀB̄u输出 yy1y2y3y4y5y6O(1) per step · O(N) total · Sequential适合推理:每个新 token 只需一次状态更新

这个对偶性是 S4 (Structured State Spaces for Sequence Modeling, Gu et al., ICLR 2022) 的核心贡献:训练时用 convolution 充分利用 GPU 并行性,推理时切换为 recurrence 实现 O(1)O(1) 增量推理。

S4 的技术关键:计算卷积核 Kˉi=CAˉiBˉ\bar{K}_i = C\bar{A}^i\bar{B} 需要 Aˉ\bar{A} 的高次幂——对一般矩阵代价很高。S4 通过 NPLR/DPLR 参数化(将 AA 分解为 Diagonal + Low-Rank),利用 Cauchy 核公式高效完成。简化版 S4D 直接对角化 AA,则 Aˉi\bar{A}^i 就是对角元素的 ii 次幂。

S4 的标志性成果:Long Range Arena 全部任务 SoTA,包括此前所有方法都失败的 Path-X(16384 步序列分类);Sequential CIFAR-10 达到 91%(无数据增强);生成速度比同参数 Transformer 快 60×。


3.5 从 S4 到 Mamba:为什么需要选择性?

S4 虽然在长序列任务上表现出色,但有一个根本局限——线性时不变 (LTI):所有参数(AA, BB, CC, Δ\Delta)固定不变。这意味着 SSM 对 “the” 和 “cat” 用完全相同的方式处理。从卷积视角看:固定卷积核 = 固定滤波器,无法根据内容做选择。

Selective Copying 任务(Mamba 论文 Section 4.1)清晰暴露了这个问题:任务要求在噪声 token 中选出有色 token 并按顺序复制,且间距不固定。S4(LTI)准确率仅 18.3%,而 Mamba(选择性)达到 99.8%。原因:静态卷积核无法处理不定间距的选择性复制。

H3(Hungry Hungry Hippos, Fu et al., 2022)尝试了一种过渡方案:在 SSM 层外加一些 Attention 层来补足检索能力。125M 的 H3-Attention hybrid 模型超过了纯 Transformer。但这引出了一个更根本的问题:与其在外面加 Attention 来补 SSM,不如让 SSM 自己具备选择性。这个思路直接催生了 Mamba。


4. Mamba 的选择性机制

Mamba (Gu & Dao, 2023) 的核心创新是 选择性 (Selectivity):让关键参数依赖于当前输入,从根本上解决第 3.5 节所述的 LTI 局限。

  • 选择性 BBBk=Linear(uk)B_k = \text{Linear}(u_k) — 控制”写入什么到状态”
  • 选择性 CCCk=Linear(uk)C_k = \text{Linear}(u_k) — 控制”从状态读出什么”
  • 选择性 Δ\DeltaΔk=softplus(Linear(uk))\Delta_k = \text{softplus}(\text{Linear}(u_k)) — 控制”记忆的更新强度”

Δ\Delta 的作用机制:由于 AA 初始化为负值,Aˉ=eΔA\bar{A} = e^{\Delta A} 的行为取决于 Δ\Delta 的大小。Δ\DeltaΔA\Delta A 是大负数 → Aˉ0\bar{A} \to 0,旧状态被大幅衰减,同时 Bˉ\bar{B} 增大,当前输入被强烈写入——效果是 “reset & focus”(清空旧态,聚焦当前 token)Δ\DeltaAˉI\bar{A} \approx I,旧状态几乎完整保留,同时 Bˉ0\bar{B} \to 0,当前输入几乎被忽略——效果是 “persist & ignore”(保留旧态,忽略当前 token)

对于内容词(“cat”、“mat”),模型输出大 Δ\Delta,将当前 token 强烈写入状态;对于功能词(“the”、“on”),模型输出小 Δ\Delta,让状态几乎不变。

RNN 门控等价:Mamba 论文 Theorem 1 证明,当 N=1N=1, A=1A=-1, B=1B=1 时,选择性 SSM 精确退化为门控 RNN:gt=σ(Linear(xt))g_t = \sigma(\text{Linear}(x_t)), ht=(1gt)ht1+gtxth_t = (1-g_t)h_{t-1} + g_t x_t(1gt)(1-g_t) 控制遗忘,gtg_t 控制写入——两者耦合为同一标量(不同于 LSTM 的独立遗忘门和输入门)。换言之,SSM 的离散化步长 Δ\Delta 是 RNN 启发式门控机制的理论基础

Mamba 选择性机制:Δ 控制状态更新点击 token 查看其对状态的影响ThecatsatonthematΔ 值0.150.850.720.120.100.88Ā=0.74Ā=0.18Ā=0.24Ā=0.79Ā=0.82Ā=0.17大 Δ = reset & focus(写入当前)小 Δ = persist & ignore(保留旧态)累积状态 (全部 tokens)0.213dim 0-0.106dim 10.295dim 20.164dim 3选择性 Δ 让 Mamba 自适应地关注重要 token,忽略噪声 — 类似 Attention 的"软选择"

Parallel Scan:选择性 SSM 的并行训练

选择性的代价:参数依赖输入后,SSM 不再是时不变系统 (LTI),卷积核 Kˉ\bar{K} 随输入变化,无法再用 FFT 加速。递推 xk=Aˉkxk1+Bˉkukx_k = \bar{A}_k x_{k-1} + \bar{B}_k u_k 中参数随 kk 变化——如何并行化?

关键观察:递推 (xk,ak)=(akxk1+bk)(x_k, a_k) = (a_k \cdot x_{k-1} + b_k) 是一个结合律运算。两个运算 (a1,b1)(a_1, b_1)(a2,b2)(a_2, b_2) 可以组合为 (a2a1,a2b1+b2)(a_2 a_1, a_2 b_1 + b_2)。结合律意味着可以用 并行前缀和(parallel prefix sum)算法——类似 GPU 上的 parallel reduction:

  • Work: O(L)O(L)(总计算量与顺序执行相同)
  • Span: O(logL)O(\log L)(关键路径长度,决定并行时间)
1. 顺序递推 — O(L) 步
选择性 SSM 必须逐步计算x1x2x3x4x5x6x7x8xₖ = Āₖ · xₖ₋₁ + B̄ₖ · uₖ1234567必须等前一步完成才能计算下一步 → 7 步完成 8 个元素

硬件感知优化(Mamba 论文 Section 3.3.2):朴素实现需要在 HBM 中存储 (Aˉ,Bˉ)(\bar{A}, \bar{B}) 张量,尺寸为 (B,L,D,N)(B, L, D, N)——显存爆炸。Mamba 的做法:将 SSM 参数 (Δ,A,B,C)(\Delta, A, B, C) 直接从 HBM 加载到 SRAM,在 SRAM 中完成离散化 + scan,只将最终输出 (B,L,D)(B, L, D) 写回 HBM。反向传播时不存中间状态,而是重新计算(类似 gradient checkpointing)。结果:与使用 FlashAttention 的优化 Transformer 具有相同的显存需求。

Mamba Block 架构

Mamba 将 Transformer 中 Attention 和 MLP 的功能融合到一个更简洁的 block 中:

Mamba Block 架构Input x(B, L, D)ResidualLinear ↑D → ED (expand)Conv1dkernel=4SiLU (σ)activationSSM (Selective)Δ, B, C = f(input)SiLU (gate)(B, L, ED)×Linear ↓ED → D (contract)+Output (B, L, D)无 Attention无 MLP比 Transformerblock 更简洁

架构要点:

  • Conv1d 的作用:提供局部上下文(kernel size = 4),让 SSM 在做选择性决策时能”看到”邻近 token。没有 Conv1d,SSM 的 BB, CC, Δ\Delta 只基于单个 token embedding 做决策。
  • Gating 机制:右分支的 SiLU gate 类似 GLU (Gated Linear Unit)。输出 = SSM_output × σ\sigma(gate_input),控制信息流,防止过度激活。
  • 维度变化:expand factor E=2E=2,Linear↑ 将 D2DD \to 2D,Linear↓ 将 2DD2D \to D
  • 与 Transformer 对比:Transformer block = LayerNorm → Multi-Head Attention → Residual → LayerNorm → MLP → Residual;Mamba block = LayerNorm → Linear↑ → (Conv1d → SiLU → SSM) × Gate → Linear↓ → Residual。

5. Mamba-2: State Space Duality

Mamba-2 (Dao & Gu, ICML 2024) 建立了 SSM 和 Attention 之间的深层数学联系——State Space Duality (SSD)

核心发现:将 SSM 递推展开为矩阵形式 Y=MUY = M \cdot U,矩阵 MM 是一个 semiseparable matrix(半可分矩阵)。矩阵 MM 的下三角部分的任意子矩阵 rank N\leq N(状态维度)。直觉:因为每步只有 NN 维状态”传递”信息,矩阵的”信息带宽”被限制为 NN。对比 Attention:QKTQK^T 的 rank = head dim,远高于 SSM 的 NN,因此更 expressive 但也更慢。

1. SSM → Semiseparable Matrix
SSM 递推展开为矩阵乘法 Y = M · UM (semiseparable)0.60.00.00.00.00.00.50.90.00.00.00.00.40.70.30.00.00.00.40.60.30.60.00.00.30.50.20.50.90.00.30.40.20.40.70.3·U (input)u1u2u3u4u5u6Semiseparable 结构:M[i,j] = C·Ā^(i-j)·B̄ (j ≤ i)M[i,j] = 0 (j > i)因果 + 指数衰减结构SSM 递推 xₖ = Āxₖ₋₁ + B̄uₖ 展开后等价于一个结构化矩阵乘法

Chunk-wise 算法:将长度 LL 序列分成 L/QL/Q 个 chunk,每个大小 QQ。Chunk 内用矩阵乘法(Q×QQ \times Q 矩阵,利用 tensor core),chunk 间用 SSM scan(只传递 NN 维状态)。总计算量 O(LN2)O(LN^2) FLOPs, O(LN)O(LN) 显存——类似 Flash Attention 的分块策略。

Multi-head SSM:Mamba-1 的 head dim P=1P=1(每个特征维度独立 SSM)。Mamba-2 引入 head dim P=64P=64128128(多个维度共享 AA,类似 multi-head attention)。更大的 head dim 让矩阵乘法更高效(GPU 偏好大矩阵),同时增强模型表达能力。

结构化 Masked Attention 视角:SSD 的 dual form 中,Λij=t=j+1iat\Lambda_{ij} = \prod_{t=j+1}^{i} a_t 是 input-dependent 标量的连乘。这可以看作用 data-dependent 的位置 mask 替代了 Transformer 的启发式位置编码。加速的主要来源是 semiseparable 结构允许 sub-quadratic 的 chunk-wise 算法,而非仅仅省去 softmax。

性能:Mamba-2 核心层速度比 Mamba-1 fused scan 快 2-8×;与 FlashAttention-2 的交叉点在序列长度 ~2K,序列长度 16K 时比 FlashAttention-2 快 ~6×。


6. 实战对比与基准测试

以下数据全部来自 Mamba 论文 (Gu & Dao, arXiv:2312.00752)。

语言模型质量

在 Pile 数据集上,Mamba 在各个参数规模上都超越了同规模的 Pythia(Transformer baseline):

模型参数量Perplexity
Pythia-1.4B1.4B7.51
RWKV-1.5B1.5B7.70
Mamba-1.4B1.4B6.80
Pythia-2.8B2.8B6.73
Mamba-2.8B2.8B6.22

下游任务平均准确率(Mamba 论文 Table 1):Mamba-1.4B 达到 59.7%,超越参数量两倍的 Pythia-2.8B(59.1%);Mamba-2.8B 达到 63.3%,超越参数量 2.5 倍的 Pythia-6.9B(61.7%)。即 Mamba 大约以一半参数量匹配同质量 Transformer

推理吞吐量

Mamba 比同参数 Transformer 快 (A100 80GB,prompt 2048 + gen 128)。

Mamba vs Transformer 性能对比Pile 数据集 Perplexity(越低越好)370M8.148.551.4B6.807.512.8B6.226.73MambaPythia (Transformer)Mamba-2.8B (63.3%) > Pythia-6.9B (61.7%)推理吞吐量强项 vs 弱项推理吞吐量(A100 80GB)Transformer1×Mamba5×Mamba 比同参数 Transformer 快 5×(prompt 2048 + gen 128)

SSM 的弱项

SSM 的固定 NN 维状态本质上是有损压缩。当需要从长序列中精确召回某个特定 token 时(如”第 3000 个 token 是什么?”),NN 维向量无法无损存储 LNL \gg N 个 token 的所有信息。这不是 bug,是 feature——SSM 擅长的是摘要和模式识别,不是精确检索

具体弱项包括:

  • 多查询关联召回 (MQAR):Mamba-1 在此任务上挣扎(Mamba-2 论文 Figure 8)
  • In-context learning:纯 SSM 仍弱于 Attention,Mamba-2 论文建议 hybrid 架构
  • 长序列精确复制:固定状态维度限制了精确复制能力

总结

SSM 提供了一种根本不同于 Attention 的序列建模范式:

AttentionSSM (LTI, 如 S4)SSM (Selective, 如 Mamba)
训练复杂度O(L2)O(L^2)O(LlogL)O(L \log L) (FFT)O(L)O(L) (parallel scan)
推理缓存O(L)O(L) (KV cache)O(1)O(1) (固定状态)O(1)O(1) (固定状态)
历史访问精确(任意 token 对)压缩 + 固定模式压缩 + 自适应模式
内容感知完全(QKV 全依赖输入)无(参数固定)部分(B, C, Δ 依赖输入)
长序列 ICL中等(仍弱于 Attention)
核心优势精确检索并行训练 + 高效推理选择性 + 线性复杂度

SSM 的发展路线图清晰地展示了每一步解决的关键问题:HiPPO(2020)解决了”如何初始化 A 才能记住长历史”→ S4(2021)解决了”如何高效训练”(convolution mode)→ H3(2022)发现”SSM 需要选择性”→ Mamba(2023)解决了”如何让 SSM 选择性地处理输入”→ Mamba-2(2024)统一了 SSM 和 Attention 的数学框架。

这个根本局限引出了下一篇文章的主题:Hybrid 架构 如何结合 Attention 的精确检索和 SSM 的高效摘要,取长补短。

延伸阅读