本站内容由 AI 生成,可能存在错误。如发现问题,欢迎到 GitHub Issues 反馈。

图优化 Pass(上):数据流分析基础与通用 Pass 模式

图优化 Pass(上):数据流分析基础与通用 Pass 模式

更新于 2026-04-23

查看全景图用户代码全景图计算图捕获IR 设计优化 Pass5. 数据流分析 & Pass 基础你在这里算子融合代码生成调度与执行硬件执行

简介:从 IR 到优化器

在前面的文章中,我们深入探讨了中间表示(IR)的设计——SSA 形式、FX Graph 和 MLIR Dialect 系统。这些 IR 提供了表达程序语义的结构,但 IR 本身只是编译器的”数据格式”。真正让编译器发挥作用的,是在 IR 上进行的分析(Analysis)变换(Transformation)

这就引出了编译器优化的核心概念:Pass

一个 Pass 是编译器中的一个独立模块,它以 IR 为输入,进行分析或变换,并输出修改后的 IR(或分析结果)。Pass 是编译器优化的基本组织单元——它让复杂的编译流程变得模块化、可测试、可复用。

本文是图优化 Pass 系列的上篇。我们将从基础概念出发,深入讲解以下核心内容:

  1. 什么是 Pass:Pass 的类型、组织方式、设计原则
  2. 数据流分析基础:格理论(Lattice Theory)、传递函数(Transfer Function)、Worklist 算法
  3. 经典 Pass 深入讲解:死代码消除(DCE)、公共子表达式消除(CSE)、常量折叠(Constant Folding)
  4. Pass 管理基础设施:PyTorch FX 的图变换工具、MLIR Pass Manager
  5. 不动点迭代:如何保证数据流分析收敛

通过本文,你将理解编译器优化的系统化思维方式,掌握实现基础 Pass 的核心技术。

什么是 Pass

Pass 的本质

从最简单的视角看,一个 Pass 就是一个函数:

Pass:IRIR\text{Pass} : \text{IR} \to \text{IR}

或者更准确地说:

Pass:IRIR×Metadata\text{Pass} : \text{IR} \to \text{IR} \times \text{Metadata}

Pass 接收一个 IR 作为输入,对其进行分析或变换,输出修改后的 IR 以及可能的元数据(例如”是否做了修改”、“分析结果”等)。

这个简单的定义隐藏了编译器优化的核心思想:分而治之。复杂的优化流程被分解为一系列独立的 Pass,每个 Pass 专注于一个特定的优化或分析任务。这种模块化设计带来了多个优势:

  • 可测试性:每个 Pass 可以独立测试,不需要运行整个编译流程
  • 可复用性:通用 Pass(如 DCE、CSE)可以在不同编译器中复用
  • 可调试性:可以单独禁用某个 Pass,快速定位是哪个优化引入了问题
  • 可扩展性:添加新优化只需实现新的 Pass,不需要修改编译器框架

Analysis Pass vs Transform Pass

Pass 分为两大类:

Pass 分类体系PassAnalysis Pass(分析)只读,收集信息活跃性分析支配树Transform Pass(变换)修改 IR 结构Local Pass单基本块窥孔优化Global Pass整个函数DCE / CSEModule Pass整个模块函数内联

1. Analysis Pass(分析 Pass)

只读取 IR,不修改 IR。它的目标是提取信息,为后续 Pass 提供分析结果。

例如:

  • Dominator Tree Analysis:计算控制流图的支配树(Dominator Tree)
  • Alias Analysis:判断两个指针是否可能指向同一内存位置
  • Shape Inference:推导张量的形状

Analysis Pass 的输出通常是一个数据结构(如映射表、图结构),存储在编译器的分析缓存中。后续的 Transform Pass 可以查询这些分析结果。

2. Transform Pass(变换 Pass)

修改 IR,实现具体的优化。

例如:

  • Dead Code Elimination:删除不可达或无用的代码
  • Common Subexpression Elimination:消除重复计算
  • Loop Tiling:将循环拆分为更小的块,提升缓存局部性

