本站内容由 AI 生成,可能存在错误。如发现问题,欢迎到 GitHub Issues 反馈。

SSM / Mamba:矩阵对角化的胜利

SSM / Mamba:矩阵对角化的胜利

更新于 2026-04-22

这是整条 27 篇矩阵数学路径的最后一篇文章。

Art. 23 学习算子 中,我们预告了 Part 3 的三个汇聚方向:LoRA(压缩微调)、Efficient Attention(加速推理)、SSM/Mamba(高效架构设计)。前两个方向都是事后发现矩阵的低秩性然后加以利用。SSM/Mamba 走了一条不同的路——在架构设计阶段就主动施加对角结构约束,将 Art. 2 特征分解中”对角化简化一切”的洞察推到极致。

更深层地看,SSM/Mamba 将整条路径的三段弧线汇聚在一个点上:

  • Part 1 的工具(特征分解、对角化)提供了数学基础
  • Part 2 的概念(state-space 模型、矩阵指数、离散化,来自 Art. 17 Kalman)提供了建模框架
  • Part 3 的思路(把给定算子变成可学习参数)提供了设计哲学

这三条线在 SSM/Mamba 这一点交汇。让我们从头开始。

为什么需要 SSM?序列建模的困境

Transformer 的 self-attention 是一种强大的序列建模机制,但它有一个根本性的计算瓶颈:注意力矩阵的大小是 O(L2)O(L^2),其中 LL 是序列长度。对于长序列(音频波形 L105L \sim 10^5、基因组序列 L106L \sim 10^6),O(L2)O(L^2) 的计算和内存开销变得不可承受。

Art. 25 Efficient Attention 试图通过低秩近似来降低 attention 的复杂度。但另一条路线更加激进:完全抛弃 attention,回到递推模型的框架

递推模型(RNN)天然是 O(L)O(L) 的——每步只需要固定大小的状态更新。但经典 RNN 有梯度消失/爆炸问题,难以捕获长程依赖。问题在于:

如何设计一个递推模型,既保持 O(L)O(L) 的计算复杂度,又能有效建模长程依赖?

答案来自一个经典的数学领域:连续时间状态空间模型——正是 Art. 17 连续系统与 Kalman 中 Kalman 滤波使用的那套公式。但这次,矩阵不是物理系统给定的,而是从数据中学出来的

从 Kalman 到 SSM:同一公式,不同语境

State-Space 公式回顾

Art. 17 连续系统与 Kalman 中,我们建立了连续时间状态空间模型的标准形式:

x˙(t)=Ax(t)+Bu(t)\dot{\mathbf{x}}(t) = A\mathbf{x}(t) + B\mathbf{u}(t) y(t)=Cx(t)+Du(t)\mathbf{y}(t) = C\mathbf{x}(t) + D\mathbf{u}(t)

逐项回顾(沿用 Art. 17 的符号):

  • x(t)RN\mathbf{x}(t) \in \mathbb{R}^N状态向量——系统的内部记忆
  • u(t)R1\mathbf{u}(t) \in \mathbb{R}^1输入信号(在序列建模中,这是输入 token 的特征投影,为简化讨论取标量)
  • y(t)R1\mathbf{y}(t) \in \mathbb{R}^1输出信号
  • ARN×NA \in \mathbb{R}^{N \times N}状态矩阵——编码记忆如何自演化
  • BRN×1B \in \mathbb{R}^{N \times 1}输入矩阵——输入如何注入状态
  • CR1×NC \in \mathbb{R}^{1 \times N}输出矩阵——状态如何映射为输出
  • DR1×1D \in \mathbb{R}^{1 \times 1}直通矩阵——输入对输出的直接影响(通常设为 0,下文省略)

Kalman 与 SSM 的关键差异

Art. 17 中,这套公式描述的是物理系统(卫星轨道、电路状态),(A,B,C)(A, B, C) 由物理定律和传感器特性决定——它们是给定的(Part 2 的”给定算子”)。

在 SSM/Mamba 中,完全相同的公式被用于序列建模(语言、音频、DNA),(A,B,C)(A, B, C) 是从数据中学习的参数——它们是学习算子(Part 3 的核心转变)。

Kalman 滤波(Art. 17)SSM / Mamba(本文)
AA 的来源物理系统的动力学方程可学习参数(数据驱动)
B,CB, C 的来源传感器模型、系统结构可学习参数
目标从噪声观测中推断隐状态(最优估计)从输入序列预测输出(序列建模)
核心挑战处理噪声和不确定性长程依赖 + 计算效率
AA 的结构任意(由物理决定)约束为对角或 DPLR(对角+低秩)结构

