分布式编译与图分割
更新于 2026-04-23
简介
单卡训练大模型的时代已经结束了。
LLaMA 70B 的参数量为 700 亿,以 FP16 存储需要约 140 GB 显存——而目前最强的单卡 NVIDIA H100 仅有 80 GB HBM3。即便是”较小”的 13B 模型(约 26 GB),加上优化器状态(Adam 需要额外 2 倍参数量)和激活值内存,单卡也难以承载训练所需的全部数据。更不用说 GPT-4 级别的万亿参数模型了。
这意味着分布式训练不是可选项,而是必需品。但分布式训练引入了三个编译器必须解决的核心问题:
- 怎么分割(Partitioning):将计算图和数据分割到多个设备上,使每个设备的内存占用可控
- 怎么通信(Communication):在分割后的设备间交换必要的数据(梯度、激活值、部分结果)
- 怎么藏延迟(Overlap):将通信延迟隐藏在计算之中,避免设备空闲等待
传统上,这三个问题由框架(PyTorch DDP、DeepSpeed)和用户手动解决——用户需要选择并行策略、插入通信原语、调整 micro-batch 大小。但现代编译器正在将这些决策自动化。XLA 的 SPMD partitioner、GSPMD、PyTorch 2.0 的 DTensor 抽象,都代表了编译器驱动的分布式策略这一趋势。
本文将从编译器的视角出发,系统地介绍分布式编译的核心技术:从并行策略的基础知识开始,到 GSPMD 的自动 sharding propagation 算法,再到 torch.compile 与分布式的集成,最后深入通信优化和图分割算法。
并行策略回顾
在深入编译器如何自动化分布式之前,我们需要理解编译器必须在哪些并行策略中做出选择。每种策略在内存、通信和计算效率上有不同的权衡。
数据并行 (Data Parallelism, DP)
数据并行是最简单也最常用的分布式策略。核心思想:每个设备持有完整的模型副本,但处理不同的数据 mini-batch。前向传播独立进行,反向传播后通过 AllReduce 同步梯度。
其中 是第 个设备上的本地梯度, 是设备数。
优点:实现简单,通信可与计算重叠(AllReduce 可以在反向传播过程中逐层启动),在模型能装入单卡时效率接近线性 scalability(~95-100%)。
限制:每个设备必须存放完整模型。当模型参数量超过单卡显存时,纯数据并行无法工作。
PyTorch 的 DistributedDataParallel(DDP)是数据并行的标准实现,通过 bucket AllReduce 实现梯度通信与反向传播的重叠。
全切分数据并行 (Fully Sharded Data Parallel, FSDP)
FSDP(原 ZeRO-3)是对数据并行的内存优化。核心思想:不仅切分数据,还将模型参数、梯度和优化器状态都均匀切分到所有设备上。每个设备只保存 的参数。
执行流程:
- 前向传播前,通过 AllGather 收集当前层的完整参数
- 计算该层的前向输出
- 立即释放非本地参数分片(仅保留本地 )
- 反向传播时再次 AllGather,计算梯度后 ReduceScatter 回各设备
FSDP 的内存节省是显著的:对于 Adam 优化器,每个参数需要 字节(FP16 参数 + FP32 主副本 + 动量 + 方差)。使用 FSDP,每卡仅需 字节/参数。这使得原本需要模型并行的场景可以用 FSDP 解决。
代价是更多的通信量:每层前向和反向各需要一次 AllGather,反向还需要 ReduceScatter。总通信量约为参数量的 3 倍(对比 DDP 的 2 倍),但可以通过预取(prefetching)和重叠来隐藏。
张量并行 (Tensor Parallelism, TP)
张量并行(由 Megatron-LM 提出)将单个层的计算在多个设备间水平切分。对于 Transformer 中的 MLP 层:
列并行(Column Parallel):将权重矩阵 按列切分为 ,每个设备计算 。由于 GeLU 等非线性激活是 element-wise 的,可以直接在分片上执行。
行并行(Row Parallel):将权重矩阵按行切分。每个设备有 (部分结果),需要一次 AllReduce 将 求和得到完整输出。
Megatron-LM 的经典 MLP 分割方案:
按列切分(每设备得到 的中间激活),GeLU 在分片上直接执行, 按行切分,最后 AllReduce 求和。每个 MLP 层需要 1 次 AllReduce(前向)+ 1 次 AllReduce(反向)。
对于 self-attention 层,多头注意力(Multi-Head Attention)天然适合张量并行:将不同的注意力头分配到不同设备。每个设备计算 个头,最后 AllReduce 输出投影。
优点:内存随设备数线性减少;计算效率高(~90-95%,特别是在 NVLink 高带宽互连下)。
限制:AllReduce 在每一层的前向和反向传播中都必须执行,因此对设备间带宽要求极高。NVLink Gen3(A100)提供约 600 GB/s 双向带宽,NVLink Gen4(H100)提供约 900 GB/s。而 PCIe Gen4 仅约 32 GB/s——这就是为什么张量并行通常限制在单个节点内(NVLink 连接的 8 卡之间),而不跨节点使用。
流水线并行 (Pipeline Parallelism, PP)
流水线并行将模型按层垂直切分为多个阶段(stages),每个阶段分配到一个设备。数据以 micro-batch 的形式在阶段间流动,类似工厂流水线。
GPipe 的方案:将一个 mini-batch 分成 个 micro-batch,按流水线方式执行。流水线效率为:
其中 是阶段数(设备数), 是 micro-batch 数。气泡(bubble)比例为 。例如 时,效率约 62.5%,气泡占 37.5%。
为减少气泡,1F1B(one forward, one backward)调度策略交替执行前向和反向,将气泡限制在启动和结束阶段。
优点:通信量最小(仅 P2P 传输层间激活值),对带宽要求低,适合跨节点(InfiniBand NDR ~50 GB/s = 400 Gb/s)。
限制:气泡导致计算效率损失;需要仔细的阶段划分以平衡各阶段的计算量。
专家并行 (Expert Parallelism, EP)
Mixture of Experts(MoE)模型中,不同的专家(expert)分布在不同设备上。输入通过路由器(router)分发到对应的专家,需要 All-to-All 通信进行数据重新分配。
EP 的编译器挑战在于:路由决策是动态的(依赖输入),因此通信模式是数据相关的,编译器很难在编译时完全确定通信量。
混合并行策略
真实的大规模训练系统几乎总是使用多种并行策略的组合。例如:
- LLaMA 70B (Meta): TP=8 (节点内 NVLink) + PP=4 (跨节点) + DP=16 (数据并行)
- GPT-3 175B (OpenAI/Microsoft): TP=8 + PP=8 + DP
- Megatron-Turing 530B (NVIDIA/Microsoft): TP=8 + PP=35 + DP=6
设计混合策略的核心原则:
- TP 使用在高带宽互连内(NVLink,节点内)
- PP 使用在中等带宽互连上(InfiniBand,跨节点)
- DP 使用在任何带宽条件下(通信可重叠)
上面的可视化展示了不同模型大小和 GPU 数量下,DP、TP、PP 三种策略在显存占用、通信模式和计算效率上的对比。注意 175B 模型在任何单策略下都无法在 1-2 张卡上运行——这就是混合并行的必要性。
GSPMD:编译器驱动的自动分割
手动设计并行策略需要深厚的专业知识,而且不同的模型架构可能需要不同的策略。GSPMD(General and Scalable Parallelization for ML Computation Graphs)提出了一种由编译器自动完成分割的方法。
Sharding Specification
GSPMD 的核心抽象是 sharding specification。每个张量都有一个 sharding spec,描述它在设备网格(device mesh)上的分布方式:
sharding_spec = {
tensor_dims: [batch, seq, hidden],
mesh_dims: [x, y],
mapping: {batch -> x, hidden -> y} // batch 沿 mesh 的 x 轴切分, hidden 沿 y 轴切分
}
例如,对于一个 的张量在 的设备网格上:
{batch -> x}表示 batch 维度切分到 4 个设备,每设备持有{batch -> x, hidden -> y}表示 batch 和 hidden 都切分,每设备持有{}表示全副本(replicated),每设备持有
这种表示的关键优势:它足够通用,可以统一表达 DP(batch 维度切分)、TP(hidden 维度切分)、以及它们的组合。
Sharding Propagation 算法
给定用户对少数张量的 sharding 标注,编译器需要推导出所有张量的 sharding spec,并在必要处插入通信算子。这就是 sharding propagation 的过程。
算法的核心是每种运算的 sharding 规则(sharding rule):
MatMul 规则:对于 ()
- 如果 沿 切分 → 也必须沿 切分 → 为 partial sum,需要 AllReduce
- 如果 沿 切分 → 也沿 切分( 无约束)
- 如果 沿 切分 → 也沿 切分( 无约束)
ElementWise 规则(如 ReLU、Add):
- 输入和输出的 sharding spec 必须一致——如果输入沿第 维切分,输出也沿第 维切分
Reduce 规则(如 Sum、Mean):
- 如果沿被 reduce 的维度切分 → 输出需要 AllReduce
- 如果沿非 reduce 维度切分 → 输出保持该切分
Propagation 以工作列表(worklist)算法实现:
- 初始化:将用户标注的张量加入已决定集合
- 遍历计算图中与已决定张量相连的算子
- 应用该算子的 sharding 规则,推导出未决定张量的 sharding spec
- 如果有新的张量被决定,加入工作列表
- 重复直到所有张量都已决定(或检测到冲突需要插入通信)
当两个输入要求同一张量有不同的 sharding spec 时,编译器需要插入 resharding 通信(如 AllGather 将 sharded 张量变为 replicated,AllReduce 将 partial sum 变为完整值,All-to-All 将一种切分变为另一种切分)。
点击 W1 标注为列并行:将 4D 维度切分到 N 个设备。其他张量显示 "?" 表示未决定。
上面的交互演示展示了 GSPMD 的 sharding propagation 过程。Step 1 中用户仅标注 W1 为列并行切分,Step 2 中编译器根据 MatMul 和 ElementWise 的 sharding 规则自动推导出所有中间张量的切分方式,Step 3 中编译器检测到 MatMul₂ 的输出为 partial sum(因为 W2 按行切分,矩阵乘法在切分维度上求和),自动插入 AllReduce 通信算子。
Cost Model
Propagation 可能有多种合法的 sharding 方案。GSPMD 使用 cost model 来选择最优方案:
其中 是 sharding 方案, 是通信与计算的权衡因子。通信 cost 考虑:
- AllReduce 的通信量:(ring AllReduce)
- AllGather 的通信量:
- 拓扑因子:NVLink 内 vs PCIe vs 跨节点的带宽差异
GSPMD 在 XLA 编译器中的实现(SPMD partitioner)已经成功应用于 PaLM(5400 亿参数)等超大规模模型的训练。
torch.compile 与分布式
PyTorch 2.0 的 torch.compile 正在逐步集成分布式能力,使编译器能够跨越分布式边界进行优化。
DTensor 抽象
PyTorch 的 DTensor(Distributed Tensor)是连接编译器和分布式的关键抽象。DTensor 在普通 Tensor 之上附加了两个元信息:
- DeviceMesh:描述设备的逻辑拓扑(如 的 2D mesh)
- Placement:描述张量在 mesh 上的分布方式(
Shard(dim),Replicate(),Partial())
# 创建一个 2D device mesh
mesh = DeviceMesh("cuda", [[0, 1, 2, 3], [4, 5, 6, 7]])
# 将张量分布到 mesh 上
# batch 维度沿 mesh dim 0 切分 (DP), hidden 维度沿 mesh dim 1 切分 (TP)
dtensor = distribute_tensor(tensor, mesh, [Shard(0), Shard(2)])
DTensor 的核心价值在于:它允许 torch.compile 将分布式通信视为计算图中的普通算子,从而进行跨设备的图优化。
FSDP + torch.compile 集成
PyTorch FSDP2 (FSDPv2) 使用 DTensor 作为底层表示,这使得 torch.compile 可以直接”看到”FSDP 的通信模式,并进行优化:
- AllGather 预取:编译器分析计算图的执行顺序,在需要某层参数之前提前发起 AllGather
- ReduceScatter 延迟:将 ReduceScatter 推迟到梯度真正被需要时才执行
- 通信-计算融合:将小的通信算子与计算算子融合,减少 kernel launch 开销
model = FSDP(model, use_orig_params=True)
model = torch.compile(model) # 编译器可以优化 FSDP 的通信模式
在 Meta 的实测中,torch.compile + FSDP2 相比纯 eager FSDP 可以获得 10-20% 的训练吞吐量提升,主要来自通信重叠的改善和不必要通信的消除。
TP + torch.compile 集成
张量并行与 torch.compile 的集成同样基于 DTensor。用户通过 parallelize_module API 指定 TP 策略:
from torch.distributed.tensor.parallel import parallelize_module, ColwiseParallel, RowwiseParallel
tp_mesh = DeviceMesh("cuda", list(range(8)))
parallelize_plan = {
"layers.*.attention.wq": ColwiseParallel(),
"layers.*.attention.wk": ColwiseParallel(),
"layers.*.attention.wv": ColwiseParallel(),
"layers.*.attention.wo": RowwiseParallel(),
"layers.*.feed_forward.w1": ColwiseParallel(),
"layers.*.feed_forward.w2": RowwiseParallel(),
}
model = parallelize_module(model, tp_mesh, parallelize_plan)
model = torch.compile(model)
编译器在这个流程中的角色:
- 追踪 DTensor 的 placement 信息通过计算图
- 识别哪些通信是冗余的(例如连续两个 AllReduce 可以合并)
- 将通信算子调度到最优的时间点
编译器 vs 框架的分工
值得注意的是,PyTorch 当前的分布式编译仍处于”编译器辅助,框架主导”的阶段:
| 决策 | 框架(用户指定) | 编译器(自动优化) |
|---|---|---|
| 选择 DP/TP/PP | 用户决定 | 未来目标 |
| DTensor placement | 用户指定 | 编译器传播 |
| 通信算子插入 | DTensor 自动 | 编译器可消除冗余 |
| 通信调度 | 基本的规则 | 编译器全局优化 |
| 通信-计算重叠 | 手动或 FSDP 内置 | 编译器可进一步优化 |
对比 GSPMD 的”全自动”方案,PyTorch 选择了更渐进的路线:先让用户能表达分布式意图(通过 DTensor),再让编译器优化执行。这种设计尊重了 PyTorch 用户习惯于灵活控制的传统。
通信优化
无论采用哪种并行策略,分布式训练的性能瓶颈往往在通信上。编译器可以通过多种优化技术来降低通信的有效开销。
硬件拓扑感知
通信优化的第一步是理解硬件拓扑。不同互连的带宽差异巨大:
| 互连 | 带宽 | 典型场景 |
|---|---|---|
| NVLink Gen3 (A100) | ~600 GB/s (双向) | 节点内 GPU 间 |
| NVLink Gen4 (H100) | ~900 GB/s (双向) | 节点内 GPU 间 |
| NVSwitch (H100 DGX) | 全对全 900 GB/s | 8-GPU 全连接 |
| PCIe Gen4 x16 | ~32 GB/s (单向) | GPU-CPU, 部分 GPU 间 |
| PCIe Gen5 x16 | ~64 GB/s (单向) | 新一代 GPU-CPU |
| InfiniBand NDR | ~50 GB/s (单向, 400Gb/s) | 跨节点 |
| InfiniBand NDR400 | ~50 GB/s x8 lanes | 下一代跨节点 |
| RoCE (RDMA over Converged Ethernet) | ~25-50 GB/s | 以太网跨节点 |
编译器利用拓扑信息来做出关键决策:
- 通信原语选择:NVLink 上用 NCCL 的 ring/tree AllReduce;跨节点用 hierarchical AllReduce
- 分割策略约束:TP 仅限 NVLink 连接的设备;PP 优先用于跨节点
- 通信量 vs 通信次数权衡:高带宽低延迟的互连偏好小而频繁的通信;低带宽高延迟的互连偏好大而少的通信
AllReduce 融合 (Fusion)
多个小的 AllReduce 可以融合为一个大的 AllReduce。这减少了以下开销:
- Kernel launch 开销:每次 NCCL 调用有 ~10μs 的启动延迟
- 同步开销:每次 AllReduce 需要一次全局 barrier
- 带宽利用率:小消息无法充分利用互连带宽(带宽随消息大小增长,在 ~1MB 处开始饱和)
PyTorch DDP 默认使用 25 MB 的 bucket size 进行梯度 AllReduce 融合。编译器可以进一步优化:
- 自适应 bucket size:根据网络拓扑和当前通信负载动态调整
- 跨层融合:将不同层的梯度融合在同一个 AllReduce 中
- 算子融合:将 AllReduce + 后续的参数更新融合为一个 kernel
计算-通信重叠 (Compute-Communication Overlap)
重叠是隐藏通信延迟的最重要技术。核心思想:在一个 CUDA stream 上执行计算,在另一个 CUDA stream 上同时执行通信,利用 GPU 的计算单元和网络硬件的独立性。
DDP 中的重叠:反向传播中,一旦某些层的梯度计算完成,立即启动这些梯度的 AllReduce,同时继续计算更前面层的梯度。
FSDP 中的重叠:
- 前向传播:在计算当前层时,预取下一层的 AllGather
- 反向传播:在计算当前层梯度时,预取前一层的 AllGather + 异步执行后一层的 ReduceScatter
TP 中的重叠:更具挑战性,因为 AllReduce 在每层的计算路径上(不是梯度路径)。一种方法是将 AllReduce 分解为 ReduceScatter + AllGather,在 ReduceScatter 等待期间执行不依赖通信结果的计算。
重叠的效果取决于通信与计算的时间比。当通信时间小于计算时间时,通信可以完全被隐藏;当通信时间超过计算时间时,重叠只能部分隐藏延迟。
上面的交互演示展示了三种通信优化策略的效果。尝试调整层数和通信/计算比来观察不同配置下的性能差异。关键观察:
- 当通信/计算比为 10% 时,简单的重叠策略就能几乎完全隐藏通信
- 当通信/计算比达到 50% 时,需要 AllReduce 融合或 Bucket 策略来进一步优化
- Bucket AllReduce 通过将通信分散到整个反向传播过程中,实现了最均匀的重叠
CUDA Stream 并行的实现细节
要实现有效的计算-通信重叠,需要正确管理 CUDA stream 和事件:
compute_stream = torch.cuda.Stream()
comm_stream = torch.cuda.Stream()
for layer in model.layers:
# 计算流:执行当前层前向传播
with torch.cuda.stream(compute_stream):
output = layer(input)
# 在通信流上等待计算完成
event = compute_stream.record_event()
comm_stream.wait_event(event)
# 通信流:异步 AllReduce
with torch.cuda.stream(comm_stream):
dist.all_reduce(output, async_op=True)
编译器(如 torch.compile)可以自动生成这种 stream 管理代码,而不需要用户手动插入。这正是编译器在分布式优化中的价值:将底层的 stream 调度、事件同步、内存管理自动化。
图分割算法
当编译器需要将计算图分配到多个设备时,面临的是一个图分割(graph partitioning)问题。这个问题在流水线并行中尤其关键:如何将 层的模型分割为 个阶段,使每个阶段的计算量尽可能均衡。
加权图分割
形式化地,给定计算图 :
- 节点 有权重 (计算量和内存占用)
- 边 有权重 (通信量)
- 目标:将 分为 个子集 ,使得:
- 负载均衡: 最小化
- 通信最小化: 最小化
这是一个 NP-hard 问题(即使 ),所以编译器使用近似算法。
PP 阶段划分算法
对于流水线并行,由于 Transformer 模型的层通常是线性序列,问题简化为序列分割:
贪心算法:将总计算量除以 ,贪心地将层分配到各阶段,使每个阶段的计算量接近 。时间复杂度 ,但不保证最优。
动态规划(DP): 表示将前 层分割为 个阶段时,最大阶段计算量的最小值。
时间复杂度 ,对于 的场景完全可行。
整数线性规划(ILP):对于更复杂的约束(如内存限制、异构设备),可以将问题建模为 ILP。虽然最坏情况是指数时间,但实际的 Transformer 模型结构足够规则,现代 ILP 求解器(如 Gurobi、CPLEX)可以在秒级时间内求解。
Alpa 的混合搜索
Alpa(Automated inter- and intra-operator parallelism)采用了一种两层搜索策略:
层内并行(Intra-op):对于每个 stage,使用 ILP 求解最优的 TP 切分方案(类似 GSPMD 的 sharding propagation,但在优化器中搜索)
层间并行(Inter-op):使用 DP 算法搜索最优的 PP 阶段划分,其中每个阶段的成本由层内 ILP 提供
这种分层搜索的关键优势:将指数级的搜索空间分解为两个可解的子问题。Alpa 在实验中展示了接近专家手动调优的性能,但完全自动化。
编译器中的实际实现
在 XLA 的 SPMD partitioner 中,图分割的流程大致如下:
- Profiling:首先在单设备上运行计算图的轻量 profiling,获取每个算子的计算时间和内存占用
- Cost 估计:基于 profiling 数据和通信模型(考虑设备拓扑),估计每种分割方案的总成本
- 搜索:在搜索空间中寻找成本最低的分割方案(使用 DP 或 ILP)
- Lowering:将分割方案转化为具体的 sharding spec,并通过 propagation 补全所有张量的 sharding
- Code generation:为每个设备生成 SPMD 代码(所有设备执行相同的代码,但操作不同的数据分片),并插入必要的通信算子
这个流程体现了编译器在分布式优化中的核心价值:将用户的高层意图(“在 64 张 GPU 上训练这个模型”)转化为高效的底层执行计划,自动处理分割、通信和调度的复杂性。
总结
分布式编译是 ML 编译器面临的最复杂挑战之一。本文涵盖了以下核心内容:
并行策略基础:DP(简单但内存受限)、TP(高效但需 NVLink)、PP(省内存但有气泡)、FSDP(内存优化的 DP),以及混合并行策略。
GSPMD 自动分割:通过 sharding specification 统一表达各种并行策略,sharding propagation 算法自动推导完整的分割方案,cost model 在多种方案中选择最优。
torch.compile + 分布式:DTensor 抽象让编译器能够感知和优化分布式通信,FSDP2/TP 与编译器的集成已经带来 10-20% 的吞吐量提升。
通信优化:硬件拓扑感知、AllReduce 融合、计算-通信重叠——编译器通过全局视角实现了比手工优化更系统化的通信调度。
图分割算法:从贪心到 DP 到 ILP,再到 Alpa 的混合搜索,编译器正在逐步自动化分布式训练中最需要专业知识的决策。
分布式编译的未来方向是”完全自动化”:用户只需指定模型和硬件配置,编译器自动搜索最优的并行策略、分割方案和通信调度。GSPMD 和 Alpa 已经向这个方向迈出了重要一步,而 torch.compile 的分布式集成正在将这些技术带入更广泛的 PyTorch 生态系统。
在下一篇文章中,我们将讨论调度与执行优化——当编译器完成了图的优化、融合、分割之后,如何高效地调度这些操作在 GPU 上的执行,包括 CUDA Stream 编排、CUDA Graph 捕获和内存规划。
学习路径:图编译与优化