Transform Pass 可能依赖 Analysis Pass 的结果。例如,DCE 需要知道哪些节点是”活跃”的,这个信息来自活性分析(Liveness Analysis)。

Local Pass vs Global Pass

Pass 还可以按作用范围分类:

1. Local Pass

只在单个基本块(Basic Block)或单个函数内部工作。不需要考虑跨块的控制流。

例如:

  • Constant Folding(常量折叠):如果看到 x = 2 + 3,直接替换为 x = 5
  • Algebraic Simplification(代数简化):如 x + 0 → xx * 1 → x

Local Pass 的优势是简单、快速。但它们的优化能力有限——很多优化机会需要跨基本块的信息。

2. Global Pass

需要分析整个函数或整个程序(Interprocedural Pass)。通常涉及控制流图(CFG)的遍历。

例如:

  • Global Value Numbering(全局值编号):发现跨基本块的重复计算
  • Loop Invariant Code Motion(循环不变量外提):将循环内的常量计算移到循环外
  • Inlining:将函数调用替换为函数体

Global Pass 更强大,但也更复杂。它们需要数据流分析(Data Flow Analysis)来跟踪信息如何在控制流图中传播。

PyTorch FX 中的 Pass 示例

在 PyTorch FX 中,Pass 通常是一个 Python 函数,接收 GraphModule 并返回修改后的 GraphModule

import torch
from torch.fx import GraphModule

def dead_code_elimination(gm: GraphModule) -> GraphModule:
    """
    删除没有被使用的节点(死代码)。
    """
    graph = gm.graph
    # 收集所有被使用的节点
    used_nodes = set()
    for node in graph.nodes:
        if node.op == 'output':
            # 从 output 节点开始,递归标记所有输入
            def mark_used(n):
                if n not in used_nodes:
                    used_nodes.add(n)
                    for inp in n.all_input_nodes:
                        mark_used(inp)
            for arg in node.args:
                if isinstance(arg, torch.fx.Node):
                    mark_used(arg)
    
    # 删除未使用的节点
    for node in list(graph.nodes):
        if node not in used_nodes and node.op not in ('placeholder', 'output'):
            graph.erase_node(node)
    
    gm.recompile()
    return gm

这个简单的 DCE Pass 展示了 FX 图变换的基本模式:遍历节点、判断条件、修改图结构。

数据流分析基础

许多全局优化需要知道”程序某一点的状态”——例如,某个变量在这里的值是多少?哪些变量在这里是活跃的?这就是**数据流分析(Data Flow Analysis)**要解决的问题。

数据流分析是编译器优化的核心技术之一。它提供了一套数学化的框架,让我们可以系统化地推导程序在每一点的状态。

格理论(Lattice Theory)基础

数据流分析的数学基础是格(Lattice)。一个格是一个偏序集(Partially Ordered Set),其中任意两个元素都有最小上界(Least Upper Bound)和最大下界(Greatest Lower Bound)。

对于编译器优化而言,格定义了”信息的精确程度”。我们用一个经典例子——常量传播(Constant Propagation)——来说明。

常量传播的格

对于每个变量,我们希望知道它在程序某一点的值。可能的情况有三种:

  1. ⊤(Top):我们还不知道这个变量的值(初始状态)
  2. 常量 c:我们确定这个变量的值是常量 c
  3. ⊥(Bottom):这个变量的值不是常量(可能在不同路径有不同值)

这三种情况形成一个格,偏序关系为:

c\top \sqsupseteq c \sqsupseteq \bot

含义是:⊤ 是”最不精确”的信息(什么都不知道),常量是”精确”的信息,⊥ 是”冲突”的信息(确定不是常量)。

Meet 操作(∧)

当两条控制流路径汇合时,我们需要合并它们携带的信息。这个合并操作称为 Meet,记作 \land

对于常量传播,Meet 的定义是:

