Content on this site is AI-generated and may contain errors. If you find issues, please report at GitHub Issues .

Distributed Compilation and Graph Partitioning

Distributed Compilation and Graph Partitioning

Updated 2026-04-13

View full mapUser CodePanoramaGraph CaptureIR DesignOptimization Passes15. Distributed CompilationYou are hereOperator FusionCode GenerationScheduling & ExecutionHardware Execution

Introduction

The era of single-GPU model training is over.

LLaMA 70B has 70 billion parameters, requiring approximately 140 GB of memory in FP16 — while the most powerful single GPU, NVIDIA’s H100, provides only 80 GB of HBM3. Even a “smaller” 13B model (about 26 GB) becomes infeasible on a single card when you add optimizer states (Adam requires 2x the parameter memory) and activation memory. Trillion-parameter models like GPT-4 are entirely out of the question.

This means distributed training is not optional — it is a necessity. But distributed training introduces three core problems that compilers must address:

  1. How to partition: Split the computation graph and data across multiple devices so that each device’s memory usage is manageable
  2. How to communicate: Exchange necessary data (gradients, activations, partial results) between partitioned devices
  3. How to hide latency: Overlap communication latency with computation to avoid idle devices

Traditionally, these three problems were solved by frameworks (PyTorch DDP, DeepSpeed) and manual user intervention — users needed to choose parallelism strategies, insert communication primitives, and tune micro-batch sizes. But modern compilers are automating these decisions. XLA’s SPMD partitioner, GSPMD, and PyTorch 2.0’s DTensor abstraction all represent the trend of compiler-driven distributed strategies.

This article takes the compiler’s perspective to systematically introduce the core technologies of distributed compilation: starting from the fundamentals of parallelism strategies, through GSPMD’s automatic sharding propagation algorithm, to torch.compile’s distributed integration, and finally diving into communication optimization and graph partitioning algorithms.

Parallelism Strategy Review

Before exploring how compilers automate distribution, we need to understand the parallelism strategies that compilers must choose among. Each strategy has different trade-offs in memory, communication, and compute efficiency.

Data Parallelism (DP)

Data parallelism is the simplest and most widely used distributed strategy. The core idea: each device holds a complete model replica but processes a different data mini-batch. Forward passes are independent; after backward passes, gradients are synchronized via AllReduce.

gavg=1Ni=1Ngig_{\text{avg}} = \frac{1}{N} \sum_{i=1}^{N} g_i

where gig_i is the local gradient on device ii and NN is the device count.

Advantages: Simple to implement, communication can overlap with computation (AllReduce can start layer-by-layer during backward pass), and efficiency approaches linear scalability (~95-100%) when the model fits on a single card.

Limitations: Each device must store the full model. When the model’s parameter size exceeds single-GPU memory, pure data parallelism cannot work.

PyTorch’s DistributedDataParallel (DDP) is the standard data parallelism implementation, using bucket AllReduce to overlap gradient communication with backward computation.

Fully Sharded Data Parallel (FSDP)

FSDP (originally ZeRO-3) is a memory optimization of data parallelism. The core idea: shard not only the data but also the model parameters, gradients, and optimizer states evenly across all devices. Each device stores only 1/N1/N of the parameters.

Execution flow:

  1. Before forward pass, AllGather to collect the full parameters for the current layer
  2. Compute the forward output for that layer
  3. Immediately release non-local parameter shards (keep only the local 1/N1/N)
  4. During backward, AllGather again, then ReduceScatter gradients back to each device

FSDP’s memory savings are significant: for the Adam optimizer, each parameter requires 2+4+4+4=142 + 4 + 4 + 4 = 14 bytes (FP16 params + FP32 master copy + momentum + variance). With FSDP, each card only needs 14/N14/N bytes per parameter. This allows scenarios that previously required model parallelism to be handled with FSDP instead.

The cost is more communication: each layer requires one AllGather during forward and backward, plus ReduceScatter during backward. Total communication volume is approximately 3x the parameter size (compared to DDP’s 2x), but this can be hidden through prefetching and overlap.