关键的一句话总结:

Kalman 是给定系统的最优滤波,SSM 是学出来的序列建模。

两者用的是同一套数学——state-space 公式、矩阵指数、离散化——但参数的来源和目标完全不同。

HiPPO:连续时间记忆的数学框架

直接把 AA 初始化为随机矩阵然后训练,效果并不好——模型难以学会长程依赖。Gu et al. (2020) 在 HiPPO(High-order Polynomial Projection Operators)框架中给出了一个关键洞察:

AA 矩阵不应该随机初始化——它应该编码一种”最优记忆”策略。

HiPPO 的核心思想

想象你在看一段连续信号 u(t)u(t)。在任意时刻 tt,你想用有限的 NN 维状态向量 x(t)\mathbf{x}(t)压缩地记住 uu[0,t][0, t] 上的全部历史。

HiPPO 的方法是:用一组正交多项式基函数(如 Legendre 多项式)来逼近信号的历史。状态向量 x(t)\mathbf{x}(t) 的第 nn 个分量 xn(t)x_n(t) 就是信号 uu 在第 nn 个基函数上的投影系数:

xn(t)=0tu(s)Pn(t)(s)dsx_n(t) = \int_0^t u(s) \cdot P_n^{(t)}(s)\, ds

其中 Pn(t)P_n^{(t)} 是适当定义的正交多项式(在区间 [0,t][0, t] 上的 Legendre 多项式经过缩放和平移)。

Gu et al. (2020) 的关键推导表明:这个投影系数的更新规则恰好可以写成状态空间形式 x˙(t)=Ax(t)+Bu(t)\dot{\mathbf{x}}(t) = A\mathbf{x}(t) + Bu(t),其中 AA 是一个特定的矩阵。

HiPPO-LegS 矩阵

最重要的 HiPPO 变体是 HiPPO-LegS(Leg = Legendre,S = Scaled),其状态矩阵 AA 的定义为:

