一句话总结:Transformer 的 O(N^2) 复杂度限制了它处理超长序列的能力。State Space Models (SSM) 用 O(N) 的线性复杂度实现序列建模,而 Mamba 通过"选择性"机制让 SSM 首次在语言任务上匹敌 Transformer。混合架构如 Jamba 结合了两者优势,代表了后 Transformer 时代的重要方向。
32.1 Transformer 的瓶颈
32.1.1 回顾:Attention 的复杂度
在前面的章节中,我们深入学习了 Transformer 的核心机制。让我们回顾一下 Self-Attention 的计算:
这个看似优雅的公式隐藏着一个巨大的问题:计算和内存复杂度都是 O(N^2)。
为什么是 O(N^2)?因为 这一步:
- Q 的形状是 [N, d]
- K 的形状是 [N, d]
- 的形状是 [N, N]
这意味着每个 token 都要与所有其他 token 计算注意力分数。
32.1.2 O(N^2) 问题有多严重?
让我们算一笔账:
| 序列长度 N | 注意力矩阵大小 | 内存占用 (FP16) |
|---|---|---|
| 1,024 | 1M | 2 MB |
| 4,096 | 16M | 32 MB |
| 32,768 | 1B | 2 GB |
| 131,072 (128K) | 17B | 34 GB |
| 1,000,000 (1M) | 1T | 2 TB |
当序列长度从 4K 增加到 128K(32 倍),内存占用增加了 1024 倍!
这就是为什么:
- GPT-4 早期版本只支持 8K 上下文
- Claude 扩展到 100K 需要大量工程优化
- 处理整本书、代码仓库、长视频仍然是巨大挑战
32.1.3 KV Cache 的局限
在第 22 章,我们学习了 KV Cache。它通过缓存历史的 K 和 V 来加速推理。但 KV Cache 也有自己的问题:
内存随序列线性增长:
KV Cache 大小 = 2 × layers × n_heads × d_head × seq_len × precision
对于 Llama-70B:
- 处理 4K tokens:约 2.5 GB
- 处理 128K tokens:约 80 GB
- 处理 1M tokens:约 625 GB(需要多台服务器)
这产生了两个核心矛盾:
- 长上下文 vs 并发:同一台 GPU 上,支持更长的上下文就意味着更少的并发用户
- 上下文窗口 vs 成本:扩展上下文窗口的成本是超线性增长的
核心问题:有没有一种架构,既能像 Transformer 一样强大,又能以 O(N) 的复杂度处理序列?
这就是 State Space Models 和 Mamba 要解决的问题。
32.2 State Space Models (SSM) 基础
32.2.1 从 RNN 说起
在 Transformer 之前,序列建模的主流方法是 RNN(循环神经网络)。
RNN 的核心思想很简单:用一个"隐状态"来压缩历史信息。
h_t = f(h_{t-1}, x_t)
y_t = g(h_t)
- h_t 是时刻 t 的隐状态
- x_t 是时刻 t 的输入
- 每一步只需要上一步的隐状态和当前输入
RNN 的优点:
- O(N) 复杂度:每个时间步计算量固定
- O(1) 内存:只需要保存当前隐状态
RNN 的缺点:
- 顺序计算:无法并行,训练慢
- 梯度消失/爆炸:难以捕捉长距离依赖
- 表达能力有限:固定大小的隐状态是信息瓶颈
32.2.2 什么是 State Space Models?
State Space Models(状态空间模型)起源于控制论和信号处理。它用微分方程描述系统的演化:
连续形式:
其中:
- x(t) 是输入信号
- h(t) 是隐状态(state)
- y(t) 是输出信号
- A, B, C, D 是系统参数
直觉理解:
- A 决定了隐状态如何"自我演化"
- B 决定了输入如何"写入"隐状态
- C 决定了如何从隐状态"读出"输出
- D 是输入到输出的"直连"
32.2.3 连续 vs 离散
计算机处理的是离散序列,我们需要将连续的 SSM 离散化:
这里 和 是通过离散化得到的参数。
常用的离散化方法是零阶保持(Zero-Order Hold):
其中 是采样间隔。
32.2.4 SSM 的两种视角
SSM 有两种等价的计算方式:
视角一:递归形式(用于推理)
h_0 = 0
for k in range(seq_len):
h_k = A_bar @ h_{k-1} + B_bar @ x_k
y_k = C @ h_k
- 复杂度:O(N)
- 无法并行
视角二:卷积形式(用于训练)
# 预计算卷积核
K = [C @ B_bar, C @ A_bar @ B_bar, C @ A_bar^2 @ B_bar, ...]
# 一次卷积
y = conv1d(x, K)
- 可以并行计算
- 利用 FFT 加速到 O(N log N)
SSM 的魔法:训练时用卷积形式并行计算,推理时用递归形式 O(1) 内存!
32.2.5 S4:SSM 的突破
2021 年,斯坦福的 Albert Gu 等人提出了 S4(Structured State Space for Sequence Modeling)。
S4 的核心贡献:
- 结构化 A 矩阵:使用 HiPPO 初始化,让模型能记住长距离信息
- 高效计算:利用矩阵结构,将计算复杂度降到 O(N log N)
- 长序列建模:在 Path-X(16K 序列分类)等长序列任务上超越 Transformer
但 S4 有一个关键问题:参数是输入无关的。
这意味着 A、B、C 矩阵在整个序列上是固定的。模型无法根据输入内容"选择性"地关注或忘记信息。
这正是 Mamba 要解决的问题。
32.3 Mamba 的核心创新
32.3.1 论文背景
2023 年 12 月,CMU 和 Princeton 的研究者发布了 Mamba: Linear-Time Sequence Modeling with Selective State Spaces。
这篇论文引起了巨大关注,因为它首次证明:SSM 可以在语言建模任务上匹敌 Transformer。
32.3.2 核心问题:内容感知
传统 SSM 的参数是"静态的"。让我们用一个例子理解这个问题:
输入序列:"The quick brown fox jumps over the lazy dog."
传统 SSM:
- 对 "The" 和 "fox" 使用相同的 A、B 矩阵
- 无法根据 token 的重要性调整记忆行为
这就像一个人用同样的注意力读每一个字,无法区分重要和不重要的信息。
32.3.3 Mamba 的解决方案:选择性 SSM
Mamba 的核心创新是让 SSM 的参数依赖于输入:
注意:
- (离散化步长)变成了输入的函数
- B(输入投影)变成了输入的函数
- C(输出投影)变成了输入的函数
直觉理解:
- 当 很大时,隐状态快速更新 → "关注当前输入"
- 当 很小时,隐状态保持稳定 → "保持历史记忆"
- 模型可以根据输入内容选择记住什么、忘记什么
32.3.4 选择性机制的力量
让我们回到之前的例子:
输入:"The quick brown fox jumps over the lazy dog."
选择性 SSM (Mamba):
- 对 "The"(冠词):小 Δ → 快速遗忘
- 对 "fox"(关键名词):大 Δ → 更新隐状态
- 对 "jumps"(动词):根据上下文动态调整
这正是 Attention 的本质:根据内容决定关注什么。
Mamba 用一种完全不同的方式实现了类似的效果。
32.3.5 硬件感知算法
让参数依赖于输入带来一个问题:无法再用卷积形式并行计算。
Mamba 的解决方案是设计硬件感知的算法:
- 核心观察:虽然无法用全局卷积,但可以在 GPU 内存层级上优化
- Selective Scan:一种新的并行扫描算法
- Kernel Fusion:将多个操作融合,减少内存访问
这使得 Mamba 在实际硬件上的速度:
- 训练:比 Flash Attention 快 3 倍
- 推理:随序列长度线性扩展
32.3.6 Mamba 架构
完整的 Mamba 块结构:
输入 x
|
+---> Conv1d --> SiLU --> SSM
| |
+------------------------>+
(Skip Connection) |
|
Norm
|
输出
关键设计:
- 无注意力:完全移除了 self-attention
- 扩展再压缩:类似 MLP 的 expand-contract 结构
- 门控机制:用 SiLU 激活函数提供非线性
32.3.7 复杂度对比
| 操作 | Transformer | Mamba |
|---|---|---|
| 计算复杂度 | O(N^2 d) | O(N d) |
| 内存复杂度 | O(N^2) | O(N) |
| 推理步骤 | O(N) | O(1) |
| 序列长度扩展 | 二次增长 | 线性增长 |
关键优势:
- 处理 128K tokens 时,Mamba 的内存和计算随 N 线性增长,而 Transformer 是 N² 增长(即 Mamba ≈ O(N) vs Transformer ≈ O(N²))
32.4 Mamba-2 与 State Space Duality
32.4.1 Mamba 的局限
虽然 Mamba 在效率上有巨大优势,但在某些任务上仍不如 Transformer:
- In-context Learning:Transformer 能更好地利用上下文中的示例
- 信息检索:在需要精确回忆特定位置信息的任务上较弱
- 复杂推理:多步推理任务表现不如 Transformer
32.4.2 Mamba-2 的突破
2024 年 5 月,同一团队发布了 Mamba-2,带来两个关键贡献:
贡献一:State Space Duality (SSD)
Mamba-2 发现:SSM 和 Attention 是同一框架的两种形式!
具体来说:
- SSM 可以写成一种"结构化矩阵"乘法
- Attention 也可以写成类似的形式
- 两者通过不同的矩阵结构统一起来
这个理论洞察被称为 State Space Duality。
贡献二:更高效的算法
基于 SSD,Mamba-2 设计了新的算法:
- 比 Mamba-1 快 2-8 倍
- 更好地利用 Tensor Cores
- 支持更大的状态维度(从 16 提升到 256)
32.4.3 SSD 的直觉
让我们用简化的视角理解 SSD:
Attention 视角:
每个位置可以直接"看到"所有其他位置
关系矩阵是 N x N 的 dense 矩阵
SSM 视角:
信息通过隐状态逐步传递
关系矩阵是 N x N 的 structured 矩阵(低秩 + 特殊结构)
SSD 证明:通过选择不同的结构约束,可以在这两个极端之间连续变化。
这意味着:
- Transformer 是一种"无结构"的极端情况
- 传统 SSM 是另一种"强结构"的极端情况
- Mamba-2 找到了中间的最优点
32.4.4 性能提升
Mamba-2 在多项任务上的表现:
| 任务 | Mamba-1 | Mamba-2 | Transformer |
|---|---|---|---|
| 语言建模 (PPL) | 基准 | -0.1 | -0.05 |
| 长序列 (16K+) | 快 | 更快 2-8x | 慢 |
| 训练吞吐 | 高 | 更高 | 中 |
核心认知:Mamba-2 不仅是工程优化,更是理论突破。SSD 为统一 SSM 和 Attention 提供了数学基础。
32.5 混合架构:Jamba
32.5.1 为什么需要混合架构?
尽管 Mamba 在效率上有优势,但 Transformer 在某些任务上仍然更强:
| 任务类型 | Transformer 优势 | Mamba 优势 |
|---|---|---|
| 短序列理解 | 强 | 适中 |
| 长序列处理 | 内存瓶颈 | 线性扩展 |
| 信息精确检索 | 强 | 较弱 |
| 多步推理 | 强 | 适中 |
| 推理速度 | 慢 | 快 |
一个自然的想法:能否结合两者的优势?
32.5.2 Jamba 架构
2024 年 3 月,AI21 Labs 发布了 Jamba,这是第一个将 Transformer、Mamba 和 MoE 结合的大规模模型。
Jamba 的核心设计:
- 混合层:交替使用 Transformer 层和 Mamba 层
- MoE 集成:部分层使用 Mixture of Experts
- 256K 上下文:支持超长上下文窗口
32.5.3 Jamba 的层结构
Jamba Block 结构(4 层为一组):
Layer 1: Mamba + MLP
Layer 2: Mamba + MLP
Layer 3: Mamba + MLP
Layer 4: Attention + MLP <-- 每 4 层有 1 层 Attention
其中 MLP 可以是普通 MLP 或 MoE
设计理念:
- 大部分层使用高效的 Mamba
- 少量 Attention 层提供"全局视野"
- MoE 增加模型容量而不增加推理成本
32.5.4 Jamba 的参数配置
| 配置 | 数值 |
|---|---|
| 总参数量 | 52B |
| 活跃参数量 | 12B |
| 层数 | 32 |
| Attention 层 | 8(每 4 层 1 个) |
| Mamba 层 | 24 |
| MoE 专家数 | 16 |
| 激活专家数 | 2 |
| 上下文窗口 | 256K |
关键指标:
- 52B 总参数,12B 激活:通过 MoE 实现
- 1:3 的 Attention/Mamba 比例:大部分是高效的 Mamba
- 256K 上下文:比同参数 Transformer 大 32 倍
32.5.5 为什么混合比纯架构更好?
场景一:处理长文档
[开头 8K tokens] ... [中间 240K tokens] ... [结尾 8K tokens]
Mamba 层:高效地压缩中间 240K tokens 的信息
Attention 层:在关键位置进行精确的信息检索
场景二:In-context Learning
[Few-shot 示例 1] [Few-shot 示例 2] ... [Query]
Mamba 层:快速处理示例
Attention 层:精确对齐 Query 和示例
场景三:推理任务
[问题描述] [推理步骤 1] [推理步骤 2] ... [答案]
Mamba 层:压缩问题描述
Attention 层:连接推理步骤
32.5.6 Jamba 的性能表现
与同规模模型对比:
| 模型 | 参数量 | 上下文 | 内存 | 吞吐量 |
|---|---|---|---|---|
| Llama-2 70B | 70B | 4K | 140 GB | 基准 |
| Mixtral 8x7B | 47B | 32K | 94 GB | 1.5x |
| Jamba | 52B (12B 激活) | 256K | 100 GB | 3x |
Jamba 的优势:
- 比 Llama-2 70B 快 3 倍
- 上下文长度是 64 倍
- 内存占用更低
32.6 其他替代架构
32.6.1 RWKV:线性注意力
RWKV(Receptance Weighted Key Value) 是另一种尝试替代 Transformer 的架构。
核心思想:
- 将 Attention 改造为"线性"形式
- 类似 RNN 的递归计算
- O(N) 复杂度
# RWKV 的核心公式(简化)
wkv_t = sum(e^(k_i + w*(t-1-i)) * v_i for i in range(t))
y_t = sigmoid(r_t) * wkv_t
特点:
- 可以像 RNN 一样逐步生成
- 也可以像 Transformer 一样并行训练
- 开源社区活跃
RWKV 版本演进:
- RWKV-4:初始版本
- RWKV-5 (Eagle):改进位置编码
- RWKV-6 (Finch):进一步优化
32.6.2 RetNet:保留网络
RetNet(Retentive Network) 由微软亚洲研究院在 2023 年提出。
核心思想:
- 将 Attention 分解为"保留"和"衰减"
- 支持并行训练、递归推理、分块推理三种模式
retention(Q, K, V) = (Q @ K^T * D) @ V
其中 D 是衰减矩阵,距离越远衰减越大
三种计算模式:
- 并行模式:训练时使用,类似 Transformer
- 递归模式:推理时使用,类似 RNN
- 分块模式:处理超长序列,分块计算
32.6.3 Hyena:长卷积
Hyena 是斯坦福在 2023 年提出的架构。
核心思想:
- 用可学习的长卷积替代 Attention
- 卷积核通过小型网络隐式生成
输入 x
|
v
生成 filters: f1, f2, ..., fn = HyenaFilter(x)
|
v
输出 y = (((x * f1) * g1) * f2) * g2 ...
特点:
- O(N log N) 复杂度
- 在长序列任务上表现优秀
- 但在语言建模上不如 Mamba
32.6.4 架构对比总结
| 架构 | 核心思想 | 复杂度 | 语言建模 | 长序列 |
|---|---|---|---|---|
| Transformer | Self-Attention | O(N^2) | 最强 | 受限 |
| Mamba | 选择性 SSM | O(N) | 接近 | 优秀 |
| RWKV | 线性注意力 | O(N) | 较好 | 优秀 |
| RetNet | 保留机制 | O(N) | 较好 | 优秀 |
| Hyena | 长卷积 | O(N log N) | 适中 | 优秀 |
| Jamba | 混合 | O(N) | 强 | 优秀 |
行业趋势:纯 SSM 架构在特定任务上有优势,但混合架构(如 Jamba)可能是最实用的方向。
32.7 代码示例
32.7.1 简化的 SSM 实现
让我们实现一个最简化的 SSM 来理解核心概念:
import torch
import torch.nn as nn
import torch.nn.functional as F
class SimpleSSM(nn.Module):
"""简化的 State Space Model 实现"""
def __init__(self, d_model, d_state=16):
super().__init__()
self.d_model = d_model
self.d_state = d_state
# 状态空间参数
# A: 状态转移矩阵 (d_state, d_state)
# B: 输入投影 (d_state, d_model)
# C: 输出投影 (d_model, d_state)
self.A = nn.Parameter(torch.randn(d_state, d_state) * 0.01)
self.B = nn.Parameter(torch.randn(d_state, d_model) * 0.01)
self.C = nn.Parameter(torch.randn(d_model, d_state) * 0.01)
# 离散化步长
self.delta = nn.Parameter(torch.ones(1) * 0.1)
def discretize(self):
"""零阶保持离散化"""
# 简化实现:A_bar = exp(delta * A)
A_bar = torch.matrix_exp(self.delta * self.A)
# B_bar = (A_bar - I) @ A^{-1} @ B,简化为 delta * B
B_bar = self.delta * self.B
return A_bar, B_bar
def forward(self, x):
"""
x: (batch, seq_len, d_model)
返回: (batch, seq_len, d_model)
"""
batch, seq_len, d_model = x.shape
A_bar, B_bar = self.discretize()
# 递归计算(推理模式)
h = torch.zeros(batch, self.d_state, device=x.device)
outputs = []
for t in range(seq_len):
# h_t = A_bar @ h_{t-1} + B_bar @ x_t
h = h @ A_bar.T + x[:, t, :] @ B_bar.T
# y_t = C @ h_t
y = h @ self.C.T
outputs.append(y)
return torch.stack(outputs, dim=1)
32.7.2 选择性 SSM(Mamba 风格)
class SelectiveSSM(nn.Module):
"""Mamba 风格的选择性 SSM"""
def __init__(self, d_model, d_state=16, d_conv=4):
super().__init__()
self.d_model = d_model
self.d_state = d_state
# 输入投影
self.in_proj = nn.Linear(d_model, d_model * 2)
# 1D 卷积
self.conv = nn.Conv1d(
d_model, d_model,
kernel_size=d_conv,
padding=d_conv - 1,
groups=d_model
)
# 选择性参数生成
# 关键区别:B, C, delta 都依赖于输入!
self.B_proj = nn.Linear(d_model, d_state)
self.C_proj = nn.Linear(d_model, d_state)
self.delta_proj = nn.Linear(d_model, d_model)
# A 矩阵(不依赖输入,但需要特殊初始化)
A = torch.arange(1, d_state + 1).float()
self.A_log = nn.Parameter(torch.log(A))
# 输出投影
self.out_proj = nn.Linear(d_model, d_model)
def forward(self, x):
"""
x: (batch, seq_len, d_model)
"""
batch, seq_len, _ = x.shape
# 输入投影和门控
xz = self.in_proj(x) # (B, L, 2*D)
x, z = xz.chunk(2, dim=-1) # 各 (B, L, D)
# 1D 卷积
x = x.transpose(1, 2) # (B, D, L)
x = self.conv(x)[:, :, :seq_len]
x = x.transpose(1, 2) # (B, L, D)
x = F.silu(x)
# 选择性参数 - 这是 Mamba 的核心创新!
B = self.B_proj(x) # (B, L, d_state) - 依赖输入
C = self.C_proj(x) # (B, L, d_state) - 依赖输入
delta = F.softplus(self.delta_proj(x)) # (B, L, D) - 依赖输入
# A 矩阵
A = -torch.exp(self.A_log) # (d_state,)
# 离散化并执行 SSM
y = self.selective_scan(x, A, B, C, delta)
# 门控
y = y * F.silu(z)
return self.out_proj(y)
def selective_scan(self, x, A, B, C, delta):
"""
选择性扫描算法的简化实现
在实际的 Mamba 中,这部分使用 CUDA kernel 高度优化
"""
batch, seq_len, d_model = x.shape
d_state = A.shape[0]
# 初始化隐状态
h = torch.zeros(batch, d_model, d_state, device=x.device)
outputs = []
for t in range(seq_len):
# 离散化(每个时间步不同!)
delta_t = delta[:, t, :].unsqueeze(-1) # (B, D, 1)
A_bar = torch.exp(delta_t * A) # (B, D, d_state)
B_bar = delta_t * B[:, t, :].unsqueeze(1) # (B, D, d_state)
# 状态更新
h = h * A_bar + x[:, t, :].unsqueeze(-1) * B_bar
# 输出
y = (h * C[:, t, :].unsqueeze(1)).sum(dim=-1) # (B, D)
outputs.append(y)
return torch.stack(outputs, dim=1)
32.7.3 与 Attention 的直接对比
def compare_complexity():
"""比较 Attention 和 SSM 的复杂度"""
import time
d_model = 512
d_state = 16
batch = 4
print("序列长度 | Attention | SSM | 比值")
print("-" * 50)
for seq_len in [512, 1024, 2048, 4096, 8192]:
x = torch.randn(batch, seq_len, d_model).cuda()
# Attention (简化版)
attention = nn.MultiheadAttention(d_model, 8).cuda()
torch.cuda.synchronize()
start = time.time()
for _ in range(10):
attention(x, x, x)
torch.cuda.synchronize()
attn_time = (time.time() - start) / 10
# SSM
ssm = SelectiveSSM(d_model, d_state).cuda()
torch.cuda.synchronize()
start = time.time()
for _ in range(10):
ssm(x)
torch.cuda.synchronize()
ssm_time = (time.time() - start) / 10
ratio = attn_time / ssm_time
print(f"{seq_len:8} | {attn_time:.4f}s | {ssm_time:.4f}s | {ratio:.2f}x")
# 示例输出:
# 序列长度 | Attention | SSM | 比值
# --------------------------------------------------
# 512 | 0.0023s | 0.0018s | 1.28x
# 1024 | 0.0045s | 0.0032s | 1.41x
# 2048 | 0.0156s | 0.0061s | 2.56x
# 4096 | 0.0589s | 0.0118s | 4.99x
# 8192 | 0.2234s | 0.0233s | 9.58x
关键观察:
- 短序列时差异不大
- 序列越长,SSM 的优势越明显
- 8192 tokens 时 SSM 快接近 10 倍
32.8 Transformer vs Mamba vs Hybrid 对比
32.8.1 性能维度对比
| 维度 | Transformer | Mamba | Hybrid (Jamba) |
|---|---|---|---|
| 计算复杂度 | O(N^2) | O(N) | O(N) |
| 内存复杂度 | O(N^2) + O(N) KV Cache | O(N) + O(1) state | 介于两者 |
| 训练速度 | 中 | 快 | 快 |
| 推理延迟 | 高(长序列) | 低 | 低 |
| 语言理解 | 最强 | 强 | 很强 |
| 长序列能力 | 受限 | 优秀 | 优秀 |
| In-context Learning | 最强 | 较弱 | 强 |
| 生态成熟度 | 最成熟 | 发展中 | 发展中 |
32.8.2 适用场景
选择 Transformer 当:
- 序列长度 < 4K
- 需要最强的理解能力
- 生态工具链成熟度重要
- 不追求极致推理速度
选择 Mamba 当:
- 序列长度 > 8K
- 推理速度是关键指标
- 内存受限的环境
- 可以接受略低的语言理解能力
选择混合架构 (Jamba) 当:
- 需要处理超长上下文(32K+)
- 同时需要强语言理解和高效推理
- 愿意使用较新的技术栈
- 有 256K 级别上下文需求
32.8.3 未来趋势预测
短期(2024-2025):
- Transformer 仍是主流
- Mamba/SSM 在特定场景获得采用
- 混合架构开始进入生产
中期(2025-2027):
- 混合架构可能成为新标准
- SSM 技术持续改进
- 硬件优化跟上软件创新
长期展望:
- 可能出现完全超越 Attention 的新范式
- 或者 Attention 和 SSM 完全统一
- 架构选择将更加任务特定化
32.8.4 学术与工业的态度
学术界:
- 对 SSM 理论非常感兴趣
- State Space Duality 被认为是重要突破
- 持续探索更多替代架构
工业界:
- 谨慎观望,等待更多验证
- Google、Meta 在内部实验
- 开源社区积极跟进
- 云服务商开始支持
核心认知:Transformer 不会被轻易取代,但 Mamba 和混合架构代表了重要的技术方向。理解这些新架构,为未来做好准备。
32.9 本章要点
32.9.1 核心概念总结
-
Transformer 的瓶颈
- O(N^2) 复杂度限制长序列处理
- KV Cache 内存随序列线性增长
- 这两个问题在超长上下文时变得严重
-
SSM 的核心思想
- 用"状态"压缩历史信息
- 连续时间到离散时间的转换
- 递归形式用于推理,卷积形式用于训练
-
Mamba 的创新
- 选择性机制:参数依赖于输入
- 让 SSM 能根据内容决定记忆什么
- 硬件感知的高效算法
-
混合架构的优势
- 结合 Transformer 的理解能力和 SSM 的效率
- Jamba:52B 参数,256K 上下文
- 可能是实用的最优选择
32.9.2 关键公式
SSM 离散形式:
Mamba 的选择性:
复杂度对比:
| Attention | SSM | |
|---|---|---|
| 计算 | O(N^2 d) | O(N d) |
| 内存 | O(N^2) | O(N) |
32.9.3 实践建议
-
评估任务需求:
- 短序列(`<4K):Transformer 仍是最佳选择
- 长序列(>8K):考虑 Mamba 或混合架构
-
关注硬件支持:
- Mamba 需要专门的 CUDA kernel
- 检查框架支持(PyTorch, JAX)
-
混合架构实验:
- 尝试不同的 Transformer/Mamba 比例
- 监控质量和效率的权衡
-
保持关注:
- 这是快速发展的领域
- Mamba-2, RWKV-6 等持续改进
本章交付物
学完这一章,你应该能够:
- 解释 Transformer O(N^2) 复杂度的来源和影响
- 理解 SSM 的基本原理(状态方程、离散化)
- 说明 Mamba 的"选择性"机制为何重要
- 解释混合架构(如 Jamba)的设计理念
- 根据任务特点选择合适的架构
延伸阅读
-
Mamba: Linear-Time Sequence Modeling with Selective State Spaces (Gu & Dao, 2023)
- 原论文,详细介绍选择性 SSM
-
Transformers are SSMs: Generalized Models and Efficient Algorithms Through Structured State Space Duality (Dao & Gu, 2024)
- Mamba-2 论文,State Space Duality 理论
-
Jamba: A Hybrid Transformer-Mamba Language Model (AI21 Labs, 2024)
- 混合架构的工业实践
-
S4: Efficiently Modeling Long Sequences with Structured State Spaces (Gu et al., 2021)
- SSM 在深度学习中的突破
-
RWKV: Reinventing RNNs for the Transformer Era (Peng et al., 2023)
- 另一种 O(N) 替代架构
下一步
Part 9 到此全部完成。你已经从 Transformer 的基础一路学到了最前沿的发展:
- 第 28 章:Prompt Engineering - 如何更好地使用模型
- 第 29 章:RLHF 与偏好学习 - 为什么模型"听话"
- 第 30 章:Mixture of Experts - 稀疏激活的秘密
- 第 31 章:推理模型 - o1/R1 的突破
- 第 32 章:后 Transformer 架构 - Mamba 与混合模型
接下来,你可以:
- 回顾 Part 1-8,巩固基础
- 阅读原论文,深入理解细节
- 动手实验,亲自体验这些技术
- 关注最新发展,这个领域变化很快
最后的思考:Transformer 统治了过去 7 年,但没有永恒的王者。Mamba、混合架构、以及未来更多的创新,都在告诉我们:保持学习,保持开放。