Tensor Parallelism (TP)

Tensor parallelism (introduced by Megatron-LM) splits individual layer computations horizontally across multiple devices. For the MLP layers in Transformers:

Column Parallel: Split weight matrix WW by columns into [W1,W2,...,WN][W_1, W_2, ..., W_N], with each device computing Yi=XWiY_i = XW_i. Since non-linear activations like GeLU are element-wise, they can be applied directly on the sharded data.

Row Parallel: Split the weight matrix by rows. Each device has YiY_i (a partial result), requiring one AllReduce to sum YiY_i into the complete output.

Megatron-LM’s classic MLP partitioning scheme:

Y=GeLU(XW1)W2Y = \text{GeLU}(XW_1) W_2

W1W_1 is column-split (each device gets [B,S,4H/N][B, S, 4H/N] intermediate activations), GeLU is applied directly on the shards, W2W_2 is row-split, and a final AllReduce sums the results. Each MLP layer requires 1 AllReduce (forward) + 1 AllReduce (backward).

For self-attention layers, Multi-Head Attention is naturally suited for tensor parallelism: different attention heads are assigned to different devices. Each device computes Nh/NN_h/N heads, with a final AllReduce on the output projection.

Advantages: Memory decreases linearly with device count; high compute efficiency (~90-95%, especially with NVLink high-bandwidth interconnect).

Limitations: AllReduce must be executed in both forward and backward passes of every layer, so it demands extremely high inter-device bandwidth. NVLink Gen3 (A100) provides approximately 600 GB/s bidirectional bandwidth, NVLink Gen4 (H100) approximately 900 GB/s. PCIe Gen4 offers only about 32 GB/s — this is why tensor parallelism is typically limited to within a single node (among 8 GPUs connected via NVLink), not used across nodes.

Pipeline Parallelism (PP)

Pipeline parallelism vertically splits the model by layers into multiple stages, with each stage assigned to a device. Data flows between stages as micro-batches, similar to a factory assembly line.

GPipe’s approach: divide one mini-batch into MM micro-batches and execute them in pipeline fashion. Pipeline efficiency is:

η=MN+1M\eta = \frac{M - N + 1}{M}

where NN is the number of stages (devices) and MM is the number of micro-batches. The bubble ratio is (N1)/M(N-1)/M. For example, with N=4,M=8N=4, M=8, efficiency is approximately 62.5%, with bubbles accounting for 37.5%.

To reduce bubbles, the 1F1B (one forward, one backward) scheduling strategy alternates forward and backward passes, limiting bubbles to the startup and shutdown phases.

Advantages: Minimal communication volume (only point-to-point transfer of inter-layer activations), low bandwidth requirements, suitable for cross-node communication (InfiniBand NDR ~50 GB/s = 400 Gb/s).

Limitations: Bubbles cause compute efficiency loss; careful stage partitioning is needed to balance computation across stages.

Expert Parallelism (EP)

In Mixture of Experts (MoE) models, different experts are distributed across different devices. Input is dispatched to the corresponding expert through a router, requiring All-to-All communication for data redistribution.

The compiler challenge with EP is that routing decisions are dynamic (input-dependent), so communication patterns are data-dependent and compilers cannot fully determine communication volume at compile time.

Hybrid Parallelism Strategies

Real large-scale training systems almost always use combinations of multiple parallelism strategies. For example:

  • LLaMA 70B (Meta): TP=8 (intra-node NVLink) + PP=4 (inter-node) + DP=16 (data parallel)
  • GPT-3 175B (OpenAI/Microsoft): TP=8 + PP=8 + DP
  • Megatron-Turing 530B (NVIDIA/Microsoft): TP=8 + PP=35 + DP=6

Core principles for designing hybrid strategies:

  1. TP on high-bandwidth interconnects (NVLink, intra-node)
  2. PP on medium-bandwidth interconnects (InfiniBand, inter-node)
  3. DP on any bandwidth (communication can be overlapped)