x=xcc=c(相同常量)c1c2=(不同常量)x=\begin{aligned} \top \land x &= x \\ c \land c &= c \quad \text{(相同常量)} \\ c_1 \land c_2 &= \bot \quad \text{(不同常量)} \\ \bot \land x &= \bot \end{aligned}

直觉理解:如果两条路径给出了相同的常量,我们就知道变量是那个常量。如果两条路径给出了不同的常量,我们就知道变量不是常量。

传递函数(Transfer Function)

传递函数描述了”一个语句如何改变数据流信息”。

对于常量传播,传递函数 FsF_s 的定义取决于语句 ss 的类型:

1. 赋值语句 x = c(常量赋值)

Fx=c(σ)=σ[xc]F_{x=c}(\sigma) = \sigma[x \mapsto c]

含义:变量 xx 的值变为常量 cc,其他变量不变。

2. 赋值语句 x = y + z

Fx=y+z(σ)={σ[xc1+c2]if σ(y)=c1,σ(z)=c2σ[x]otherwiseF_{x=y+z}(\sigma) = \begin{cases} \sigma[x \mapsto c_1 + c_2] & \text{if } \sigma(y) = c_1, \sigma(z) = c_2 \\ \sigma[x \mapsto \bot] & \text{otherwise} \end{cases}

含义:如果 yyzz 都是常量,那么 xx 是它们的和;否则 xx 不是常量。

3. 分支语句 if condition

对于分支语句,两条分支的传递函数相同(都是恒等函数),但在汇合点需要应用 Meet 操作。

Worklist 算法

数据流分析通常使用 Worklist 算法(也称为 Iterative Data Flow Analysis)来计算不动点。

算法的核心思想是:

  1. 初始化所有基本块的数据流信息(通常是 ⊤)
  2. 将所有基本块加入 Worklist
  3. 从 Worklist 中取出一个基本块 BB
  4. BB 应用传递函数,计算 BB 的输出状态
  5. 如果 BB 的输出状态发生变化,将 BB 的所有后继块加入 Worklist
  6. 重复步骤 3-5,直到 Worklist 为空

这个算法保证收敛,因为:

  • 格的高度有限(对于常量传播,最多三层)
  • 传递函数是单调的(信息只会变得更精确,不会回退)

伪代码

Input: CFG with basic blocks B₁, B₂, ..., Bₙ
Output: IN[B] and OUT[B] for each basic block B

// 初始化
for each block B:
    IN[B] = ⊤  // 初始状态:无信息
    OUT[B] = ⊤

Worklist = {B₁, B₂, ..., Bₙ}

// 迭代
while Worklist is not empty:
    B = Worklist.pop()
    
    // 计算 B 的输入状态(来自前驱块)
    IN[B] = ⋀(OUT[P] for P in predecessors(B))
    
    // 应用传递函数
    old_OUT = OUT[B]
    OUT[B] = TransferFunction(B, IN[B])
    
    // 如果输出变化,将后继块加入 Worklist
    if OUT[B] ≠ old_OUT:
        for S in successors(B):
            Worklist.add(S)

前向 vs 后向分析

数据流分析有两个方向:

1. 前向分析(Forward Analysis)

信息沿着控制流从前往后传播。基本块的输入状态由前驱块的输出状态决定。

例如:

  • 常量传播:从定义点向使用点传播
  • Available Expressions(可用表达式分析):哪些表达式在这一点已经被计算过

2. 后向分析(Backward Analysis)

信息沿着控制流从后往前传播。基本块的输出状态由后继块的输入状态决定。

例如:

  • Liveness Analysis(活性分析):哪些变量在这一点之后还会被使用
  • Dead Code Elimination:哪些代码不影响程序输出

对于后向分析,Worklist 算法略有不同:我们从出口块开始,沿着控制流反向遍历。

数据流分析:Worklist 算法演示

步骤 0
数据流分析:Worklist 算法演示B0 (entry)x = 3y = 5x=⊤, y=⊤, z=⊤, w=⊤B1z = x + yx=⊤, y=⊤, z=⊤, w=⊤B2if z > 0x=⊤, y=⊤, z=⊤, w=⊤B3w = z × 2x=⊤, y=⊤, z=⊤, w=⊤B4w = z − 1x=⊤, y=⊤, z=⊤, w=⊤B5 (exit)return wx=⊤, y=⊤, z=⊤, w=⊤WorklistB0B1B2B3B4B5图例⊤(未知)常量⊥(冲突)

经典 Pass 深入讲解

现在我们深入三个经典的 Pass,看看它们的算法、实现和效果。

死代码消除(Dead Code Elimination, DCE)

问题定义

死代码(Dead Code)是指”不影响程序可观察行为”的代码。有两种类型的死代码:

  1. 不可达代码(Unreachable Code):控制流永远不会到达的代码
  2. 无用代码(Useless Code):计算了值但这个值永远不会被使用

DCE 的目标是安全地删除所有死代码。

算法

DCE 是一个后向数据流分析问题。我们从程序的”输出”(返回值、I/O 操作、外部可见的副作用)开始,反向追踪哪些计算对输出有贡献。

对于 SSA 形式的 IR,DCE 的算法极为简洁:

Input: SSA-form program
Output: Program with dead code removed

// Step 1: 标记阶段 (Mark phase)
worklist = {output node}
live_set = {}

while worklist is not empty:
    node = worklist.pop()
    if node not in live_set:
        live_set.add(node)
        for operand in node.inputs:
            worklist.add(operand)

// Step 2: 删除阶段 (Sweep phase)
for node in program:
    if node not in live_set:
        delete node

SSA 的优势

在 SSA 形式中,每个变量只有一个定义点,且 use-def chain 是显式的。这让 DCE 变得 trivial:只需要检查一个节点的 users 列表是否为空。如果为空且不是 output 节点,它就是死代码。

在非 SSA 形式中,DCE 需要进行复杂的活性分析——沿着控制流反向追踪每个变量的使用情况,这是一个需要不动点迭代的过程。

死代码消除 (DCE) 前后对比变换前变换后DCExyaddmuladdsuboutputxyaddaddoutput死节点(被删除)活跃节点(保留)

死代码消除(Dead Code Elimination)

1. 初始计算图

图中存在未被使用的节点(e, f, const3)

aba + bc × 22a − be × 33output活跃死代码

在 PyTorch FX 中实现 DCE

def dce_pass(gm: torch.fx.GraphModule) -> torch.fx.GraphModule:
    graph = gm.graph
    live_nodes = set()
    
    # 从 output 节点开始标记
    for node in graph.nodes:
        if node.op == 'output':
            worklist = list(node.all_input_nodes)
            while worklist:
                n = worklist.pop()
                if n not in live_nodes:
                    live_nodes.add(n)
                    worklist.extend(n.all_input_nodes)
    
    # 删除死节点
    for node in list(graph.nodes):
        if node not in live_nodes and node.op not in ('placeholder', 'output'):
            graph.erase_node(node)
    
    gm.recompile()
    return gm

效果示例

Before:

def forward(x, y):
    a = x + 1          # 被使用 → 保留
    b = x * 2          # 死代码 → 删除
    c = a + y          # 被使用 → 保留
    d = b - y          # 死代码 → 删除
    return c

After DCE:

def forward(x, y):
    a = x + 1
    c = a + y
    return c

公共子表达式消除(Common Subexpression Elimination, CSE)

问题定义

公共子表达式(Common Subexpression)是指”在程序多处计算相同表达式”的情况。CSE 的目标是识别并消除重复计算。

算法

CSE 的核心是计算每个操作的”哈希值”(也称为 Value Number)。如果两个操作有相同的哈希值,它们就计算相同的结果,可以合并为一个。

对于 SSA 形式的 IR,哈希值的计算非常简单:

hash(op,operands)=(op_type,operand_ids)\text{hash}(\text{op}, \text{operands}) = (\text{op\_type}, \text{operand\_ids})

两个操作的哈希值相同,当且仅当:

  1. 操作类型相同(如都是 add
  2. 操作数相同(在 SSA 中,相同的变量版本意味着相同的值)

算法伪代码

Input: SSA-form program
Output: Program with common subexpressions eliminated

hash_table = {}  // hash → canonical node

for node in topological_order(program):
    if node.op in ['input', 'const']:
        continue  // 输入和常量不参与 CSE
    
    hash_val = (node.op, tuple(node.inputs))
    
    if hash_val in hash_table:
        // 找到重复计算,替换所有使用者
        canonical_node = hash_table[hash_val]
        replace_all_uses(node, canonical_node)
        delete node
    else:
        hash_table[hash_val] = node

在 MLIR 中的实现

MLIR 的 Canonicalization Pass 包含了 CSE。MLIR 的 Operation 类提供了 isIdenticalTo() 方法,用于判断两个操作是否等价。

// MLIR CSE 的核心逻辑(简化版)
DenseMap<Operation*, Operation*> cseDominanceMap;

for (Operation &op : block) {
    // 查找是否有等价的操作已经存在
    for (auto &entry : cseDominanceMap) {
        Operation *existingOp = entry.second;
        if (op.isIdenticalTo(existingOp)) {
            // 替换所有使用
            op.replaceAllUsesWith(existingOp->getResults());
            op.erase();
            break;
        }
    }
    // 记录这个操作
    cseDominanceMap[&op] = &op;
}
公共子表达式消除 (CSE) 前后对比变换前变换后CSExyadd重复muladd重复addoutputxyadd合并muladdoutput重复子表达式合并后唯一节点

公共子表达式消除(Common Subexpression Elimination)

1. 初始图

存在重复计算:c1 和 c2 都是 a + b

aba + ba + brelureluoutput

效果示例

Before:

def forward(x, y):
    a = x + y          # 第一次计算 x + y
    b = x * 2
    c = x + y          # 重复计算 x + y
    d = a + c
    return d

After CSE:

def forward(x, y):
    a = x + y          # 只计算一次
    b = x * 2
    d = a + a          # c 被替换为 a
    return d

注意事项

CSE 需要小心处理以下情况:

  1. 副作用(Side Effects):如果操作有副作用(如 I/O、修改全局状态),即使看起来相同也不能合并
  2. 支配关系(Dominance):只能用一个操作替换它支配的其他操作。如果两个操作在不同的控制流分支,需要检查支配关系

常量折叠(Constant Folding)

问题定义

常量折叠是指”在编译时计算常量表达式”。如果一个操作的所有输入都是编译时常量,那么这个操作的结果也可以在编译时计算。

算法

常量折叠是一个前向数据流分析问题。我们从输入和常量定义开始,沿着数据流传播常量信息。

对于 SSA 形式的 IR,算法非常直接:

Input: SSA-form program
Output: Program with constants folded

for node in topological_order(program):
    if node.op == 'const':
        continue  // 已经是常量
    
    // 检查所有输入是否都是常量
    all_const = True
    const_inputs = []
    for inp in node.inputs:
        if inp.op != 'const':
            all_const = False
            break
        const_inputs.append(inp.value)
    
    if all_const:
        // 编译时计算
        result = evaluate(node.op, const_inputs)
        // 替换为常量
        const_node = create_const(result)
        replace_all_uses(node, const_node)
        delete node

在 PyTorch FX 中的实现

def constant_fold_pass(gm: torch.fx.GraphModule) -> torch.fx.GraphModule:
    graph = gm.graph
    
    for node in graph.nodes:
        if node.op == 'call_function' and node.target in FOLDABLE_OPS:
            # 检查所有输入是否都是常量
            all_const = True
            const_values = []
            for arg in node.args:
                if isinstance(arg, torch.fx.Node) and arg.op == 'get_attr':
                    # 从模型中获取常量参数
                    const_values.append(getattr(gm, arg.target))
                elif isinstance(arg, (int, float)):
                    const_values.append(arg)
                else:
                    all_const = False
                    break
            
            if all_const:
                # 编译时执行
                result = node.target(*const_values)
                # 创建常量节点
                with graph.inserting_before(node):
                    const_node = graph.create_node(
                        'get_attr', 
                        f'_folded_const_{node.name}',
                        args=(), kwargs={}
                    )
                    setattr(gm, const_node.target, result)
                node.replace_all_uses_with(const_node)
                graph.erase_node(node)
    
    gm.recompile()
    return gm

效果示例

Before:

def forward(x):
    a = 2 + 3          # 编译时常量
    b = x * a
    c = 10 / 2         # 编译时常量
    d = b + c
    return d

After Constant Folding:

def forward(x):
    a = 5              # 折叠为常量
    b = x * 5          # a 被内联
    c = 5.0            # 折叠为常量
    d = b + 5.0
    return d

常量传播 vs 常量折叠

这两个术语经常被混用,但有细微差别:

  • 常量折叠(Constant Folding):计算常量表达式的值
  • 常量传播(Constant Propagation):将常量值传播到使用点

通常它们一起工作:常量传播识别哪些变量是常量,常量折叠计算常量表达式,然后再次常量传播……如此迭代直到不动点。

Pass 管线交互演示

现在,让我们通过一个交互式演示来直观感受 Pass 的效果,以及不同 Pass 的组合顺序如何影响最终结果。

Pass 管线模拟器拖拽调整 Pass 顺序,观察不同管线对计算图的影响xy01x + 0y × 1x + 0relureluoutput可用的 Pass:死代码消除公共子表达式消除常量折叠规范化已选管线:运行管线重置图

观察要点

  1. Pass 顺序很重要:先运行 Constant Folding 再运行 DCE,可能比反过来更有效。因为 Folding 可能产生新的死代码。

  2. 不动点迭代:有时需要多次运行相同的 Pass。例如:

    • CSE 可能暴露新的死代码(合并后某些节点变得无用)
    • DCE 可能暴露新的 CSE 机会(删除节点后发现新的重复)
  3. Canonicalize 的特殊地位:Canonicalization(规范化)不是单一优化,而是一组”局部模式重写”的集合。它通常放在管线的最前面或多次运行。

Pass 管理基础设施

实际的编译器需要一套完整的基础设施来管理 Pass 的执行。这包括 Pass 的注册、依赖管理、缓存、调试工具等。

Pass Manager 架构Pass 注册表Pass ManagerIR排序 · 依赖分析缓存 · 失效管理调试 · 日志PyTorch FXreplace_patternShapeProp自定义 Pass注册调度执行应用变换FX GraphMLIRCanonicalizeCSEInliner注册调度执行应用变换MLIR Module

PyTorch FX 的图变换工具

PyTorch FX 提供了一套轻量级的图变换工具。最重要的是 subgraph_rewriter——它允许通过模式匹配和替换来实现 Pass。

Subgraph Rewriter 示例

from torch.fx import subgraph_rewriter

def pattern(x):
    """要匹配的模式:x + 0"""
    return x + 0

def replacement(x):
    """替换为:x"""
    return x

# 在 GraphModule 上应用重写
subgraph_rewriter.replace_pattern(gm, pattern, replacement)

这个工具自动识别图中所有匹配 pattern 的子图,并替换为 replacement。这是一种声明式的 Pass 实现方式——你只需描述”what”(匹配什么、替换为什么),框架负责”how”(如何遍历、如何修改图)。

更复杂的例子:融合 ReLU 到 Conv

def fuse_conv_relu(gm: GraphModule) -> GraphModule:
    def pattern(x, weight):
        conv = torch.nn.functional.conv2d(x, weight)
        relu = torch.nn.functional.relu(conv)
        return relu
    
    def replacement(x, weight):
        # 使用融合的 conv_relu 算子
        return torch.ops.my_ops.conv_relu(x, weight)
    
    subgraph_rewriter.replace_pattern(gm, pattern, replacement)
    return gm

MLIR Pass Manager

MLIR 提供了一套完整的 Pass 管理基础设施,包括:

1. Pass 注册

每个 Pass 继承自 OperationPass<OpT>,并通过宏注册到系统中。

// 定义一个 Pass
struct MyCSEPass : public OperationPass<MyCSEPass, ModuleOp> {
    void runOnOperation() override {
        // Pass 的实现
        ModuleOp module = getOperation();
        // ... 对 module 进行 CSE
    }
};

// 注册 Pass
void registerMyCSEPass() {
    PassRegistration<MyCSEPass>("my-cse", "My CSE Pass");
}

2. Pass Manager

Pass Manager 负责执行 Pass 管线,管理分析缓存,处理 Pass 失败等。

PassManager pm(&context);

// 添加 Pass
pm.addPass(createCSEPass());
pm.addPass(createDCEPass());
pm.addPass(createCanonicalizerPass());

// 运行管线
if (failed(pm.run(module))) {
    llvm::errs() << "Pass failed\n";
}

3. 嵌套 Pass(Nested Pass)

MLIR 支持在不同层次运行 Pass。例如,可以在每个函数上运行一个 Pass,而不是在整个 Module 上。

pm.addNestedPass<FuncOp>(createCSEPass());  // 在每个函数上运行 CSE

4. Pass Pipeline 配置

MLIR 支持通过文本格式配置 Pass 管线,方便实验和调试。

mlir-opt input.mlir --pass-pipeline='
  builtin.module(
    func.func(cse, canonicalize),
    inline,
    func.func(cse, dce)
  )'

这个管线的含义是:

  1. 对每个函数运行 CSE 和 Canonicalize
  2. 运行 Inlining(函数内联)
  3. 再次对每个函数运行 CSE 和 DCE

Pass 依赖与分析保留

Pass 之间可能有依赖关系。例如,某个 Pass 需要支配树(Dominator Tree)分析的结果。MLIR 的 Pass 系统提供了声明和管理这些依赖的机制。

struct MyPass : public OperationPass<MyPass, FuncOp> {
    void getDependentDialects(DialectRegistry &registry) const override {
        // 声明依赖的 Dialect
        registry.insert<arith::ArithDialect, scf::SCFDialect>();
    }
    
    void runOnOperation() override {
        FuncOp func = getOperation();
        // 查询分析结果
        auto &dominatorTree = getAnalysis<DominatorTree>();
        // ...
    }
};

分析保留(Analysis Preservation)

Transform Pass 修改 IR 后,某些分析结果可能失效。Pass 系统需要知道哪些分析仍然有效。

void runOnOperation() override {
    // ... 修改 IR
    
    // 标记保留的分析
    markAnalysesPreserved<DominatorTree>();
    // 或标记所有分析失效
    markAllAnalysesPreserved();
}

不动点迭代与收敛保证

许多优化需要多次运行才能达到最佳效果。例如:

初始: a = x + 0; b = a * 1; return b
第一轮 Canonicalize: a = x; b = a * 1; return b  // 删除 + 0
第二轮 Canonicalize: a = x; b = a; return b      // 删除 * 1
第三轮 Copy Propagation: return x                 // 传播 a, b

这就是**不动点迭代(Fixed-Point Iteration)**的概念:反复运行优化,直到 IR 不再变化(达到不动点)。

不动点迭代收敛IR 状态空间随迭代单调收缩初始 IR初始 IRPass 1Pass 1Pass 2Pass 2Pass 3Pass 3不动点单调收缩:每次至少消除一个冗余收敛

为什么会收敛

不动点迭代的收敛依赖两个关键性质:

1. 单调性(Monotonicity)

每次优化都让 IR “变好”(例如节点数减少、计算复杂度降低),且这个”好坏”是可以排序的。

对于数据流分析,单调性表现为:信息在格上单调下降(从 ⊤ 向 ⊥)。

2. 有界性(Boundedness)

优化空间是有限的。例如:

  • 节点数不能无限减少(最少是输入和输出节点)
  • 格的高度有限(对于常量传播,最多三层)

这两个性质保证了迭代必然在有限步内停止。

实际的收敛策略

实际编译器通常不会”迭代到不动点”——因为可能需要太多次迭代。常见的策略包括:

1. 固定次数迭代

运行优化管线固定次数(如 2-3 次),这通常足以捕获大部分优化机会。

PassManager pm(&context);
for (int i = 0; i < 3; ++i) {
    pm.addPass(createCSEPass());
    pm.addPass(createDCEPass());
    pm.addPass(createCanonicalizerPass());
}

2. 检测变化

每次运行 Pass 时,记录是否做了修改。如果连续几轮都没有修改,就停止迭代。

bool changed = true;
int round = 0;
while (changed && round < MAX_ROUNDS) {
    changed = false;
    for (auto pass : passes) {
        if (pass.run(module)) {
            changed = true;
        }
    }
    round++;
}

3. 启发式顺序

精心设计 Pass 的顺序,让大部分优化在第一轮就生效。经验法则:

  • 先运行 Canonicalize(清理)
  • 再运行高层优化(如算子融合)
  • 最后运行低层优化(如 DCE、CSE)

Worklist 算法的收敛分析

对于 Worklist 算法,我们可以严格证明收敛。

定理:如果格的高度为 hh,基本块数为 nn,则 Worklist 算法的时间复杂度为 O(hne)O(h \cdot n \cdot e),其中 ee 是边数。

证明思路

  • 每个基本块的状态在格上最多下降 hh
  • 每次下降最多导致后继块加入 Worklist
  • 总共最多 hneh \cdot n \cdot e 次 Worklist 操作

对于常量传播,h=3h = 3(⊤, constant, ⊥),所以复杂度是 O(ne)O(n \cdot e),即线性于图的大小。

总结

本文系统介绍了编译器优化 Pass 的基础知识。核心要点回顾:

Pass 是编译器优化的基本组织单元。通过将复杂的优化流程分解为独立的 Pass,我们获得了模块化、可测试、可复用的编译器架构。Pass 分为 Analysis Pass(只读分析)和 Transform Pass(修改 IR),以及 Local Pass(单块)和 Global Pass(跨块)。

数据流分析提供了系统化的优化框架。通过格理论、传递函数和 Worklist 算法,我们可以数学化地推导程序在每一点的状态。前向分析和后向分析分别适用于不同类型的优化问题。

经典 Pass 展示了 SSA 的威力。死代码消除(DCE)、公共子表达式消除(CSE)、常量折叠在 SSA 形式的 IR 上实现极为简洁——use-def chain 是显式的,变量版本唯一性保证了分析的正确性。

Pass 管理基础设施让优化系统化。PyTorch FX 的 subgraph_rewriter 提供了声明式的图重写能力,MLIR Pass Manager 提供了完整的 Pass 注册、依赖管理、分析缓存机制。精心设计的 Pass 管线和不动点迭代策略是高效优化的关键。

在下一篇 图优化 Pass(下):算子融合与 Pattern Rewriting 中,我们将深入探讨更高级的优化技术——算子融合(Operator Fusion)的各种模式(横向融合、纵向融合、循环融合)、基于模式匹配的重写系统(Pattern Rewriting)、以及 MLIR 的 Declarative Rewrite Rules(DRR)框架。

延伸阅读

  • Kildall, “A Unified Approach to Global Program Optimization” (1973) — 数据流分析的奠基论文,首次系统化地描述了格理论和 Worklist 算法
  • Wegman & Zadeck, “Constant Propagation with Conditional Branches” (1991) — 稀疏常量传播算法,比传统 Worklist 更高效
  • Cytron et al., “Efficiently Computing Static Single Assignment Form” (1991) — SSA 的构造算法,是 DCE/CSE 高效实现的基础
  • MLIR Pass Infrastructure 文档 — 详细描述了 MLIR 的 Pass 系统设计和 API
  • PyTorch FX Graph Manipulation 文档 — FX 的图变换工具和 subgraph_rewriter 的使用方法
  • LLVM Programmer’s Manual: Writing an LLVM Pass — LLVM 的 Pass 编写指南,是学习 Pass 实现的经典教程