-(2n+1)^{1/2}(2k+1)^{1/2} & \text{if } n > k \\ -(n+1) & \text{if } n = k \\ 0 & \text{if } n < k \end{cases}$$ 逐项理解: - **对角元素** $A_{nn} = -(n+1)$:第 $n$ 个记忆分量的自衰减速率。$n$ 越大,衰减越快——高阶多项式系数(编码细节)比低阶系数(编码大趋势)衰减得更快 - **下三角元素** $A_{nk} = -(2n+1)^{1/2}(2k+1)^{1/2}$($n > k$):高阶分量从低阶分量接收信息。$(2n+1)^{1/2}$ 和 $(2k+1)^{1/2}$ 是 Legendre 多项式的归一化因子 - **上三角元素 = 0**:$A$ 是下三角矩阵——信息只从低阶流向高阶,不会倒流 **为什么这个矩阵能编码"最优记忆"?** 因为 Legendre 多项式是在均匀测度下的正交基。HiPPO-LegS 对应的记忆策略是:**对所有过去时刻给予等权重**(uniform measure on $[0, t]$)——它在 $L^2$ 意义下最优地压缩了 $u$ 在 $[0, t]$ 上的全部历史。 对比其他记忆策略: - **指数衰减记忆**(如经典 RNN):近的记得清,远的快速遗忘——$A$ 是标量或简单对角矩阵 - **HiPPO-LegS**:在整个历史上做最优多项式逼近——能同时保留短期细节和长期趋势 ### 数值例子:$N = 4$ 的 HiPPO 矩阵 取 $N = 4$,HiPPO-LegS 矩阵为: $$A = \begin{bmatrix} -1 & 0 & 0 & 0 \\ -\sqrt{3} & -2 & 0 & 0 \\ -\sqrt{5} & -\sqrt{15} & -3 & 0 \\ -\sqrt{7} & -\sqrt{21} & -\sqrt{35} & -4 \end{bmatrix}$$ 验证对角元素:$A_{00} = -1$, $A_{11} = -2$, $A_{22} = -3$, $A_{33} = -4$——确实是 $-(n+1)$。 验证一个下三角元素:$A_{31} = -(2 \times 3 + 1)^{1/2}(2 \times 1 + 1)^{1/2} = -\sqrt{7}\sqrt{3} = -\sqrt{21}$ ✓ **特征值**:由于 $A$ 是下三角矩阵,特征值就是对角元素: $$\lambda_0 = -1, \quad \lambda_1 = -2, \quad \lambda_2 = -3, \quad \lambda_3 = -4$$ 所有特征值都是负实数——系统是**稳定的**([Art. 17 连续系统与 Kalman](../matrix-math-continuous-systems-kalman) 中的稳定性判据:$\text{Re}(\lambda_i) < 0$)。 ### HiPPO 特征值的物理含义 每个特征值 $\lambda_n = -(n+1)$ 对应一个衰减时间常数 $\tau_n = 1/|{\lambda_n}| = 1/(n+1)$: - $\lambda_0 = -1$:$\tau_0 = 1$——最慢的衰减,编码信号的最粗粒度趋势 - $\lambda_1 = -2$:$\tau_1 = 0.5$——稍快的衰减 - $\lambda_{N-1} = -N$:$\tau_{N-1} = 1/N$——最快的衰减,编码信号的最细节变化 这形成了一个**多尺度记忆结构**:不同的状态分量以不同的速率"遗忘",从而同时保留不同时间尺度的信息。 下面的交互组件展示了 HiPPO 矩阵的特征值分布和对应的记忆衰减曲线。 <HiPPOSpectrumVis client:visible locale="zh" /> **读图要点**: - **左图**:所有特征值都在复平面的左半平面($\text{Re}(\lambda) < 0$),保证系统稳定。特征值均匀分布在负实轴上——这是 HiPPO-LegS 的特征 - **右图**:$\lambda = -1$ 的曲线衰减最慢(长期记忆),$\lambda = -N$ 的曲线衰减最快(短期细节)。这就是"多尺度记忆"的可视化 - **切换到 "S4 DPLR" 视图**:对角化后特征值获得虚部,形成共轭对——虚部编码振荡频率,下文详述 ## 离散化:从连续到计算 ### 为什么需要离散化 HiPPO 给了我们一个好的 $A$ 矩阵,但连续微分方程 $\dot{\mathbf{x}} = A\mathbf{x} + Bu$ 不能直接在数字计算机上运行。我们需要把它**离散化**为递推方程 $\mathbf{x}_k = \bar{A}\mathbf{x}_{k-1} + \bar{B}u_k$。 在 [Art. 17 连续系统与 Kalman](../matrix-math-continuous-systems-kalman) 中,我们详细讨论了两种离散化方法。现在直接回顾并应用到 SSM 语境中。 ### 双线性变换(S4 采用的方法) S4 论文 (Gu, Goel & Ré, 2022) 采用双线性变换(Bilinear / Tustin's method),公式为: $$\bar{A} = \left(I - \frac{\Delta}{2}A\right)^{-1}\left(I + \frac{\Delta}{2}A\right)$$ $$\bar{B} = \left(I - \frac{\Delta}{2}A\right)^{-1} \Delta B$$ 其中 $\Delta > 0$ 是步长参数(对应 Art. 17 中的 $\Delta t$)。 逐项理解: - $I + \frac{\Delta}{2}A$:向前看半步——一阶 Euler 前进半步 - $I - \frac{\Delta}{2}A$:向后看半步——隐式 Euler 后退半步 - 左乘 $(I - \frac{\Delta}{2}A)^{-1}$:梯形法则的效果——取前后两端的平均,精度 $O(\Delta^2)$ - 整体效果:将连续系统的左半平面($\text{Re}(\lambda) < 0$)**精确映射**到离散系统的单位圆内部($|\bar{\lambda}| < 1$)——**无论 $\Delta$ 多大,稳定性都被保持** 这正是 Art. 17 中总结的双线性变换的"无条件稳定性保持"性质。 ### 零阶保持(ZOH,Mamba 采用的方法) Mamba (Gu & Dao, 2023) 改用零阶保持(Zero-Order Hold)离散化: $$\bar{A} = e^{A\Delta}$$ $$\bar{B} = A^{-1}(e^{A\Delta} - I)B \approx (\Delta A)^{-1}(e^{\Delta A} - I) \cdot \Delta B$$ 当 $A$ 是对角矩阵 $A = \text{diag}(a_1, \ldots, a_N)$ 时,矩阵指数变成逐元素运算: $$\bar{A} = \text{diag}(e^{a_1 \Delta}, e^{a_2 \Delta}, \ldots, e^{a_N \Delta})$$ 这就是 [Art. 17 连续系统与 Kalman](../matrix-math-continuous-systems-kalman) 中矩阵指数通过对角化计算的结论($e^{At} = Q\,\text{diag}(e^{\lambda_1 t}, \ldots, e^{\lambda_n t})\, Q^{-1}$)在 $Q = I$(已经是对角矩阵)时的特殊情形。 ### 矩阵指数 $e^{A\Delta}$ 的角色 矩阵指数 $e^{A\Delta}$ 在 SSM 中扮演着核心角色。回顾 Art. 17 的结论: - 连续系统的解是 $\mathbf{x}(t) = e^{At}\mathbf{x}(0)$ - 对角化后 $e^{At} = Q\,\text{diag}(e^{\lambda_1 t}, \ldots, e^{\lambda_n t})\, Q^{-1}$ - 每个特征方向独立演化:$e^{\lambda_i t}$ 的模决定衰减/增长,辐角决定振荡 在 SSM/Mamba 中,$A$ 被直接约束为对角矩阵,所以 $Q = I$,$e^{A\Delta}$ 就是逐元素标量运算。**这是对角化的极端形式——不需要做分解,矩阵本身就已经是对角的。** 为什么要在连续域定义 $A$,再离散化到 $\bar{A}$,而不是直接学 $\bar{A}$? 1. **分辨率不变性**:连续参数 $A$ 不依赖于具体的时间步长 $\Delta$。训练时用 $\Delta = 0.01$,推理时可以改为 $\Delta = 0.005$(更高分辨率)或 $\Delta = 0.02$(更低分辨率),只需要重新计算 $\bar{A} = e^{A\Delta}$ 2. **稳定性保证**:只要 $A$ 的所有对角元素的实部为负($\text{Re}(a_i) < 0$),离散化后 $|\bar{a}_i| = |e^{a_i \Delta}| = e^{\text{Re}(a_i)\Delta} < 1$,稳定性自动保持 3. **物理直觉**:连续域的参数有明确的物理意义——$\text{Re}(a_i)$ 是衰减速率,$\text{Im}(a_i)$ 是振荡频率 ### 离散化方法对比(回顾 Art. 17) | 方法 | $\bar{A}$ | 稳定性保持 | 用在哪里 | |------|----------|-----------|---------| | 双线性变换 | $(I - \frac{\Delta}{2}A)^{-1}(I + \frac{\Delta}{2}A)$ | 无条件 | S4 | | ZOH | $e^{A\Delta}$ | 是($\text{Re}(\lambda) < 0 \Rightarrow \|e^{\lambda\Delta}\| < 1$) | Mamba | | 欧拉前向 | $I + \Delta A$ | **不保证** | 教学 | ## 对角化是关键:从 $O(N^2L)$ 到 $O(NL\log L)$ 现在来到整篇文章的核心。 ### 问题:密集矩阵递推的瓶颈 离散化后,SSM 的递推公式为: $$\mathbf{x}_k = \bar{A}\mathbf{x}_{k-1} + \bar{B}u_k$$ $$y_k = C\mathbf{x}_k$$ 如果 $\bar{A} \in \mathbb{R}^{N \times N}$ 是一个**密集矩阵**,每步递推需要一个矩阵-向量乘法 $\bar{A}\mathbf{x}_{k-1}$,复杂度为 $O(N^2)$。对长度为 $L$ 的序列,总复杂度为 $O(N^2 L)$。 对于 HiPPO 这样的结构化矩阵($N$ 通常取 64-256),$N^2$ 项是瓶颈。更关键的是,递推天然是**顺序的**——第 $k$ 步依赖第 $k-1$ 步的结果,无法并行化。 ### 洞察:卷积展开 仔细审视递推公式: $$\mathbf{x}_0 = \bar{B}u_0$$ $$\mathbf{x}_1 = \bar{A}\bar{B}u_0 + \bar{B}u_1$$ $$\mathbf{x}_2 = \bar{A}^2\bar{B}u_0 + \bar{A}\bar{B}u_1 + \bar{B}u_2$$ $$\vdots$$ $$\mathbf{x}_k = \sum_{j=0}^{k} \bar{A}^{k-j}\bar{B}u_j$$ 输出 $y_k = C\mathbf{x}_k = \sum_{j=0}^{k} C\bar{A}^{k-j}\bar{B} \cdot u_j$。 定义**卷积核**: $$\bar{K}_j = C\bar{A}^j\bar{B}, \quad j = 0, 1, \ldots, L-1$$ 那么输出就是输入和卷积核的(因果)卷积: $$y_k = \sum_{j=0}^{k} \bar{K}_{k-j} \cdot u_j = (\bar{K} * u)_k$$ 这就是 SSM 的**双重计算模式**: - **递推模式**:$\mathbf{x}_k = \bar{A}\mathbf{x}_{k-1} + \bar{B}u_k$——逐步计算,$O(NL)$ 但顺序执行 - **卷积模式**:$\mathbf{y} = \bar{K} * \mathbf{u}$——通过 FFT 并行计算,$O(L\log L)$ **但是**,计算卷积核 $\bar{K}$ 需要计算 $\bar{A}^0\bar{B}, \bar{A}^1\bar{B}, \ldots, \bar{A}^{L-1}\bar{B}$——这就是 $L$ 次矩阵-向量乘法,复杂度 $O(N^2 L)$。瓶颈并没有消失,只是从递推转移到了卷积核的计算。 ### 对角化:打破瓶颈 这就是 [Art. 2 特征分解](../matrix-math-eigendecomposition)中的核心工具登场的时刻。 如果 $\bar{A}$ 可以对角化为 $\bar{A} = V\Lambda V^{-1}$([Art. 2 特征分解](../matrix-math-eigendecomposition) 的 $Q\Lambda Q^{-1}$,这里用 $V$ 以与 SSM 文献一致),那么: $$\bar{A}^j = V\Lambda^j V^{-1}$$ 代入卷积核: $$\bar{K}_j = C\bar{A}^j\bar{B} = (CV)\Lambda^j(V^{-1}\bar{B})$$ 定义 $\tilde{C} = CV \in \mathbb{R}^{1 \times N}$ 和 $\tilde{B} = V^{-1}\bar{B} \in \mathbb{R}^{N \times 1}$,其中 $\Lambda = \text{diag}(\bar{\lambda}_1, \ldots, \bar{\lambda}_N)$: $$\bar{K}_j = \tilde{C}\,\text{diag}(\bar{\lambda}_1^j, \ldots, \bar{\lambda}_N^j)\,\tilde{B} = \sum_{i=1}^{N} \tilde{C}_i \cdot \bar{\lambda}_i^j \cdot \tilde{B}_i$$ 现在,$\bar{K}_j$ 分解为 $N$ 个**几何级数**的加权和。每个几何级数 $\bar{\lambda}_i^0, \bar{\lambda}_i^1, \ldots, \bar{\lambda}_i^{L-1}$ 可以在 $O(L)$ 时间内计算(逐元素乘法),$N$ 个级数总计 $O(NL)$。 再加上 FFT 卷积的 $O(L\log L)$,总复杂度为: $$\boxed{O(NL + L\log L) = O(NL\log L)}$$ 对比密集矩阵的 $O(N^2L)$:当 $N = 64, L = 10000$ 时,加速比约为 $N / \log L \approx 64 / 13 \approx 5\times$;但更重要的是卷积模式**完全可并行化**。 **这就是 Art. 2 特征分解中"对角化简化一切"在现代架构中的终极应用。** $A^n = Q\Lambda^n Q^{-1}$ 这个公式——我们在路径的第二篇文章中就学到的——在这里直接转化为计算效率的飞跃。 ### S4 的突破:HiPPO 矩阵的 DPLR 结构 但 HiPPO-LegS 矩阵不是对称矩阵(它是下三角的),对角化并不简单。S4 (Gu, Goel & Ré, 2022) 的关键贡献是发现 HiPPO 矩阵具有 **DPLR 结构**(Diagonal Plus Low-Rank): $$A = \Lambda_{\text{diag}} - \mathbf{p}\mathbf{q}^T$$ 其中 $\Lambda_{\text{diag}}$ 是对角矩阵,$\mathbf{p}\mathbf{q}^T$ 是秩一矩阵。 更精确地说,HiPPO 矩阵可以被写为正规矩阵加上低秩修正的形式(NPLR,Normal Plus Low-Rank),而正规矩阵可以通过酉变换对角化。这意味着只需要处理一个低秩修正项——利用 Woodbury 恒等式,矩阵求逆和矩阵指数都可以在 $O(N)$ 时间内完成。 S4 的完整计算流程: 1. 将 HiPPO 矩阵分解为 NPLR/DPLR 结构 2. 在对角基下计算卷积核——每个频率分量独立 3. 用 FFT 完成卷积 4. 总复杂度:$O(NL\log L)$ ## 从 S4 到 Mamba:选择性机制 ### S4 的局限 S4 将 SSM 的计算效率提升到了可用水平,但它有一个关键限制:**$(A, B, C)$ 在训练后固定不变**——对于不同的输入 token,状态更新规则完全相同。 这意味着 S4 是一个**线性时不变(LTI)系统**:$\bar{A}, \bar{B}, C$ 不依赖于输入 $u_k$。LTI 的好处是可以用卷积模式高效计算,但坏处是**缺乏选择性**——模型不能根据输入内容决定"记住什么、遗忘什么"。 对比 Transformer 的 attention:注意力权重是**输入依赖的**($\text{softmax}(QK^T/\sqrt{d_k})$ 中 $Q, K$ 来自输入),所以 Transformer 天然具有选择性——它可以决定关注序列中的哪些位置。 ### Mamba 的核心创新:让参数依赖输入 Mamba (Gu & Dao, 2023) 的核心思路是:**让离散化参数 $\Delta$、输入矩阵 $B$ 和输出矩阵 $C$ 依赖于当前输入 $x_t$**。 具体来说,对于输入序列中的第 $k$ 个 token $x_k \in \mathbb{R}^D$($D$ 是模型维度): $$\Delta_k = \text{softplus}(W_\Delta x_k + b_\Delta) \in \mathbb{R}^D$$ $$B_k = W_B x_k \in \mathbb{R}^{N}$$ $$C_k = W_C x_k \in \mathbb{R}^{N}$$ 其中 $W_\Delta, W_B, W_C$ 是可学习的投影矩阵,softplus 保证 $\Delta_k > 0$。 注意:**$A$ 矩阵本身不依赖输入**——它是全局共享的对角矩阵 $A = \text{diag}(a_1, \ldots, a_N)$。但通过 $\Delta_k$ 的变化,离散化后的 $\bar{A}_k = e^{A \Delta_k}$ 变成了**输入依赖的**。 ### 选择性机制的直觉 $\Delta$ 可以理解为"时间步长"或"门控信号": - **$\Delta_k$ 大**:$\bar{A}_k = e^{A\Delta_k}$ 中的衰减更强(因为 $A$ 的对角元素为负,$\Delta$ 越大,$e^{a_i \Delta}$ 越接近 0)→ 更多地**遗忘**旧状态,更多地**吸收**新输入 - **$\Delta_k$ 小**:$\bar{A}_k \approx I$(因为 $e^{a_i \cdot 0} = 1$)→ 几乎**保持**旧状态不变,忽略当前输入 所以模型可以学会: - 遇到重要的 token(如关键词、分隔符)→ 输出大的 $\Delta$ → 重置状态,吸收新信息 - 遇到不重要的 token(如填充词、噪声)→ 输出小的 $\Delta$ → 保持记忆不变 **这就是"选择性"(selectivity)的含义——模型通过学习 $\Delta$ 来选择性地记忆和遗忘。** ### 代价:失去卷积模式 选择性带来了一个计算上的代价:因为 $\bar{A}_k, \bar{B}_k, C_k$ 现在随时间变化(time-varying),SSM 不再是 LTI 系统,卷积模式**不再适用**。 Mamba 必须回到递推模式: $$\mathbf{h}_k = \bar{A}_k \mathbf{h}_{k-1} + \bar{B}_k u_k$$ $$y_k = C_k \mathbf{h}_k$$ 但因为 $\bar{A}_k$ 是**对角矩阵**,每步递推是 $O(N)$ 的逐元素乘法(而非 $O(N^2)$ 的矩阵乘法)。总复杂度为 $O(NL)$——线性于序列长度。 Gu & Dao (2023) 还开发了一套**硬件感知的并行扫描算法**(hardware-aware parallel scan),利用 GPU 的并行能力将顺序递推加速到接近卷积的效率。关键技巧是:利用对角矩阵乘法的结合律,将长度为 $L$ 的顺序扫描分解为 $O(\log L)$ 层的并行归约。 ### S4 → Mamba 的演化总结 | | S4 | Mamba | |--|---|-------| | $A$ 的结构 | DPLR(从 HiPPO 初始化) | 对角矩阵(直接参数化) | | $B, C$ | 固定(训练后不变) | **输入依赖**($B_k = W_B x_k$) | | $\Delta$ | 固定超参数 | **输入依赖**($\Delta_k = \text{softplus}(W_\Delta x_k)$) | | 选择性 | 无(LTI 系统) | **有**(通过 $\Delta, B, C$ 的输入依赖性) | | 计算模式 | 卷积(FFT) | 递推(并行扫描) | | 训练复杂度 | $O(NL\log L)$ | $O(NL)$ | | 推理复杂度 | $O(NL)$(递推模式) | $O(NL)$ | ### DSS 到 S4D:对角化参数化的简化 值得一提的是 S4 到 Mamba 之间的中间步骤。Gu et al. (2022, NeurIPS) 在 DSS(Diagonal State Spaces)和 S4D 中发现:**直接将 $A$ 参数化为对角矩阵**,跳过 HiPPO 的 DPLR 分解,在大多数任务上表现相当。 具体做法:$A = \text{diag}(a_1, \ldots, a_N)$,其中 $a_i \in \mathbb{C}$(允许复数),初始化为 HiPPO 矩阵的特征值 $a_i = -(i+1)$(或对角化后的 DPLR 特征值)。 这个简化的关键洞察是:HiPPO 矩阵的**特征值**(而非矩阵本身)才是真正重要的东西。下三角结构只是 Legendre 多项式投影的一种表示,换到特征基后,本质信息被特征值完全捕获。 > **这正是 Art. 2 的核心教训——对角化揭示了线性变换的本质结构。矩阵的很多表观复杂性只是坐标选择的产物,在特征基下,一切变成了逐方向独立缩放。** ## 完整的 Mamba 计算流程 让我们把所有部件组装起来,给出 Mamba 的完整前向计算流程: **输入**:序列 $\mathbf{x} = (x_1, x_2, \ldots, x_L)$,$x_k \in \mathbb{R}^D$ **参数**(可学习): - $A = \text{diag}(a_1, \ldots, a_N) \in \mathbb{C}^{N \times N}$:状态矩阵(全局共享,负实部) - $W_\Delta \in \mathbb{R}^{D \times R}$, $b_\Delta \in \mathbb{R}^D$:步长投影($R$ 是低秩维度) - $W_B \in \mathbb{R}^{N \times D}$:输入矩阵投影 - $W_C \in \mathbb{R}^{N \times D}$:输出矩阵投影 **对每个位置 $k = 1, \ldots, L$**: 1. **计算输入依赖参数**: - $\Delta_k = \text{softplus}(W_\Delta x_k + b_\Delta)$ - $B_k = W_B x_k$ - $C_k = W_C x_k$ 2. **离散化**(ZOH): - $\bar{A}_k = e^{A \Delta_k} = \text{diag}(e^{a_1 \Delta_{k}}, \ldots, e^{a_N \Delta_{k}})$ - $\bar{B}_k = A^{-1}(e^{A\Delta_k} - I)B_k$(对角矩阵,逐元素) 3. **状态更新**(对角矩阵乘法,$O(N)$): - $\mathbf{h}_k = \bar{A}_k \odot \mathbf{h}_{k-1} + \bar{B}_k \odot u_k$ 4. **输出投影**: - $y_k = C_k^T \mathbf{h}_k$ 其中 $\odot$ 表示逐元素乘法(因为 $\bar{A}_k$ 是对角矩阵,矩阵-向量乘法退化为逐元素乘法)。 **整个过程是 $O(NL)$ 的**——线性于序列长度和状态维度。 ## 全路径回顾:三条弧线的汇聚 这是矩阵数学路径的最后一篇文章。让我们站在终点回望,看看 27 篇文章如何构成一个完整的叙事。 ### 弧线一:拆(Part 1, Art. 1-13) Part 1 的核心任务是**建立工具**。我们学会了四件工具: - **[Art. 2 特征分解](../matrix-math-eigendecomposition)**:$A = Q\Lambda Q^{-1}$——在特征基下,线性变换退化为逐方向缩放 - **[Art. 3 SVD](../matrix-math-svd)**:推广到任意矩阵 $A = U\Sigma V^T$,Eckart-Young 定理保证截断 SVD 是最佳低秩近似 - **[Art. 4 范数](../matrix-math-norms)**:度量矩阵和近似的"大小" - **[Art. 5 微积分](../matrix-math-calculus)**:矩阵参数的梯度计算 然后用这些工具拆解数据矩阵:PCA(Art. 6)、随机化 SVD(Art. 7)、矩阵补全(Art. 8)、NMF(Art. 9)、矩阵分解与 FM(Art. 10)、Word2Vec(Art. 11)、鲁棒 PCA(Art. 12)、张量分解(Art. 13)。 **在 SSM/Mamba 中的回响**:Art. 2 的特征分解公式 $A^n = Q\Lambda^n Q^{-1}$ 直接给出了 S4 卷积核的高效计算方法。Part 1 建立的"对角化简化一切"这个洞察,在 20 多篇文章之后,成为了一个现代架构的设计原则。 ### 弧线二:传(Part 2, Art. 14-22) Part 2 的核心任务是**分析给定算子**。矩阵不再只是装数据的容器,而是编码过程的算子。 - **时序子线**:马尔可夫链(Art. 15)→ HMM(Art. 16)→ 连续系统与 Kalman(Art. 17) - **图/空间子线**:PageRank(Art. 18)→ 随机游走(Art. 19)→ Kernel(Art. 20)→ 图 Laplacian(Art. 21)→ GNN(Art. 22) **在 SSM/Mamba 中的回响**:Art. 17 的连续时间 state-space 公式 $(\dot{\mathbf{x}} = A\mathbf{x} + Bu, \mathbf{y} = C\mathbf{x})$ 和离散化方法(ZOH、双线性变换)被 SSM/Mamba **原封不动地复用**。Art. 15 马尔可夫链中"反复乘以矩阵"的概念,在 SSM 的递推模式 $\mathbf{h}_k = \bar{A}\mathbf{h}_{k-1} + \bar{B}u_k$ 中找到了直接的对应。 ### 弧线三:汇(Part 3, Art. 23-26) Part 3 的核心转变是**矩阵从"给定"变为"学习"**。 - **[Art. 23 学习算子概述](../matrix-math-learned-operators)**:揭示训练好的权重矩阵的经验低秩性,预览三个应用方向 - **Art. 24 LoRA**:利用微调增量的低秩性压缩可训练参数 - **Art. 25 Efficient Attention**:利用注意力矩阵的近似低秩性加速推理 - **Art. 26 SSM/Mamba(本文)**:利用状态矩阵的对角结构实现线性时间序列建模 ### 汇聚点:SSM/Mamba SSM/Mamba 是三条弧线的交汇点: | 来自哪里 | 贡献了什么 | |---------|----------| | Art. 2 特征分解 | $A = Q\Lambda Q^{-1}$, $A^n = Q\Lambda^n Q^{-1}$ → 卷积核的高效计算 | | Art. 17 Kalman | State-space 公式 + 离散化方法 → SSM 的建模框架 | | Art. 23 学习算子 | "给定"变"学习" → $A, B, C$ 成为可训练参数 | 甚至更细粒度地: - **对角化**(Art. 2)→ S4 的 DPLR 分解 → Mamba 的对角参数化 - **矩阵指数**(Art. 17)→ ZOH 离散化 $\bar{A} = e^{A\Delta}$ - **稳定性判据**(Art. 17)→ $\text{Re}(a_i) < 0$ 保证 SSM 的稳定性 - **经验低秩性**(Art. 23)→ 动机:利用结构简化计算 **三段弧线的一句话总结**: > Part 1 教我们"对角化简化一切",Part 2 展示了矩阵指数和状态空间模型的威力,Part 3 把给定算子变成可学习参数——SSM/Mamba 在这三个洞察的交汇处,用对角化将学到的状态空间模型变成了线性时间的序列建模器。 ## 总结 本文——也是整条路径的最后一篇——介绍了 SSM/Mamba 如何将矩阵对角化这个基础工具转化为高效架构设计原则: - **HiPPO 矩阵**提供了"最优记忆"的数学基础:$A_{nk}$ 的特定结构编码了 Legendre 多项式投影,特征值 $\lambda_n = -(n+1)$ 形成多尺度记忆 - **离散化**(双线性变换/ZOH)将连续时间 SSM 转换为可计算的递推方程,对照 Art. 17 的方法 - **对角化是关键突破**:$A = V\Lambda V^{-1}$ 将 $O(N^2L)$ 的密集递推转化为 $O(NL\log L)$ 的卷积(S4)或 $O(NL)$ 的对角递推(Mamba) - **选择性机制**(Mamba)让 $\Delta, B, C$ 依赖输入,使模型能够选择性地记忆和遗忘——以失去卷积模式为代价,换取内容感知的序列处理 - **矩阵指数** $e^{A\Delta}$ 扮演核心角色:连接连续参数和离散计算,保持稳定性 从整条路径的视角看: > **同一个公式 $A^n = Q\Lambda^n Q^{-1}$——在 Art. 2 中是一个数学定理,在 Art. 15 中分析马尔可夫链的长期行为,在 Art. 17 中计算矩阵指数,在 Art. 26 中成为高效架构的设计原则。这就是"拆→传→汇"弧线的完整闭合:Part 1 的工具建立了分解和对角化的数学基础,Part 2 展示了这些工具在分析给定系统中的力量,Part 3 将它们应用于设计和优化学习到的计算。** 矩阵数学路径至此完结。感谢你走完了这 27 篇文章的旅程。