Model:
GPUs:
Per GPU: 80 GB (A100)
Parallel Strategy Comparison: 70B x 8 GPUsData Parallel (DP)8 copiesGPU0GPU1GPU2GPU3GPU4GPU5GPU6GPU7Memory /GPU:140GB OOM!80GBComm:Gradient AllReduceOverlappable with computeEfficiency:Tensor Parallel (TP)8192/8 = 1024 per GPUGPU0GPU1GPU2GPU3GPU4GPU5GPU6GPU7Memory /GPU:18GB OK80GBComm:Intra-layer AllReduceRequires NVLinkEfficiency:~92%Pipeline Parallel (PP)10 layers/stageGPU0GPU1GPU2GPU3GPU4GPU5GPU6GPU7Memory /GPU:18GB OK80GBComm:Point-to-pointBubble: 88%Efficiency:~13%Real-World DeploymentLLaMA 70B: TP=8 (intra-node NVLink) + PP=4 (inter-node) + DP=16 (data parallel)

The visualization above shows how DP, TP, and PP compare in memory usage, communication patterns, and compute efficiency across different model sizes and GPU counts. Notice that the 175B model cannot run on 1-2 GPUs under any single strategy — this illustrates the necessity of hybrid parallelism.

GSPMD: Compiler-Driven Automatic Partitioning

Manually designing parallelism strategies requires deep expertise, and different model architectures may need different strategies. GSPMD (General and Scalable Parallelization for ML Computation Graphs) proposes a method for the compiler to automatically handle partitioning.

Sharding Specification

GSPMD’s core abstraction is the sharding specification. Each tensor has a sharding spec describing how it is distributed across a device mesh:

sharding_spec = {
  tensor_dims: [batch, seq, hidden],
  mesh_dims:   [x, y],
  mapping:     {batch -> x, hidden -> y}  // batch split along mesh x-axis, hidden along y-axis
}

For example, for a [B,S,D][B, S, D] tensor on a 4×24 \times 2 device mesh:

  • {batch -> x} means the batch dimension is split across 4 devices, each device holding [B/4,S,D][B/4, S, D]
  • {batch -> x, hidden -> y} means both batch and hidden are split, each device holding [B/4,S,D/2][B/4, S, D/2]
  • {} means fully replicated, each device holding [B,S,D][B, S, D]

The key advantage of this representation: it is general enough to unify DP (batch dimension splitting), TP (hidden dimension splitting), and their combinations.

Sharding Propagation Algorithm

Given user sharding annotations on a few tensors, the compiler needs to infer the sharding spec for all tensors and insert communication operators where necessary. This is the sharding propagation process.

The algorithm’s core is the sharding rule for each operation:

MatMul rule: For C=A×BC = A \times B (A:[M,K],B:[K,N]A: [M, K], B: [K, N])

  • If AA is split along KK then BB must also be split along KK, and CC is a partial sum requiring AllReduce
  • If AA is split along MM then CC is also split along MM (BB unconstrained)
  • If BB is split along NN then CC is also split along NN (AA unconstrained)

ElementWise rule (e.g., ReLU, Add):

  • Input and output sharding specs must be identical — if the input is split along dimension ii, the output is also split along dimension ii

Reduce rule (e.g., Sum, Mean):

  • If split along the reduced dimension, the output requires AllReduce
  • If split along a non-reduced dimension, the output preserves that splitting

Propagation is implemented as a worklist algorithm:

  1. Initialize: add user-annotated tensors to the decided set
  2. Traverse operators connected to decided tensors in the computation graph
  3. Apply the operator’s sharding rule to infer undecided tensors’ sharding specs
  4. If new tensors are decided, add them to the worklist
  5. Repeat until all tensors are decided (or a conflict is detected requiring communication insertion)

When two inputs require different sharding specs for the same tensor, the compiler inserts resharding communication (e.g., AllGather to convert sharded to replicated, AllReduce to convert partial sum to full value, All-to-All to convert one splitting to another).

User Annotation

