Keyboard shortcuts

Press or to navigate between chapters

Press S or / to search in the book

Press ? to show this help

Press Esc to hide this help

11.2b Distributed Training Fundamentals: DP / TP / PP / SP / ZeRO

🖥️ "Training a 70B parameter model — a single A100 (80GB VRAM) simply can't fit it. Even if it could, you'd retire before training finishes. Distributed training isn't a luxury; it's a prerequisite for large models to exist."


Why Is Distributed Training Necessary?

Before understanding the various parallelism strategies, let's clarify the two core constraints: the memory wall and the compute wall.

Memory Wall: One GPU Can't Fit It

Using training a 7B parameter model as an example, in mixed precision training (BF16), GPU memory usage comes from the following parts:

7B Model Mixed Precision Training Memory Estimation

As shown in the diagram, optimizer states (Adam's m + v momentum) are the largest overhead, exceeding the total memory of an A100 on their own. Adding parameters, gradients, and activations, mixed precision training of a 7B model requires 106–138 GB, far exceeding the 80 GB limit of a single A100.

Compute Wall: One GPU Can't Wait

Using Llama-3 70B as an example, training on 1T tokens requires approximately floating point operations:

A single A100 (312 TFLOPS) would need:

A cluster of 1,000 A100s still needs 15 days — this is the actual time scale of real training.

Distributed training solves both problems simultaneously by distributing computation and storage across multiple GPUs.


Overview of Five Parallelism Dimensions

Modern LLM training uses up to 5 parallelism dimensions simultaneously, each targeting a different axis of the computation graph:

LLM Distributed Training Five Parallelism Strategies

Parallelism StrategyFull NameAbbreviationSplit DimensionCore Problem Solved
Data ParallelismData ParallelismDPBatch dimensionTraining speed (throughput)
Tensor ParallelismTensor ParallelismTPWeight matrix dimensionSingle layer parameters too large
Pipeline ParallelismPipeline ParallelismPPModel layer (depth) dimensionToo many layers
Sequence ParallelismSequence ParallelismSPSequence length dimensionLong sequence activations too large
Expert ParallelismExpert ParallelismEPMoE expert count dimensionToo many experts in MoE models

ZeRO is an optimizer state sharding technique used with DP; strictly speaking it's not a new parallelism dimension, but it's critical for memory optimization.


I. Data Parallelism (DP / DDP)

Core Idea

Data parallelism is the simplest and most commonly used parallelism strategy: each GPU holds a complete model replica but processes different data subsets, then aggregates gradients to update parameters.

Data Parallelism DP Principle + Ring-AllReduce Communication

As shown on the left side of the diagram, each GPU has a complete model replica, performs forward + backward computation on different batches, then aggregates averaged gradients via AllReduce to synchronously update all replicas' parameters. The right side shows Ring-AllReduce communication — GPUs are arranged in a ring, gradient fragments are passed along the ring, and each GPU's communication volume is independent of the number of GPUs, achieving near-linear scaling.

DDP vs DP

PyTorch provides two data parallelism implementations:

import torch
import torch.nn as nn
from torch.nn.parallel import DistributedDataParallel as DDP
from torch.nn.parallel import DataParallel as DP

# ─── Method 1: DataParallel (DP) — single machine multi-GPU, simple but has bottlenecks ───
# Problem: master GPU (GPU 0) is responsible for aggregating gradients, becomes communication bottleneck
# Problem: uneven memory usage across GPUs (master GPU bears more load)
model = MyModel().cuda()
model = DataParallel(model, device_ids=[0, 1, 2, 3])
output = model(input)   # Automatically distributes batch to each GPU

# ─── Method 2: DistributedDataParallel (DDP) — Recommended! ───
# Advantage: each GPU independently computes gradients, Ring-AllReduce communicates evenly
# Advantage: supports multi-machine multi-GPU, better linear scalability
import torch.distributed as dist

def setup(rank, world_size):
    dist.init_process_group("nccl", rank=rank, world_size=world_size)

def train_ddp(rank, world_size, model, dataset):
    setup(rank, world_size)
    
    # Each process holds one model copy
    model = model.to(rank)
    model = DDP(model, device_ids=[rank])
    
    # DistributedSampler ensures each process sees different data
    sampler = torch.utils.data.DistributedSampler(
        dataset, num_replicas=world_size, rank=rank
    )
    loader = DataLoader(dataset, sampler=sampler, batch_size=64)
    
    optimizer = torch.optim.AdamW(model.parameters(), lr=1e-4)
    
    for batch in loader:
        optimizer.zero_grad()
        loss = model(batch)
        loss.backward()          # DDP automatically triggers AllReduce during backward
        optimizer.step()

AllReduce: Communication Algorithm for Gradient Aggregation

The core operation of DDP is Ring-AllReduce (shown on the right side of the diagram above): each GPU passes its gradient fragment along the ring step by step; after 2(N-1) rounds of communication, each GPU has the complete averaged gradient. The key property is that communication volume is independent of the number of GPUs — even with 1,000 GPUs, each GPU's communication volume is still approximately 2× the parameter count, achieving ideal linear scaling.

Limitations of DP

Memory problem: each GPU must store complete model parameters + gradients + optimizer states. For a 70B model, even with DP=1000, each GPU still needs 100GB+. This led to the development of ZeRO (see below).

Effective batch size: global_batch_size = local_batch_size × world_size. With DP=1000, the effective batch may be too large, causing training instability; needs to be combined with Gradient Accumulation:

# Gradient Accumulation: simulate large batch without increasing actual batch size
accumulation_steps = 8  # Accumulate 8 mini-batches before updating

for step, batch in enumerate(loader):
    loss = model(batch) / accumulation_steps  # Scale loss
    loss.backward()                            # Accumulate gradients
    
    if (step + 1) % accumulation_steps == 0:
        optimizer.step()                       # Update every 8 steps
        optimizer.zero_grad()

II. Tensor Parallelism (TP)

Core Idea

Tensor parallelism splits a single weight matrix across GPUs; each GPU holds only a portion of the matrix and computes matrix multiplication in parallel.

This is the core technique proposed by Megatron-LM [1], specifically targeting the two dense layers of Transformers:

Tensor Parallelism TP Column/Row Parallel Matrix Splitting and MLP Data Flow

As shown in the diagram, TP has two splitting methods:

  • Column Parallel: split weight W by columns; each GPU independently computes local output, then Concat at the end — no communication needed in forward pass.
  • Row Parallel: split weight by rows while also splitting input; each GPU independently computes then AllReduce sums — 1 AllReduce needed in forward pass.

The actual Transformer MLP layer consists of "column parallel + activation function + row parallel" (see data flow at bottom of diagram), requiring 2 AllReduces per layer (1 forward + 1 backward).

Column Parallel Linear concept:

  • Weight W [H, 4H] split by columns into W₀ [H, 2H] and W₁ [H, 2H]
  • GPU 0 computes Y₀ = X × W₀, GPU 1 computes Y₁ = X × W₁
  • Output Concat: Y = [Y₀ | Y₁] [B, s, 4H]

Row Parallel Linear concept:

  • Weight W [4H, H] split by rows, input also split accordingly
  • GPU 0: X₀ × W₀ = Y₀, GPU 1: X₁ × W₁ = Y₁
  • AllReduce: Y = Y₀ + Y₁ [B, s, H]

TP Decomposition of MLP Layer

# TP decomposition of standard FFN layer (Megatron-LM style)
class ColumnParallelLinear(nn.Module):
    """
    Weight split by columns:
    W_full [H, 4H] → each GPU holds W_local [H, 4H/tp_size]
    """
    def __init__(self, in_features, out_features, tp_size):
        super().__init__()
        self.tp_size = tp_size
        # Each GPU only holds 1/tp_size of the columns
        self.weight = nn.Parameter(
            torch.randn(in_features, out_features // tp_size)
        )
    
    def forward(self, x):
        # Local matrix multiplication, no communication needed
        return F.linear(x, self.weight)  # [B, s, out/tp]


class RowParallelLinear(nn.Module):
    """
    Weight split by rows:
    W_full [4H, H] → each GPU holds W_local [4H/tp_size, H]
    Input x is also local (output from ColumnParallel)
    """
    def __init__(self, in_features, out_features, tp_size):
        super().__init__()
        self.tp_size = tp_size
        self.weight = nn.Parameter(
            torch.randn(in_features // tp_size, out_features)
        )
    
    def forward(self, x):
        # Local matrix multiplication
        local_output = F.linear(x, self.weight)  # [B, s, H]
        # AllReduce to aggregate partial results from all GPUs
        dist.all_reduce(local_output, op=dist.ReduceOp.SUM)
        return local_output


class TensorParallelMLP(nn.Module):
    """
    Complete MLP TP implementation:
    FFN(x) = GeLU(x @ W_up) @ W_down
    
    Communication pattern:
    Input x → [AllGather] → Column parallel W_up → Row parallel W_down → [AllReduce] → Output
    Total: 1 AllGather + 1 AllReduce in forward
           1 AllGather + 1 AllReduce in backward
    """
    def __init__(self, hidden_size, ffn_size, tp_size):
        super().__init__()
        self.up_proj = ColumnParallelLinear(hidden_size, ffn_size, tp_size)
        self.down_proj = RowParallelLinear(ffn_size, hidden_size, tp_size)
    
    def forward(self, x):
        return self.down_proj(F.gelu(self.up_proj(x)))

TP for Attention Layers

Multi-head attention (MHA) TP is more natural — split directly by attention heads:

# Attention TP: each GPU handles some attention heads
class TensorParallelAttention(nn.Module):
    """
    Assuming n_heads=32, TP=4:
    Each GPU handles 8 heads
    
    Q, K, V projections: column parallel (output split)
    Output projection: row parallel (input split + AllReduce)
    """
    def __init__(self, d_model, n_heads, tp_size):
        super().__init__()
        self.local_heads = n_heads // tp_size       # Number of heads per GPU
        self.head_dim = d_model // n_heads
        local_d = self.local_heads * self.head_dim  # KQV dimension for this GPU
        
        # Column parallel: each GPU only holds partial Q/K/V projections
        self.qkv_proj = ColumnParallelLinear(d_model, 3 * local_d, tp_size=1)
        # Row parallel: aggregate attention outputs from each GPU
        self.out_proj = RowParallelLinear(local_d, d_model, tp_size)

Applicable Scenarios and Limitations of TP

FeatureDescription
Communication volume2 AllReduces per layer (1 forward + 1 backward)
Communication latencyHigh — must wait for communication to complete before continuing each layer (synchronous)
Recommended GPU connectionMust be within the same node (NVLink); cross-node bandwidth too low
Suitable forSingle layer parameters too large (MLP layer 4H×H weights)
Not suitable forMany layers but each not large; cross-node scaling
Typical scaleTP=4~8 (within a single 8-GPU server)

III. Pipeline Parallelism (PP)

Core Idea

Pipeline parallelism splits the model by layers, with different GPUs responsible for different layer groups, processing in parallel like a factory assembly line:

Pipeline Parallelism PP: Sequential Execution vs 1F1B Schedule Gantt Chart

As shown in the diagram, without PP most GPUs are idle waiting; the 1F1B schedule lets each GPU alternately execute forward and backward in the steady state, significantly improving utilization. The bubble rate formula is:

Where P is PP_stages (number of pipeline stages) and M is the number of micro-batches. When M ≫ P, the bubble rate approaches zero.

Bubble (Pipeline Bubble) Problem

The biggest challenge of pipeline parallelism is bubbles — idle time when GPUs wait for the previous stage's output:

# GPipe scheduling strategy (naive PP)
class GPipe:
    """
    Each micro-batch completes full forward propagation before backward propagation
    Bubble rate = (PP_stages - 1) / (micro_batches + PP_stages - 1)
    
    When micro_batches >> PP_stages, bubble rate → 0
    But too many micro_batches increases memory (need to cache all activations)
    """
    def __init__(self, stages=4, micro_batches=8):
        self.stages = stages
        self.micro_batches = micro_batches
        self.bubble_rate = (stages - 1) / (micro_batches + stages - 1)
        print(f"Bubble rate: {self.bubble_rate:.1%}")  # stages=4, m=8 → 27.3%

1F1B scheduling (proposed by Megatron-LM, superior):

The core improvement of 1F1B is lower peak memory — it only needs to cache stages micro-batches' activations (while GPipe needs to cache all M), so it saves more memory at the same bubble rate. See the scheduling comparison in the Gantt chart above.

class OneFOneB:
    """
    1F1B scheduling core idea:
    In steady state, each time step is 1 forward + 1 backward
    Avoids GPipe's need to cache all micro-batch activations
    """
    def schedule(self, num_stages, num_micro_batches):
        """
        Returns the execution sequence for each stage
        F = Forward, B = Backward
        Numbers represent micro-batch IDs
        """
        schedule = {stage: [] for stage in range(num_stages)}
        
        # Warmup phase: first PP_stages micro-batches do forward only
        for stage in range(num_stages):
            for mb in range(num_stages - stage):
                schedule[stage].append(f"F{mb}")
        
        # Steady state: each GPU alternates 1F1B
        # ...(see Megatron-LM source code for actual implementation)
        
        return schedule

Applicable Scenarios and Limitations of PP

FeatureDescription
Communication volumeOnly passes activations between adjacent stages (point-to-point, low communication)
Communication latencyRelatively low (P2P communication)
Recommended GPU connectionSuitable for cross-node (100Gbps IB is sufficient)
Suitable forMany model layers, each layer not too large
DisadvantagesPipeline bubbles; complex debugging; high implementation difficulty
Typical scalePP=4~16 (multi-machine multi-GPU)

IV. Sequence Parallelism (SP)

Core Idea

Sequence Parallelism splits along the sequence length dimension, primarily solving the problem of activation memory explosion with long sequences.

In standard Transformers, activation memory usage has a quadratic relationship with sequence length (attention matrix ); when sequence length grows from 2K to 128K, this is a fatal bottleneck.

Sequence Parallelism SP+TP Communication Data Flow + Ring Attention (CP)

As shown in the diagram, SP+TP splits the standard TP's AllReduce into AllGather + ReduceScatter, keeping the sequence in a sharded state in the SP region (LayerNorm, Dropout, etc.), thereby reducing activation memory at these positions by tp_size times. The Ring Attention (Context Parallelism) shown at the bottom of the diagram can further reduce attention matrix memory from to , supporting ultra-long context training exceeding 1M tokens.

SP + TP Combination (Megatron-LM v3)

SP is usually used in combination with TP (within the same group of GPUs), jointly reducing activations (complete data flow already shown in the diagram above).

Key optimization: replacing TP's AllReduce with ReduceScatter + AllGather combination can maintain sharding in the sequence dimension, reducing activation memory by 50%.

# SP+TP communication pattern (comparison)
class SequenceParallelism:
    """
    Standard TP communication:
      Forward: AllGather(x) → ColParallel → RowParallel + AllReduce
      Backward: AllReduce(∇) → RowParallel → ColParallel + AllGather
    
    SP+TP communication (activations always maintain sequence sharding):
      Forward: AllGather(x) → ColParallel → RowParallel + ReduceScatter
      Backward: AllGather(∇) → RowParallel → ColParallel + ReduceScatter
    
    Memory savings: activations in sequence parallel regions (Dropout, LayerNorm) reduced by tp_size times
    Communication volume: same as standard TP (AllGather ≈ AllReduce communication volume)
    """
    pass

Context Parallelism (CP): Ring Attention

When sequence length exceeds what SP can handle (e.g., 1M tokens), there's a more aggressive approach — Context Parallelism:

# Ring Attention (core of CP)
# Split Q/K/V along sequence dimension,
# implement distributed attention computation via Ring communication

class RingAttention:
    """
    Core idea:
    - Each GPU has a segment of complete Q [B, S/cp, H]
    - K/V circulate between GPUs via Ring communication
    - Each GPU completes partial attention computation locally
    - Final merge gives complete attention output
    
    Communication volume: O(S/cp × H × cp) = O(S × H), independent of layer count
    Memory: attention matrix reduced from O(S²) to O(S²/cp²)
    
    Typical applications: Apple MLX, Google JAX ultra-long context training
    """
    def forward(self, q, k, v, cp_group):
        S_local = q.shape[1]  # S / cp
        output = torch.zeros_like(q)
        
        # Local K/V
        k_local = k.clone()
        v_local = v.clone()
        
        for step in range(self.cp_size):
            # Compute attention for current K/V block
            attn_out = flash_attn(q, k_local, v_local, causal=(step == 0))
            output += attn_out
            
            # Pass K/V to next GPU via Ring
            k_local = self.ring_send_recv(k_local, cp_group)
            v_local = self.ring_send_recv(v_local, cp_group)
        
        return output

V. ZeRO: Eliminating Optimizer Redundancy

Problem Source

In standard DDP, each GPU stores complete:

  • Model parameters: FP16, 2 bytes/param
  • Gradients: FP32, 4 bytes/param
  • Optimizer states: Adam needs m + v, FP32, 8 bytes/param

Total approximately 16 bytes/param. For a 70B model = 1,120 GB; even with 1,000 A100s, each GPU still needs >1GB, but the real problem is each GPU stores completely identical optimizer states — this is enormous redundancy!

ZeRO Three Stages

ZeRO (Zero Redundancy Optimizer) was proposed by Microsoft DeepSpeed [2], eliminating redundancy through sharding:

ZeRO 0/1/2/3 Three-Stage Sharding Comparison (4 GPUs, 16 bytes/param example)

As shown in the diagram, ZeRO progressively shards in three stages, gradually eliminating redundancy in optimizer states, gradients, and model parameters:

StageSharded ContentMemory per GPU (N=4)Savings
ZeRO-0 (DDP)No sharding80 B/param
ZeRO-1Optimizer states~38 B/param52%
ZeRO-2+ Gradients~21 B/param74%
ZeRO-3+ Model parameters80/N B/param

Training a 70B model with 1,000 A100s: ~1.12GB per GPU — easily fits!

# Training with DeepSpeed ZeRO
import deepspeed

# ZeRO Stage configuration
ds_config = {
    "zero_optimization": {
        "stage": 3,                      # ZeRO-3: full sharding
        "offload_optimizer": {
            "device": "cpu",             # Optional: offload optimizer states to CPU
            "pin_memory": True,
        },
        "offload_param": {
            "device": "cpu",             # Optional: offload parameters to CPU (ZeRO-Infinity)
        },
        "overlap_comm": True,            # Overlap communication with computation
        "contiguous_gradients": True,    # Contiguous memory improves communication efficiency
        "sub_group_size": 1e9,
        "reduce_bucket_size": 5e8,
    },
    "bf16": {"enabled": True},
    "gradient_checkpointing": True,
}

# Initialize DeepSpeed engine
model_engine, optimizer, _, _ = deepspeed.initialize(
    model=model,
    config=ds_config,
)

# Training loop is almost identical to standard PyTorch
for batch in dataloader:
    loss = model_engine(batch)
    model_engine.backward(loss)    # Replaces loss.backward()
    model_engine.step()            # Replaces optimizer.step()

ZeRO++ and ZeRO-Infinity

ZeRO++ (2023) further compresses communication volume on top of ZeRO-3:

  • qwZ (quantized weights): quantize FP16 to INT8 during AllGather, reducing communication volume by 50%
  • hpZ (hierarchical partitioning): prioritize intra-node partitioning to reduce cross-node traffic
  • qgZ (quantized gradients): quantize before ReduceScatter to further reduce bandwidth requirements

ZeRO-Infinity offloads parameters/gradients/optimizer states to CPU RAM or NVMe SSD, theoretically enabling training of arbitrarily large models, but speed is limited by PCIe bandwidth; suitable for "large model + few GPUs" research scenarios.

FSDP (PyTorch Native ZeRO)

PyTorch 2.0+ has built-in Fully Sharded Data Parallel (FSDP), which is the native implementation of ZeRO-3:

from torch.distributed.fsdp import (
    FullyShardedDataParallel as FSDP,
    MixedPrecision,
    BackwardPrefetch,
    ShardingStrategy,
)

# FSDP configuration
fsdp_config = dict(
    sharding_strategy=ShardingStrategy.FULL_SHARD,    # ZeRO-3 equivalent
    # ShardingStrategy.SHARD_GRAD_OP = ZeRO-2
    # ShardingStrategy.NO_SHARD = standard DDP
    
    mixed_precision=MixedPrecision(
        param_dtype=torch.bfloat16,
        reduce_dtype=torch.float32,
        buffer_dtype=torch.bfloat16,
    ),
    backward_prefetch=BackwardPrefetch.BACKWARD_PRE,  # Prefetch next layer parameters
    cpu_offload=None,  # or CPUOffload(offload_params=True)
    auto_wrap_policy=lambda module, recurse, *args: (
        recurse or isinstance(module, TransformerDecoderLayer)
    ),
)

model = FSDP(model, **fsdp_config)

VI. Expert Parallelism (EP)

EP is specifically for MoE (Mixture of Experts) models, such as Mixtral, DeepSeek-V3, etc.

MoE Review

Expert Parallelism EP: MoE Routing and AllToAll Dynamic Distribution

As shown on the left side of the diagram, standard FFN uses the same FFN for every token; MoE FFN sets up multiple experts, and each token only passes through K of them (Top-K routing), greatly increasing parameter count without increasing activation computation. DeepSeek-V3 uses 256 experts, activating 8 per token, with 671B total parameters but ~37B activated.

EP Splitting Method

As shown on the right side of the diagram, EP assigns different experts to different GPUs (with 256 experts and EP=8 as an example, each GPU holds 32 experts). The biggest challenge is that MoE routing is dynamic — which expert each token goes to is decided at runtime, requiring two AllToAll communications (distribute tokens → compute → collect results).

class ExpertParallelMoE(nn.Module):
    """
    Expert parallel MoE layer
    
    Communication pattern:
    1. Router decides which expert each token goes to (local computation)
    2. AllToAll: send tokens to corresponding GPUs
    3. Each GPU independently computes FFN for its experts
    4. AllToAll: send results back to original GPUs
    5. Merge expert outputs
    """
    def __init__(self, d_model, n_experts, n_experts_per_token, ep_group):
        super().__init__()
        self.ep_group = ep_group
        self.ep_size = dist.get_world_size(ep_group)
        self.local_n_experts = n_experts // self.ep_size
        
        # Each GPU only holds local_n_experts experts
        self.experts = nn.ModuleList([
            MLP(d_model) for _ in range(self.local_n_experts)
        ])
        self.router = Router(d_model, n_experts, n_experts_per_token)
    
    def forward(self, x):
        B, S, D = x.shape
        x_flat = x.view(-1, D)  # [B*S, D]
        
        # 1. Routing computation (local)
        expert_indices, expert_weights = self.router(x_flat)
        
        # 2. AllToAll: distribute tokens to corresponding GPUs
        x_dispatched = self.all_to_all_dispatch(x_flat, expert_indices)
        
        # 3. Local expert computation
        expert_outputs = []
        for i, expert in enumerate(self.experts):
            # Get tokens assigned to this GPU's i-th expert
            tokens_for_expert = x_dispatched[i]
            if tokens_for_expert.shape[0] > 0:
                expert_outputs.append(expert(tokens_for_expert))
        
        # 4. AllToAll: send results back to original GPUs
        combined = self.all_to_all_combine(expert_outputs)
        
        # 5. Weighted merge
        return self.weighted_sum(combined, expert_weights)

VII. 3D / 4D / 5D Parallelism: Combined Usage

Production training typically uses multiple parallelism strategies simultaneously, known as 3D/4D/5D parallelism:

  • 3D parallelism = TP × PP × DP (Megatron-LM's classic scheme)
  • 4D parallelism = TP × PP × DP × SP (adding sequence parallelism)
  • 5D parallelism = TP × PP × DP × SP × EP (adding expert parallelism, for MoE models)

Configuration example (Llama-3 70B on 512 H100s): TP=8 (intra-node NVLink) × PP=8 (cross-node InfiniBand) × DP=8 (ZeRO-1) = 512 GPUs.

Rules of Thumb for Choosing Parallelism Strategies

Parallelism Strategy Selection Decision Tree

Following the decision tree: determine TP first (can a single layer fit?), then PP (cross-node needed?), then SP (sequence > 32K?), give remaining GPUs to DP, and MoE independently adds EP.

# Actual configuration example (reference Megatron-LM and LLaMA-Factory)
training_config = {
    # Parallelism dimensions
    "tensor_model_parallel_size": 4,      # TP=4 (4 GPUs within a node)
    "pipeline_model_parallel_size": 4,    # PP=4 (across 4 nodes)
    "data_parallel_size": 16,             # DP=16 (total 256 GPUs / TP4 / PP4)
    "sequence_parallel": True,            # Used with TP
    
    # ZeRO configuration
    "zero_stage": 1,                      # Usually only ZeRO-1 needed in PP+TP mode
    
    # Batch configuration
    "global_batch_size": 2048,
    "micro_batch_size": 2,               # Each GPU processes 2 samples at a time
    "gradient_accumulation_steps": 64,   # 2048 / (2 × 16) = 64
    
    # Sequence length
    "seq_length": 8192,
    
    # Mixed precision
    "bf16": True,
    "fp32_residual_connection": False,
}

Communication Volume Comparison

StrategyCommunication OperationVolumeLatency SensitivityRecommended Network
DPAllReduce2P (parameter count)LowEthernet / IB
TPAllReduce / AllGather2 × activations/layerHighNVLink (intra-node)
PPP2P Send/Recvactivations × B/SMediumIB
SPAllGather / ReduceScatterSame as TPHighNVLink (intra-node)
ZeRO-3AllGather (forward) + ReduceScatter (backward)3P (more)LowIB

VIII. Gradient Checkpointing

This is not a parallelism strategy, but it's inseparable from distributed training — it's an important technique for trading computation for memory:

# Standard training: retain all activations (memory O(layers × S × H))
output = model(input)
loss = criterion(output, target)
loss.backward()  # Uses activations saved during forward pass

# Gradient checkpointing: only retain some activations, recompute during backward
from torch.utils.checkpoint import checkpoint

def forward_with_checkpointing(model, input):
    """
    Don't save intermediate activations; recompute forward once during backward propagation
    
    Memory savings: from O(L × S × H) → O(√L × S × H)
    Compute overhead: ~30% increase (equivalent to one extra forward pass)
    """
    # Enable checkpointing for each Transformer layer
    for layer in model.layers:
        # Don't save layer's activations
        input = checkpoint(layer, input, use_reentrant=False)
    return input

# Enable in Hugging Face Transformers
from transformers import LlamaForCausalLM

model = LlamaForCausalLM.from_pretrained("meta-llama/Llama-3-8B")
model.gradient_checkpointing_enable()  # One line, saves ~50% memory

Comprehensive Comparison and Selection Recommendations

Model ScaleGPU CountRecommended ConfigTypical Framework
1B~7B1~8 GPUsDP + ZeRO-2/3DeepSpeed, FSDP
7B~13B8~32 GPUsDP/ZeRO-3 + GradCkptFSDP, LLaMA-Factory
13B~70B32~256 GPUsTP=4 + PP=2 + DP + ZeRO-1Megatron-LM
70B~400B256~1024 GPUsTP=8 + PP=4~8 + DP + ZeRO-1Megatron-LM
400B~1T (MoE)512~8192 GPUsTP=8 + PP=8 + EP=8 + DP + ZeRO-1Megatron-Core

Comparison of Mainstream Training Frameworks

FrameworkSupported ParallelismSuitable ScenariosEase of Use
DeepSpeedDP+ZeRO, PP, TP (limited)Small-medium models, resource-constrained⭐⭐⭐⭐
PyTorch FSDPDP+ZeRO-3Medium scale, PyTorch native⭐⭐⭐⭐
Megatron-LMTP+PP+SP+DP+ZeROUltra-large scale pre-training⭐⭐
LLaMA-FactoryWraps FSDP/DeepSpeedSFT/RL fine-tuning⭐⭐⭐⭐⭐
AxolotlWraps FSDP/DeepSpeedSFT fine-tuning⭐⭐⭐⭐

Section Summary

TechnologySplit DimensionProblem SolvedKey Constraint
DP / DDPBatchThroughputEach GPU needs complete model
ZeRO-1/2/3Optimizer/gradient/parametersMemory redundancy under DPIncreased communication volume
FSDPParameters+gradients+optimizerPyTorch native ZeRO-3Multiple AllGather overhead
TPInside weight matricesSingle layer too largeRequires NVLink high bandwidth
PPModel layers (depth)Too many layersPipeline bubbles
SPSequence lengthLong sequence activationsUsed with TP
CP / Ring AttnUltra-long sequencesMillion-token attentionAttention computation splitting
EPMoE expertsExpert parameter distributionAllToAll dynamic routing

💡 Core takeaway for Agent developers:
If you're fine-tuning models with LLaMA-Factory or Axolotl, FSDP (ZeRO-3) + Gradient Checkpointing is the optimal choice for small teams — supports training 70B models with up to 8 GPUs.
If you're designing pre-training from scratch, you need to carefully plan the combination of 3D parallelism (TP × PP × DP).


References

[1] SHOEYBI M, et al. Megatron-LM: training multi-billion parameter language models using model parallelism[J]. arXiv:1909.08053, 2019.

[2] RAJBHANDARI S, et al. ZeRO: memory optimizations toward training trillion parameter models[C]//SC, 2020.

[3] KORTHIKANTI V, et al. Reducing activation recomputation in large Transformer models[J]. arXiv:2205.05198, 2022. (SP original paper)

[4] LIU Z, et al. Ring attention with blockwise transformers for near-infinite context[J]. arXiv:2310.01889, 2023. (Ring Attention)

[5] Microsoft DeepSpeed Team. ZeRO++: extremely efficient collective communication for giant model training[J]. arXiv:2306.10209, 2023.

[6] ZHAO Y, et al. PyTorch FSDP: experiences on scaling fully sharded data parallel[J]. arXiv:2304.11277, 2023.


Previous section: 11.2 SFT + LoRA Basic Training
Next section: 11.3 PPO: Proximal Policy Optimization