IR 设计(上):SSA、FX IR 与 MLIR Dialect
更新于 2026-04-23
简介:为什么需要中间表示
在前一篇文章中,我们看到 TorchDynamo 如何通过字节码分析捕获 Python 计算图,AOTAutograd 如何将前向和反向传播联合追踪。但一个关键问题浮现:捕获到的计算图该以什么形式存储和操作?
这就是 Intermediate Representation(中间表示,简称 IR) 的核心问题。IR 是编译器架构中最重要的设计决策之一——它决定了编译器能做哪些分析和优化,不能做哪些;它决定了系统的可扩展性、可维护性和性能上限。
一个直觉类比:IR 之于编译器,如同数据结构之于算法。选错了数据结构,再好的算法也无法高效运行。同理,选错了 IR,再巧妙的优化 pass 也无法发挥作用。
本文是 IR 设计系列的上篇。我们将从编译器的基础概念出发,深入讲解三个核心主题:
- SSA(Static Single Assignment):几乎所有现代编译器 IR 的基石
- FX IR:PyTorch 2.0 的图级中间表示
- MLIR Dialect 系统:Google 提出的多层次 IR 框架
IR 基础概念
从源码到机器码的翻译链
任何编译器(无论是传统的 C/C++ 编译器还是 ML 编译器)都遵循相同的基本模式:
前端(Frontend)负责将源码解析成 IR,优化器(Optimizer)在 IR 上进行变换,后端(Backend)将优化后的 IR 翻译为目标机器码。IR 是这条链的枢纽——所有的分析和优化都在 IR 上进行。
对于 ML 编译器,这条链变得更加复杂。源码是 Python(动态类型、控制流丰富),目标是 GPU/TPU/NPU 等异构硬件。中间可能需要多层 IR,每层关注不同的抽象级别:
好的 IR 设计原则
设计一个好的 IR 需要在多个维度间权衡。以下是编译器社区长期积累的核心原则:
1. 适当的抽象层次
IR 的抽象层次决定了能表达什么、能优化什么。太高层——丢失了底层信息,无法做硬件相关优化。太低层——信息被”展平”,高层模式难以识别。
例如,在 linalg.matmul 层次,编译器知道这是矩阵乘法,可以选择 tiling 策略。降到循环嵌套后,这个语义就模糊了——编译器只看到三层循环和乘加操作,需要做 pattern matching 才能恢复矩阵乘法的语义。
2. 易于分析和变换
IR 的结构应当使常见的编译器分析(数据流分析、别名分析、支配树分析)容易实现。这就是 SSA 形式如此流行的原因——它让 use-def chain 的构建变得trivial。
3. 可扩展性
ML 编译器面对的是一个快速演化的生态:新的算子、新的硬件、新的优化技巧不断出现。IR 必须能够方便地添加新的操作类型,而不需要修改整个编译器框架。MLIR 的 Dialect 系统就是为此而生。
4. 保留足够的语义信息
好的 IR 不应在翻译过程中不必要地丢弃语义信息。例如,类型信息(tensor 的 shape、dtype)对优化至关重要,过早丢弃会导致后续 pass 不得不重新推导这些信息。
SSA 形式
什么是 SSA
Static Single Assignment(静态单赋值) 是一种 IR 的约束形式,其核心规则极为简单:
每个变量只被赋值(定义)一次。
这条看似简单的规则,彻底改变了编译器的设计方式。SSA 的概念由 IBM 研究员在 1980 年代逐步发展(Rosen、Wegman 和 Zadeck 在 1988 年引入了 函数)。Cytron 等人在 1991 年的经典论文 “Efficiently Computing Static Single Assignment Form and the Control Dependence Graph” 中提出了高效构造 SSA 的算法(基于 dominance frontier),使 SSA 成为现代编译器的标准。此后,几乎所有现代编译器都采用 SSA 作为标准 IR 形式——LLVM IR、GCC GIMPLE、Java HotSpot 的 Sea of Nodes、V8 的 TurboFan(后被 Turboshaft 替代),无一例外。
为什么需要 SSA
考虑这段简单的代码:
x = input
if condition:
x = x + 1
else:
x = x * 2
return x
变量 x 被赋值了三次(初始化、if 分支、else 分支)。当编译器看到最后的 return x 时,它必须回答一个问题:这个 x 的值从哪里来? 取决于运行时的 condition,它可能来自 x + 1,也可能来自 x * 2。
在非 SSA 形式中,要回答这个问题需要进行复杂的数据流分析——沿着控制流回溯,考虑所有可能的赋值路径。这个分析的复杂度随程序规模增长而迅速增加。
SSA 通过一个优雅的机制解决了这个问题:给每次赋值一个唯一的版本号,并在控制流汇合点引入 (phi)节点 来合并不同分支的值。
转换后的 SSA 形式:
x₀ = input
if condition:
x₁ = x₀ + 1
else:
x₂ = x₀ * 2
x₃ = φ(x₁, x₂)
return x₃
现在,每个变量版本()只有一个定义点。当你看到 return x₃ 时,你可以直接找到 的唯一定义—— 节点。 节点告诉你: 的值取决于控制流来自哪个分支。
节点的本质
节点不是一个真实的计算操作。它是一个编译器内部的 选择器(selector):根据控制流的来源,选择对应分支的值。在最终的机器码中, 节点通常通过寄存器分配来消除——不同分支将结果放入同一个寄存器。
节点的形式化定义:
含义是:如果控制流从基本块 到达当前块,;如果从 到达,。
SSA 如何简化编译器优化
SSA 形式让许多经典优化变得简洁高效。以下是几个关键例子:
1. 死代码消除(Dead Code Elimination, DCE)
在 SSA 中,如果一个变量版本没有被任何其他操作使用(use list 为空),它就是”死”的,可以安全删除。不需要复杂的活性分析(liveness analysis)。
x₁ = a + b // x₁ 被 x₃ 使用 → 保留
x₂ = c * d // x₂ 没有使用者 → 删除!
x₃ = x₁ + 1 // x₃ 被 return 使用 → 保留
return x₃
2. 常量传播(Constant Propagation)
如果一个变量的唯一定义是常量赋值,那么所有使用点可以直接替换为该常量。在 SSA 中,“唯一定义”是自动保证的。
x₀ = 42 // 常量定义
y₀ = x₀ + 1 // 可以直接替换为 y₀ = 42 + 1 = 43
3. Use-Def Chain
SSA 最重要的属性是:每个变量的 use-def chain 是 trivial 的。给定一个变量的使用(use),它的定义(def)恰好只有一个。这使得很多数据流分析从 降到 。
在非 SSA 形式中,构建完整的 use-def chain 需要遍历所有可能的执行路径,是一个需要不动点迭代的过程。在 SSA 形式中,use-def chain 直接编码在 IR 的结构中。
4. 全局值编号(Global Value Numbering, GVN)
GVN 是发现和消除冗余计算的强大技术。在 SSA 形式中,两个操作如果有相同的操作码和相同的操作数(注意:同一个变量版本意味着同一个值),就一定计算相同的结果,可以合并为一个。
FX IR 详解
FX Graph 的起源与定位
torch.fx 是 PyTorch 2.0 引入的图级中间表示框架。它不是一个传统编译器意义上的”完整 IR”,而是一个轻量级、Python-native 的计算图表示,设计目标是:
- 让 PyTorch 用户和开发者容易理解和调试——IR 本身就是合法的 Python 代码
- 提供程序化的图变换能力——可以用 Python 代码操作图结构
- 作为 TorchDynamo 的输出格式——捕获的计算图以 FX Graph 形式存储
Graph 与 Node 结构
FX IR 的核心是 torch.fx.Graph 对象,它包含一组有序的 torch.fx.Node。每个 Node 代表计算图中的一个操作。一个典型的 FX Graph 如下:
graph():
%x : [#users=1] = placeholder[target=x]
%w : [#users=1] = placeholder[target=w]
%matmul : [#users=1] = call_function[target=torch.matmul](x, w)
%relu : [#users=1] = call_function[target=torch.relu](matmul)
return (relu,)
每个 Node 有以下关键属性:
| 属性 | 说明 | 示例 |
|---|---|---|
op | 操作类型(6 种之一) | call_function |
target | 目标函数/模块/属性 | torch.matmul |
args | 位置参数(其他 Node 或常量) | (x, w) |
kwargs | 关键字参数 | {} |
name | 节点名称(唯一标识) | matmul |
users | 使用此节点输出的节点集合 | {relu} |
Node 的六种操作类型
FX 定义了六种 Node 操作类型,覆盖了 Python 计算图的所有场景:
1. placeholder — 函数参数
表示图的输入参数。每个 placeholder 对应 traced 函数的一个参数。
%x : [#users=1] = placeholder[target=x]
2. call_function — 自由函数调用
表示对 Python 函数的调用,包括 torch.* 操作、operator.* 等。
%matmul : [#users=1] = call_function[target=torch.matmul](x, w)
3. call_method — 方法调用
表示对对象方法的调用,如 tensor.view()、tensor.permute() 等。
%view : [#users=1] = call_method[target=view](x, 128, 768)
4. call_module — 子模块调用
表示对 nn.Module 子模块的调用,target 是子模块在模型中的路径。
%linear : [#users=1] = call_module[target=layers.0.linear](x)
5. get_attr — 属性访问
表示访问模型的参数或 buffer。
%weight : [#users=1] = get_attr[target=layers.0.linear.weight]
6. output — 返回值
表示图的输出。
return (relu,)
FX Graph 的 SSA 属性
FX Graph 天然满足 SSA 约束。每个 Node 恰好定义一个值(其输出),且 Node 名称在图中唯一。这意味着 FX Graph 自动继承了 SSA 的所有优势:
- Use-def chain 是显式的:
node.users给出所有使用者,node.args给出所有依赖 - DCE 是 trivial 的:
len(node.users) == 0且不是output节点,就可以删除 - 遍历是线性的:Node 列表的拓扑序保证了依赖关系
FX Graph 的局限性
FX IR 的设计选择了”简单和 Python-native”,这带来了一些固有的限制:
1. 缺乏类型系统
FX Node 没有内置的类型信息(shape、dtype)。虽然可以通过 ShapeProp 等工具推导,但类型信息不是 IR 结构的一部分。这使得很多需要 shape 信息的优化变得更复杂。
2. 单一抽象层次
FX Graph 只有一层——所有操作都在同一个”函数调用”的抽象级别。不像 MLIR 可以在同一个 module 中混合不同抽象层次的操作。
3. 无控制流表示能力
在 TorchDynamo 的标准使用模式中,FX Graph 是控制流的一个”直线段”(graph break 发生在控制流边界)。图内部没有分支、循环等控制流结构。这意味着 FX Graph 不需要 节点——但也限制了它的表达能力。
4. Python 运行时依赖
FX Graph 本质上是 Python 对象,操作它需要 Python 运行时。这使得将 FX Graph 序列化、跨语言传递、或在非 Python 环境中使用变得困难。
MLIR 的设计哲学
IR 碎片化问题
在 MLIR 出现之前,编译器生态面临一个严峻的问题:IR 碎片化。
每个框架、每个硬件平台都有自己的 IR:TensorFlow 有 HLO,PyTorch 有 FX/TorchScript,TVM 有 Relay/TIR,XLA 有 StableHLO……这些 IR 之间互不兼容,导致了大量的重复工作:
- 同样的优化(如常量折叠、死代码消除)在每个 IR 上都要重新实现
- 支持新硬件需要为每个 IR 编写独立的后端
- 跨框架协作几乎不可能
更深层的问题在于:没有一个单一抽象层次的 IR 能够同时满足所有需求。高层 IR(如 HLO)适合做算子融合,但缺乏表达 tiling、向量化等低层变换的能力。低层 IR(如 LLVM IR)可以生成高效机器码,但丢失了张量级语义。
MLIR 的答案:可扩展的多层次 IR
MLIR(Multi-Level Intermediate Representation)由 Chris Lattner 在 Google 发起,2020 年发表论文 “MLIR: A Compiler Infrastructure for the End of Moore’s Law”。其核心创新是:
不定义一个固定的 IR,而是提供一个构建 IR 的框架。
MLIR 的关键设计选择包括:
1. Dialect(方言)系统
MLIR 将 IR 操作组织为 Dialect——一组相关的操作、类型和属性的集合。每个 Dialect 定义一个特定抽象层次的”语言”。不同的 Dialect 可以在同一个 module 中共存,实现渐进式下降(Progressive Lowering)。
例如,一个 module 可以同时包含 linalg.matmul(高层张量操作)和 arith.addf(标量算术),编译器逐步将高层操作 lower 到低层操作,而不是一步到位。
2. SSA + Region 的统一语义
MLIR 的所有操作(Operation)都遵循 SSA 形式:每个操作产生零个或多个 SSA 值(Value),每个值只有一个定义点。同时,操作可以包含 Region(区域),Region 包含 Block(基本块),Block 包含操作的有序序列——这样就能嵌套地表示控制流和数据流。
%result = "linalg.generic"(...) ({
^bb0(%arg0: f32, %arg1: f32):
%0 = arith.addf %arg0, %arg1 : f32
linalg.yield %0 : f32
}) : (tensor<...>, tensor<...>) -> tensor<...>
在这个例子中,linalg.generic 操作包含一个 Region,Region 中有一个 Block(^bb0),Block 中有具体的标量计算。
3. 操作(Operation)的统一接口
MLIR 中所有的 Operation 共享统一的内部表示:
%results = "dialect.op_name"(%operands) {attributes}
({regions}) : (operand_types) -> result_types
这种统一结构意味着:
- 通用的分析和变换工具(如 DCE、CSE)可以工作在任何 Dialect 上
- 新的 Dialect 可以复用现有的 pass 基础设施
- Dialect 之间的互操作是自然的
Dialect 层次结构
MLIR 生态中有数十个 Dialect,它们形成了一个从高层到低层的层次结构。对于 ML 编译而言,最重要的 Dialect 包括:
让我们详细看几个关键 Dialect:
Linalg Dialect:这是 ML 编译的核心 Dialect。它提供了两类操作:
- Named ops(命名操作):如
linalg.matmul、linalg.conv_2d,语义明确,易于 pattern matching - Generic op(通用操作):
linalg.generic可以表达任意元素级操作,通过indexing_maps和iterator_types描述访问模式
// linalg.matmul 展开为 linalg.generic 的等价形式
%result = linalg.generic {
indexing_maps = [
affine_map<(m, n, k) -> (m, k)>, // A 的访问模式
affine_map<(m, n, k) -> (k, n)>, // B 的访问模式
affine_map<(m, n, k) -> (m, n)> // C 的访问模式
],
iterator_types = ["parallel", "parallel", "reduction"]
} ins(%A, %B : ...) outs(%C : ...) {
^bb0(%a: f32, %b: f32, %c: f32):
%prod = arith.mulf %a, %b : f32
%sum = arith.addf %c, %prod : f32
linalg.yield %sum : f32
} -> tensor<M x N x f32>
indexing_maps 使用仿射映射(affine map)描述每个操作数如何通过循环迭代器来索引,iterator_types 标记哪些维度是并行的、哪些是归约的。这些结构化信息为 tiling、fusion、parallelization 等变换提供了坚实的基础。
Tensor vs MemRef:MLIR 中”值语义”到”引用语义”的转换是通过 bufferization 实现的。
tensor<128x768xf32>— 值语义,不可变,如同数学中的张量。适合高层优化(fusion 分析不需要担心别名问题)memref<128x768xf32>— 引用语义,是对内存 buffer 的引用。必须考虑别名、生命周期、内存分配
这种设计让高层 pass 可以在”无副作用”的值语义世界中工作,大幅简化了分析;而低层 pass 则在引用语义世界中处理真实的内存布局。
SCF vs CF:类似地,控制流也有两层抽象:
scf.for、scf.if— 结构化控制流,保留了循环和分支的嵌套结构。编译器可以直接识别循环边界、进行循环变换(tiling、unrolling)cf.br、cf.cond_br— 非结构化控制流,退化为基本块之间的跳转(类似 LLVM IR 的 branch)。结构信息丢失,但表达能力更强
Dialect 间的下降(Lowering)
Dialect 之间的转换称为 lowering。MLIR 提供了一套 Conversion 框架 来系统化地实现 lowering:
linalg.matmul → scf.for + arith ops + memref ops → cf.br + llvm ops → LLVM IR
每一步 lowering 都是一个独立的 pass,可以单独测试和调试。这就是 Progressive Lowering(渐进式下降)的核心思想——不是一步从高层 IR 跳到机器码,而是逐步下降,每一步只做一个”小跳跃”。
渐进式下降的优势在于:
- 每一步变换都是局部和可验证的——不需要处理巨大的 abstraction gap
- 中间层次可以插入优化 pass——例如在 linalg 层做 fusion,在 memref 层做 buffer 复用,在 SCF 层做 loop tiling
- 新的 Dialect 可以插入到现有的 lowering 链中——扩展性极强
多层 IR 对比
现在,让我们用一个具体的例子来直观感受不同 IR 层次的差异。下图展示了一个高层算子如何在不同抽象层级被逐步展开为越来越多的低层操作:
我们将观察同一个简单操作——矩阵乘法加 ReLU——在五个不同 IR 层次上的表示。
1def fn(x, w):2 y = torch.matmul(x, w)3 return torch.relu(y)从上到下观察这些 IR,有几个关键趋势:
1. 代码量急剧增加
Python 源码只有 3 行,LLVM IR 需要十几行甚至更多(实际展开后是数百行)。每一步 lowering 都在”展开”高层操作的细节。
2. 抽象程度逐步降低
- Python:
torch.matmul— 一个函数调用 - FX:
call_function[target=torch.matmul]— 显式的图节点,但仍然是”调用 matmul” - Linalg:
linalg.matmul— 附带完整的类型信息(tensor<128x768xf32>),还能展开为linalg.generic看到 indexing_maps - MemRef:从值语义变成引用语义,出现了显式的内存管理
- LLVM IR:循环、标量操作、phi 节点——完全展开为底层操作
3. 类型信息从隐式变为显式
Python 中类型信息完全隐式。到了 MLIR,类型信息变为显式且丰富:tensor<128x768xf32> 不仅说明了形状,还说明了元素类型和语义(值语义 vs 引用语义)。到了 LLVM IR,类型退化为裸指针(float*)加手动偏移计算。
4. 控制流从隐式变为显式
Python 的 torch.matmul 隐式包含三层循环。在 Linalg 中,通过 iterator_types 描述迭代结构。到了 LLVM IR,循环变成了显式的 branch 和 phi 节点。
FX IR vs MLIR Dialect 对比
下表总结了两种 IR 体系的核心差异:
| 维度 | FX IR | MLIR |
|---|---|---|
| 设计目标 | PyTorch 生态内的图捕获和变换 | 通用编译器基础设施 |
| 实现语言 | Python | C++ (有 Python binding) |
| 类型系统 | 无内置类型(可通过 meta 信息推导) | 丰富的可扩展类型系统 |
| 抽象层次 | 单一层次(算子调用) | 多层次(通过 Dialect 组合) |
| SSA 形式 | 隐式满足(每个 Node 一个输出) | 显式 SSA(Value、Operation) |
| 控制流 | 不支持(graph break 处理) | 完整支持(SCF、CF Dialect) |
| 可扩展性 | 通过自定义 op 和 subgraph | 通过 Dialect 和 Operation 注册 |
| 序列化 | Python pickle / export | MLIR 文本/二进制格式(跨语言) |
| 优化基础设施 | 手写 Python 图变换 | Pass Manager、Pattern Rewriting |
| 适用场景 | PyTorch 模型优化、torch.compile | 跨框架、跨硬件的编译器工程 |
| 学习曲线 | 低(Python 开发者友好) | 高(需要编译器背景知识) |
这两种 IR 不是替代关系,而是互补关系。在 PyTorch 2.0 的编译流程中:
- TorchDynamo 捕获 FX Graph
- AOTAutograd 在 FX Graph 上做前向/反向联合追踪
- 后端编译器(如 Inductor 或未来的 MLIR-based 后端)将 FX Graph lower 到更低层的 IR 进行优化和代码生成
FX IR 擅长的是”在 Python 生态中快速原型和迭代”,MLIR 擅长的是”系统化地构建生产级编译器管线”。
总结
本文深入探讨了 ML 编译器中 IR 设计的基础。核心要点回顾:
SSA 形式 是现代编译器 IR 的基石。“每个变量只赋值一次”这条简单规则,加上 节点处理控制流汇合,使得 use-def chain、DCE、常量传播等关键分析和优化变得简洁高效。无论是 FX IR 还是 MLIR,都在本质上遵循 SSA 原则。
FX IR 是 PyTorch 2.0 的图级表示,以 Python-native 的方式提供了计算图的捕获、检查和变换能力。它的优势在于简单和易于调试,但单一抽象层次和缺乏内置类型系统限制了它作为”最终优化 IR”的能力。
MLIR Dialect 系统 提供了一个构建多层次 IR 的框架。通过 Dialect 将操作组织为不同的抽象层次,通过渐进式下降逐步将高层操作翻译为低层操作,通过统一的操作接口让通用分析工具跨 Dialect 工作。这种设计优雅地解决了 IR 碎片化问题。
在下一篇 IR 设计(下):Progressive Lowering 中,我们将深入探讨从 Linalg 到 LLVM IR 的完整下降过程,包括 bufferization、tiling 的 IR 变换、以及 Conversion 框架的工作原理。
延伸阅读
- Cytron et al., “Efficiently Computing Static Single Assignment Form and the Control Dependence Graph” (1991) — SSA 的奠基论文,详细描述了如何高效计算 SSA 形式和支配边界
- Lattner et al., “MLIR: A Compiler Infrastructure for the End of Moore’s Law” (2020) — MLIR 的设计论文,阐述了多层次 IR 和 Dialect 系统的动机与实现
- PyTorch
torch.fx文档 — FX Graph 的完整 API 参考,包括 symbolic tracing 和图变换的教程 - MLIR 官方文档 — 特别是 Language Reference 和 Dialects 页面,提供了所有 Dialect 的详细规范
- “Tutorial: Building a Compiler with MLIR” (MLIR 官方 Tutorial) — 从零搭建一个基于 MLIR 的编译器,是理解 MLIR 工程实践的最佳入口