Click W1 to annotate as column-parallel: split the 4D dimension across N devices. Other tensors show "?" for undecided.

Devices N:
X[B, S, D]replicatedW1[D, 4D]col[D, 4D/4]← Click to annotate W1MatMul₁[B, S, 4D]?ReLU[B, S, 4D]?W2[4D, D]?MatMul₂[B, S, D]?Y[B, S, D]?Sharding:User annotatedAuto propagatedUnresolvedComm opDevices:GPU 0GPU 1GPU 2GPU 3

The interactive demonstration above shows GSPMD’s sharding propagation process. In Step 1, the user only annotates W1 as column-parallel split. In Step 2, the compiler automatically infers all intermediate tensors’ splitting using MatMul and ElementWise sharding rules. In Step 3, the compiler detects that MatMul2’s output is a partial sum (because W2 is row-split and the matrix multiplication sums over the split dimension), and automatically inserts an AllReduce communication operator.

Cost Model

Propagation may yield multiple legal sharding schemes. GSPMD uses a cost model to select the optimal one:

Cost(s)=opGcompute_cost(op,sop)+αeEcomm_cost(e,se)\text{Cost}(s) = \sum_{op \in G} \text{compute\_cost}(op, s_{op}) + \alpha \sum_{e \in E} \text{comm\_cost}(e, s_e)

where ss is the sharding scheme and α\alpha is the communication-computation trade-off factor. Communication cost considers:

  • AllReduce communication volume: 2N1Ntensor_size2 \cdot \frac{N-1}{N} \cdot \text{tensor\_size} (ring AllReduce)
  • AllGather communication volume: N1Ntensor_size\frac{N-1}{N} \cdot \text{tensor\_size}
  • Topology factors: NVLink intra-node vs PCIe vs inter-node bandwidth differences

GSPMD’s implementation in the XLA compiler (SPMD partitioner) has been successfully applied to train ultra-large-scale models like PaLM (540 billion parameters).

torch.compile and Distributed Training

PyTorch 2.0’s torch.compile is progressively integrating distributed capabilities, enabling compilers to optimize across distributed boundaries.

DTensor Abstraction

PyTorch’s DTensor (Distributed Tensor) is the key abstraction connecting compilers and distributed training. DTensor augments regular Tensors with two pieces of metadata:

  • DeviceMesh: Describes the logical topology of devices (e.g., a 4×24 \times 2 2D mesh)
  • Placement: Describes how the tensor is distributed on the mesh (Shard(dim), Replicate(), Partial())
# Create a 2D device mesh
mesh = DeviceMesh("cuda", [[0, 1, 2, 3], [4, 5, 6, 7]])

# Distribute a tensor on the mesh
# batch dim split along mesh dim 0 (DP), hidden dim split along mesh dim 1 (TP)
dtensor = distribute_tensor(tensor, mesh, [Shard(0), Shard(2)])

DTensor’s core value: it allows torch.compile to treat distributed communication as regular operators in the computation graph, enabling cross-device graph optimizations.

FSDP + torch.compile Integration

PyTorch FSDP2 (FSDPv2) uses DTensor as its underlying representation, enabling torch.compile to directly “see” FSDP’s communication patterns and optimize them:

  1. AllGather prefetching: The compiler analyzes the computation graph’s execution order and initiates AllGather ahead of when a layer’s parameters are needed
  2. ReduceScatter deferral: Postpone ReduceScatter until gradients are actually needed
  3. Communication-computation fusion: Fuse small communication operators with computation operators to reduce kernel launch overhead
model = FSDP(model, use_orig_params=True)
model = torch.compile(model)  # Compiler can optimize FSDP communication patterns

In Meta’s benchmarks, torch.compile + FSDP2 achieves 10-20% training throughput improvement over pure eager FSDP, primarily from improved communication overlap and elimination of unnecessary communication.

TP + torch.compile Integration

Tensor parallelism integration with torch.compile is also based on DTensor. Users specify TP strategy via the parallelize_module API:

from torch.distributed.tensor.parallel import parallelize_module, ColwiseParallel, RowwiseParallel

tp_mesh = DeviceMesh("cuda", list(range(8)))
parallelize_plan = {
    "layers.*.attention.wq": ColwiseParallel(),
    "layers.*.attention.wk": ColwiseParallel(),
    "layers.*.attention.wv": ColwiseParallel(),
    "layers.*.attention.wo": RowwiseParallel(),
    "layers.*.feed_forward.w1": ColwiseParallel(),
    "layers.*.feed_forward.w2": RowwiseParallel(),
}
model = parallelize_module(model, tp_mesh, parallelize_plan)
model = torch.compile(model)

The compiler’s role in this flow:

  1. Track DTensor placement information through the computation graph
  2. Identify redundant communications (e.g., two consecutive AllReduces can be merged)
  3. Schedule communication operators at optimal time points

Compiler vs Framework Division of Labor

It is worth noting that PyTorch’s current distributed compilation is still in a “compiler-assisted, framework-driven” stage:

DecisionFramework (user-specified)Compiler (auto-optimized)
Choose DP/TP/PPUser decidesFuture goal
DTensor placementUser specifiesCompiler propagates
Comm op insertionDTensor automaticCompiler removes redundancy
Comm schedulingBasic rulesCompiler global optimization
Comm-compute overlapManual or FSDP built-inCompiler further optimizes

Compared to GSPMD’s “fully automatic” approach, PyTorch has chosen a more incremental path: first let users express distributed intent (via DTensor), then let the compiler optimize execution. This design respects PyTorch users’ tradition of preferring flexible control.

Communication Optimization

Regardless of the parallelism strategy chosen, the performance bottleneck in distributed training is often communication. Compilers can apply multiple optimization techniques to reduce the effective communication overhead.

Hardware Topology Awareness

The first step in communication optimization is understanding the hardware topology. Bandwidth differences across interconnects are substantial:

InterconnectBandwidthTypical Use Case
NVLink Gen3 (A100)~600 GB/s (bidirectional)Intra-node GPU-to-GPU
NVLink Gen4 (H100)~900 GB/s (bidirectional)Intra-node GPU-to-GPU
NVSwitch (H100 DGX)Full bisection 900 GB/s8-GPU all-to-all
PCIe Gen4 x16~32 GB/s (unidirectional)GPU-CPU, some GPU-to-GPU
PCIe Gen5 x16~64 GB/s (unidirectional)Next-gen GPU-CPU
InfiniBand NDR~50 GB/s (unidirectional, 400 Gb/s)Inter-node
InfiniBand NDR400~50 GB/s x8 lanesNext-gen inter-node
RoCE (RDMA over Converged Ethernet)~25-50 GB/sEthernet inter-node

Compilers use topology information to make critical decisions:

  • Communication primitive selection: Use NCCL’s ring/tree AllReduce on NVLink; use hierarchical AllReduce for inter-node
  • Partitioning strategy constraints: TP only for NVLink-connected devices; PP preferred for inter-node
  • Communication volume vs frequency trade-off: High-bandwidth low-latency interconnects favor small, frequent communications; low-bandwidth high-latency interconnects favor large, infrequent communications

AllReduce Fusion

Multiple small AllReduces can be fused into one large AllReduce. This reduces the following overheads:

  1. Kernel launch overhead: Each NCCL call incurs ~10 microseconds of launch latency
  2. Synchronization overhead: Each AllReduce requires a global barrier
  3. Bandwidth utilization: Small messages cannot fully saturate interconnect bandwidth (bandwidth scales with message size, saturating around ~1 MB)

PyTorch DDP defaults to a 25 MB bucket size for gradient AllReduce fusion. Compilers can further optimize:

  • Adaptive bucket size: Dynamically adjust based on network topology and current communication load
  • Cross-layer fusion: Fuse gradients from different layers into the same AllReduce
  • Operator fusion: Fuse AllReduce + subsequent parameter update into one kernel

Compute-Communication Overlap

Overlap is the most important technique for hiding communication latency. The core idea: execute computation on one CUDA stream and communication on another CUDA stream simultaneously, exploiting the independence between GPU compute units and network hardware.

Overlap in DDP: During backward pass, once certain layers’ gradients are computed, immediately launch AllReduce for those gradients while continuing to compute earlier layers’ gradients.

Overlap in FSDP:

  • Forward pass: Prefetch next layer’s AllGather while computing current layer
  • Backward pass: Prefetch previous layer’s AllGather while computing current layer’s gradients + asynchronously execute subsequent layer’s ReduceScatter

Overlap in TP: More challenging because AllReduce is on the compute path (not the gradient path) in every layer. One approach is to decompose AllReduce into ReduceScatter + AllGather, executing computation that doesn’t depend on communication results during the ReduceScatter wait.

The effectiveness of overlap depends on the communication-to-computation time ratio. When communication time is less than computation time, communication can be fully hidden; when communication time exceeds computation time, overlap can only partially hide latency.

Layers:
Comm/Compute Ratio:
Strategy:
Communication-Computation Overlap6 Layer | Compute: 10ms/Layer | Comm: 2.5ms/LayerSerial ExecutionL0L1L2L3L4L575 msOptimized ExecutionCompute StreamL0L1L2L3L4L5Comm Stream63 msComputeCommMetricsSerial:75 msOptimized:63 msSpeedup:16.7%Compute-Comm OverlapUse a separate CUDA stream to run communication in parallel with next layer compute.Ideally communication is fully hidden when comm_time < compute_time.Limitation: requires GPU multi-stream support; PCIe bandwidth may bottleneck.

The interactive demonstration above shows the effects of three communication optimization strategies. Try adjusting the layer count and communication/computation ratio to observe performance differences across configurations. Key observations:

  1. When the communication/computation ratio is 10%, simple overlap nearly completely hides communication
  2. When the communication/computation ratio reaches 50%, AllReduce fusion or Bucket strategies are needed for further optimization
  3. Bucket AllReduce achieves the most uniform overlap by distributing communication across the entire backward pass

CUDA Stream Parallelism Implementation Details

To achieve effective compute-communication overlap, CUDA streams and events must be properly managed:

compute_stream = torch.cuda.Stream()
comm_stream = torch.cuda.Stream()

for layer in model.layers:
    # Compute stream: execute current layer forward pass
    with torch.cuda.stream(compute_stream):
        output = layer(input)

    # Wait for compute on the comm stream
    event = compute_stream.record_event()
    comm_stream.wait_event(event)

    # Comm stream: async AllReduce
    with torch.cuda.stream(comm_stream):
        dist.all_reduce(output, async_op=True)

Compilers (like torch.compile) can automatically generate this stream management code without requiring users to manually insert it. This is precisely the value of compilers in distributed optimization: automating the low-level stream scheduling, event synchronization, and memory management.

Graph Partitioning Algorithms

When compilers need to assign computation graphs to multiple devices, they face a graph partitioning problem. This problem is especially critical in pipeline parallelism: how to partition LL layers of a model into NN stages so that each stage’s computation is as balanced as possible.

Weighted Graph Partitioning

Formally, given a computation graph G=(V,E)G = (V, E):

  • Node vVv \in V has weight w(v)w(v) (computation cost and memory usage)
  • Edge eEe \in E has weight c(e)c(e) (communication volume)
  • Objective: partition VV into NN subsets V1,...,VNV_1, ..., V_N such that:
    1. Load balance: Minimize maxivViw(v)\max_i \sum_{v \in V_i} w(v)
    2. Communication minimization: Minimize (u,v)E:part(u)part(v)c(u,v)\sum_{(u,v) \in E: \text{part}(u) \neq \text{part}(v)} c(u,v)

This is an NP-hard problem (even for N=2N=2), so compilers use approximation algorithms.

PP Stage Partitioning Algorithms

For pipeline parallelism, since Transformer model layers are typically a linear sequence, the problem simplifies to sequence partitioning:

Greedy algorithm: Divide total computation by NN, greedily assigning layers to stages so each stage’s computation approaches total/N\text{total}/N. Time complexity O(L)O(L), but not guaranteed optimal.

Dynamic Programming (DP): dp[i][j]\text{dp}[i][j] represents the minimum maximum-stage computation when partitioning the first ii layers into jj stages.

dp[i][j]=mink<imax(dp[k][j1],l=k+1iw(l))\text{dp}[i][j] = \min_{k < i} \max\left(\text{dp}[k][j-1], \sum_{l=k+1}^{i} w(l)\right)

Time complexity O(L2N)O(L^2 N), entirely feasible for scenarios like L=96,N=8L=96, N=8.

Integer Linear Programming (ILP): For more complex constraints (memory limits, heterogeneous devices), the problem can be modeled as ILP. Although worst-case exponential, actual Transformer model structures are regular enough that modern ILP solvers (Gurobi, CPLEX) can solve them in seconds.

Alpa (Automated inter- and intra-operator parallelism) employs a two-level search strategy:

Intra-operator parallelism: For each stage, use ILP to find the optimal TP sharding scheme (similar to GSPMD’s sharding propagation, but searched within an optimizer)

Inter-operator parallelism: Use DP to search for the optimal PP stage partition, where each stage’s cost is provided by the intra-op ILP

The key advantage of this hierarchical search: it decomposes an exponential search space into two tractable subproblems. Alpa demonstrated performance close to expert manual tuning in experiments while being fully automated.

Practical Implementation in Compilers

In XLA’s SPMD partitioner, the graph partitioning flow roughly proceeds as follows:

  1. Profiling: First run lightweight profiling of the computation graph on a single device to obtain each operator’s compute time and memory usage
  2. Cost estimation: Based on profiling data and a communication model (considering device topology), estimate the total cost of each partitioning scheme
  3. Search: Find the lowest-cost partitioning scheme in the search space (using DP or ILP)
  4. Lowering: Convert the partitioning scheme into concrete sharding specs, using propagation to complete all tensors’ sharding
  5. Code generation: Generate SPMD code for each device (all devices execute the same code but operate on different data shards), inserting necessary communication operators

This flow demonstrates the compiler’s core value in distributed optimization: transforming users’ high-level intent (“train this model on 64 GPUs”) into efficient low-level execution plans, automatically handling the complexity of partitioning, communication, and scheduling.

Summary

Distributed compilation is one of the most complex challenges facing ML compilers. This article covered the following core topics:

Parallelism strategy fundamentals: DP (simple but memory-limited), TP (efficient but requires NVLink), PP (memory-saving but has bubbles), FSDP (memory-optimized DP), and hybrid parallelism strategies.

GSPMD automatic partitioning: Sharding specification unifies various parallelism strategies, the sharding propagation algorithm automatically infers complete partitioning schemes, and the cost model selects the optimal among multiple candidates.

torch.compile + distributed: The DTensor abstraction lets compilers perceive and optimize distributed communication; FSDP2/TP integration with the compiler already delivers 10-20% throughput improvements.

Communication optimization: Hardware topology awareness, AllReduce fusion, compute-communication overlap — compilers achieve more systematic communication scheduling than manual optimization through their global perspective.

Graph partitioning algorithms: From greedy to DP to ILP, and Alpa’s hybrid search, compilers are progressively automating the decisions in distributed training that previously required the most expertise.

The future direction of distributed compilation is “full automation”: users specify only the model and hardware configuration, and the compiler automatically searches for the optimal parallelism strategy, partitioning scheme, and communication schedule. GSPMD and Alpa have taken important steps in this direction, while torch.compile’s distributed integration is bringing these techniques to the broader PyTorch ecosystem.

In the next article, we will discuss Scheduling and Execution Optimization — after the compiler has completed graph optimization, fusion, and partitioning, how to efficiently schedule these operations for GPU execution, including CUDA Stream orchestration, CUDA Graph capture, and memory planning.