11.2b 分布式训练基础:DP / TP / PP / SP / ZeRO

🖥️ "训练一个 70B 参数模型,单卡 A100(80GB 显存)根本放不下——即便放得下,等你训到退休也没训完。分布式训练不是锦上添花,而是大模型存在的前提。"


为什么需要分布式训练?

在理解各种并行策略之前,先搞清楚显存墙算力墙这两个核心约束。

显存墙:一张卡装不下

以训练一个 7B 参数的模型为例,混合精度训练(BF16)时,显存占用来自以下几个部分:

7B 模型混合精度训练显存占用估算

如上图所示,优化器状态(Adam 的 m + v 动量)是最大的开销来源,单项就超过 A100 的全部显存。加上参数、梯度和激活值,7B 模型混合精度训练需要 106~138 GB,远超单卡 A100 的 80 GB 上限。

算力墙:一张卡等不起

以 Llama-3 70B 为例,训练 1T tokens 大约需要 次浮点运算:

单张 A100(312 TFLOPS)需要:

用 1000 张 A100 组成集群,仍需 15 天,这才是实际训练的时间规模。

分布式训练通过将计算和存储分散到多个 GPU 上,同时解决这两个问题。


五大并行维度总览

现代 LLM 训练综合使用最多 5 种并行维度,它们针对的是计算图的不同切分轴:

LLM 分布式训练五大并行策略

并行策略英文缩写切分维度解决的核心问题
数据并行Data ParallelismDPBatch 维度训练速度(吞吐量)
张量并行Tensor ParallelismTP权重矩阵维度单层参数过大
流水线并行Pipeline ParallelismPP模型层(深度)维度层数过多
序列并行Sequence ParallelismSP序列长度维度长序列激活值过大
专家并行Expert ParallelismEPMoE 专家数维度MoE 模型专家过多

ZeRO 是一种优化器状态分片技术,配合 DP 使用,严格说不属于新的并行维度,但对显存优化至关重要。


一、数据并行(DP / DDP)

核心思路

数据并行是最简单、最常用的并行策略:每个 GPU 持有完整的模型副本,但处理不同的数据子集,最后聚合梯度更新参数。

数据并行 DP 原理 + Ring-AllReduce 通信

如上图左侧所示,每个 GPU 拥有完整模型副本,分别对不同的 Batch 做前向 + 反向计算,然后通过 AllReduce 聚合平均梯度,同步更新所有副本的参数。右侧是 Ring-AllReduce 的通信方式——各 GPU 排成一个环,梯度片段沿环传递,每张卡的通信量与 GPU 数量无关,实现近似线性扩展。

DDP vs DP

PyTorch 提供两种数据并行实现:

import torch
import torch.nn as nn
from torch.nn.parallel import DistributedDataParallel as DDP
from torch.nn.parallel import DataParallel as DP

# ─── 方式一:DataParallel(DP)——单机多卡,简单但有瓶颈 ───
# 问题:主卡(GPU 0)负责汇聚梯度,成为通信瓶颈
# 问题:各 GPU 显存不均衡(主卡负担更重)
model = MyModel().cuda()
model = DataParallel(model, device_ids=[0, 1, 2, 3])
output = model(input)   # 自动分发 batch 到各 GPU

# ─── 方式二:DistributedDataParallel(DDP)——推荐!───
# 优点:每个 GPU 独立计算梯度,Ring-AllReduce 均匀通信
# 优点:支持多机多卡,线性扩展性更好
import torch.distributed as dist

def setup(rank, world_size):
    dist.init_process_group("nccl", rank=rank, world_size=world_size)

def train_ddp(rank, world_size, model, dataset):
    setup(rank, world_size)
    
    # 每个进程持有一份模型
    model = model.to(rank)
    model = DDP(model, device_ids=[rank])
    
    # DistributedSampler 保证每个进程看到不同数据
    sampler = torch.utils.data.DistributedSampler(
        dataset, num_replicas=world_size, rank=rank
    )
    loader = DataLoader(dataset, sampler=sampler, batch_size=64)
    
    optimizer = torch.optim.AdamW(model.parameters(), lr=1e-4)
    
    for batch in loader:
        optimizer.zero_grad()
        loss = model(batch)
        loss.backward()          # DDP 自动在 backward 期间触发 AllReduce
        optimizer.step()

AllReduce:梯度聚合的通信算法

DDP 的核心操作是 Ring-AllReduce(已在上图右侧展示):每个 GPU 把自己的梯度片段沿环形逐步传递,经过 2(N-1) 轮通信后,每个 GPU 都拥有完整的平均梯度。关键性质是通信量与 GPU 数量无关——即使 1000 张卡,每张卡的通信量仍约等于 2 倍参数量,实现理想线性扩展。

DP 的局限性

显存问题:每个 GPU 都要存储完整模型参数 + 梯度 + 优化器状态。对 70B 模型,即使 DP=1000,每张卡仍需 100GB+。这催生了 ZeRO(见后文)。

有效 batch sizeglobal_batch_size = local_batch_size × world_size。DP=1000 时,有效 batch 可能过大导致训练不稳定,需要配合 Gradient Accumulation

# Gradient Accumulation:模拟大 batch 而不增加实际 batch size
accumulation_steps = 8  # 累积 8 个 mini-batch 才更新一次

for step, batch in enumerate(loader):
    loss = model(batch) / accumulation_steps  # 缩放损失
    loss.backward()                            # 累积梯度
    
    if (step + 1) % accumulation_steps == 0:
        optimizer.step()                       # 每 8 步才更新
        optimizer.zero_grad()

二、张量并行(TP)

核心思路

张量并行将单个权重矩阵在 GPU 之间切分,每个 GPU 只持有矩阵的一部分,并行计算矩阵乘法。

这是 Megatron-LM [1] 提出的核心技术,专门针对 Transformer 的两大密集层:

张量并行 TP 列并行/行并行矩阵切分与 MLP 数据流

如上图所示,TP 有两种切分方式:

  • 列并行(Column Parallel):将权重 W 按列切分,各 GPU 独立计算局部输出,最后 Concat 拼接——前向无需通信
  • 行并行(Row Parallel):将权重按行切分,同时切分输入,各 GPU 独立计算后 AllReduce 求和——前向需要 1 次 AllReduce

实际 Transformer 的 MLP 层由"列并行 + 激活函数 + 行并行"组成(见图底部数据流),每层共需 2 次 AllReduce(前向 1 次 + 反向 1 次)。

线性层的列并行(Column Parallel Linear)概念:

  • 权重 W [H, 4H] 按列切分为 W₀ [H, 2H] 和 W₁ [H, 2H]
  • GPU 0 计算 Y₀ = X × W₀,GPU 1 计算 Y₁ = X × W₁
  • 输出 Concat:Y = [Y₀ | Y₁] [B, s, 4H]

线性层的行并行(Row Parallel Linear)概念:

  • 权重 W [4H, H] 按行切分,输入也相应切分
  • GPU 0:X₀ × W₀ = Y₀,GPU 1:X₁ × W₁ = Y₁
  • AllReduce:Y = Y₀ + Y₁ [B, s, H]

MLP 层的 TP 分解

# 标准 FFN 层的 TP 分解(Megatron-LM 风格)
class ColumnParallelLinear(nn.Module):
    """
    权重按列切分:
    W_full [H, 4H] → 每个 GPU 持有 W_local [H, 4H/tp_size]
    """
    def __init__(self, in_features, out_features, tp_size):
        super().__init__()
        self.tp_size = tp_size
        # 每个 GPU 只持有 1/tp_size 的列
        self.weight = nn.Parameter(
            torch.randn(in_features, out_features // tp_size)
        )
    
    def forward(self, x):
        # 局部矩阵乘法,不需要通信
        return F.linear(x, self.weight)  # [B, s, out/tp]


class RowParallelLinear(nn.Module):
    """
    权重按行切分:
    W_full [4H, H] → 每个 GPU 持有 W_local [4H/tp_size, H]
    同时输入 x 也已经是局部的(来自 ColumnParallel 的输出)
    """
    def __init__(self, in_features, out_features, tp_size):
        super().__init__()
        self.tp_size = tp_size
        self.weight = nn.Parameter(
            torch.randn(in_features // tp_size, out_features)
        )
    
    def forward(self, x):
        # 局部矩阵乘法
        local_output = F.linear(x, self.weight)  # [B, s, H]
        # AllReduce 聚合所有 GPU 的部分结果
        dist.all_reduce(local_output, op=dist.ReduceOp.SUM)
        return local_output


class TensorParallelMLP(nn.Module):
    """
    完整 MLP 的 TP 实现:
    FFN(x) = GeLU(x @ W_up) @ W_down
    
    通信模式:
    输入 x → [AllGather] → 列并行W_up → 行并行W_down → [AllReduce] → 输出
    总计:前向 1 次 AllGather + 1 次 AllReduce
         反向 1 次 AllGather + 1 次 AllReduce
    """
    def __init__(self, hidden_size, ffn_size, tp_size):
        super().__init__()
        self.up_proj = ColumnParallelLinear(hidden_size, ffn_size, tp_size)
        self.down_proj = RowParallelLinear(ffn_size, hidden_size, tp_size)
    
    def forward(self, x):
        return self.down_proj(F.gelu(self.up_proj(x)))

注意力层的 TP

多头注意力(MHA)的 TP 更自然——直接按注意力头切分:

# Attention TP:每个 GPU 负责部分注意力头
class TensorParallelAttention(nn.Module):
    """
    假设 n_heads=32,TP=4:
    每个 GPU 负责 8 个头
    
    Q、K、V 投影:列并行(输出切分)
    Output 投影:行并行(输入切分 + AllReduce)
    """
    def __init__(self, d_model, n_heads, tp_size):
        super().__init__()
        self.local_heads = n_heads // tp_size       # 每 GPU 负责的头数
        self.head_dim = d_model // n_heads
        local_d = self.local_heads * self.head_dim  # 本 GPU 的 KQV 维度
        
        # 列并行:每个 GPU 只持有部分 Q/K/V 投影
        self.qkv_proj = ColumnParallelLinear(d_model, 3 * local_d, tp_size=1)
        # 行并行:聚合各 GPU 的注意力输出
        self.out_proj = RowParallelLinear(local_d, d_model, tp_size)

TP 的适用场景与限制

特性说明
通信量每层 2 次 AllReduce(前向 1 次 + 反向 1 次)
通信延迟高——每层必须等通信完成才能继续(同步通信)
推荐 GPU 连接必须在同一节点内(NVLink),跨节点带宽太低
适合单层参数量过大(MLP 层 4H×H 权重)
不适合层数多但每层不大;跨节点扩展
典型规模TP=4~8(一台 8 卡服务器内使用)

三、流水线并行(PP)

核心思路

流水线并行将模型按层切分,不同 GPU 负责不同的层组,像工厂流水线一样并行处理:

流水线并行 PP:顺序执行 vs 1F1B 调度甘特图

如上图所示,无 PP 时 GPU 大部分时间空闲等待;1F1B 调度则让每个 GPU 在稳定阶段交替执行前向和反向,显著提升利用率。气泡率公式为:

其中 P 是 PP_stages(流水线阶段数),M 是 micro_batches 数量。当 M ≫ P 时,气泡率趋近于零

Bubble(流水线气泡)问题

流水线并行最大的挑战是气泡(Bubble)——GPU 等待前一个阶段输出时的空闲时间:

# GPipe 调度策略(朴素 PP)
class GPipe:
    """
    每个 micro-batch 完整前向传播后再反向传播
    气泡率 = (PP_stages - 1) / (micro_batches + PP_stages - 1)
    
    当 micro_batches >> PP_stages 时,气泡率 → 0
    但 micro_batches 太多会增加显存(需要缓存所有激活值)
    """
    def __init__(self, stages=4, micro_batches=8):
        self.stages = stages
        self.micro_batches = micro_batches
        self.bubble_rate = (stages - 1) / (micro_batches + stages - 1)
        print(f"气泡率: {self.bubble_rate:.1%}")  # stages=4, m=8 → 27.3%

1F1B 调度(Megatron-LM 提出,更优):

1F1B 的核心改进在于显存峰值更低——它只需缓存 stages 个 micro-batch 的激活值(而 GPipe 需要缓存所有 M 个),因此在相同气泡率下更节省显存。详见上方甘特图中的调度对比。

class OneFOneB:
    """
    1F1B 调度核心思路:
    稳定状态下,每个时间步都是 1 次前向 + 1 次反向
    避免了 GPipe 需要缓存所有 micro-batch 激活值的问题
    """
    def schedule(self, num_stages, num_micro_batches):
        """
        返回每个 stage 的执行序列
        F = Forward, B = Backward
        数字表示 micro-batch 编号
        """
        schedule = {stage: [] for stage in range(num_stages)}
        
        # Warmup phase: 前 PP_stages 个 micro-batch 只做前向
        for stage in range(num_stages):
            for mb in range(num_stages - stage):
                schedule[stage].append(f"F{mb}")
        
        # Steady state: 每个 GPU 都是 1F1B 交替
        # ...(实际实现见 Megatron-LM 源码)
        
        return schedule

PP 的适用场景与限制

特性说明
通信量只在相邻阶段间传递激活值(点对点,低通信量)
通信延迟相对低(P2P 通信)
推荐 GPU 连接适合跨节点(100Gbps IB 即可)
适合模型层数多,单层参数量不大
缺点流水线气泡;调试复杂;实现难度高
典型规模PP=4~16(多机多卡)

四、序列并行(SP)

核心思路

序列并行(Sequence Parallelism)沿序列长度维度切分,主要解决长序列下激活值显存爆炸的问题。

在标准 Transformer 中,激活值的显存占用与序列长度成二次方关系(注意力矩阵 ),当序列长度从 2K 增长到 128K 时,这是致命瓶颈。

序列并行 SP+TP 通信数据流 + Ring Attention(CP)

如上图所示,SP+TP 将标准 TP 的 AllReduce 拆分为 AllGather + ReduceScatter,使序列在 SP 区域(LayerNorm、Dropout 等)始终保持分片状态,从而将这些位置的激活值显存减少 tp_size 倍。图中下方的 Ring Attention(Context Parallelism)则可进一步将注意力矩阵的显存从 降至 ,支持超过 1M token 的超长上下文训练。

SP + TP 组合(Megatron-LM v3)

SP 通常与 TP 配合使用(在同一组 GPU 内),共同减少激活值(已在上图中展示完整数据流)。

关键优化:将 TP 的 AllReduce 替换为 ReduceScatter + AllGather 组合,可以在序列维度保持分片,减少 50% 的激活值显存。

# SP+TP 的通信模式(比较)
class SequenceParallelism:
    """
    标准 TP 通信:
      前向:AllGather(x) → ColParallel → RowParallel + AllReduce
      反向:AllReduce(∇) → RowParallel → ColParallel + AllGather
    
    SP+TP 通信(激活值始终保持序列分片):
      前向:AllGather(x) → ColParallel → RowParallel + ReduceScatter
      反向:AllGather(∇) → RowParallel → ColParallel + ReduceScatter
    
    显存节省:序列并行区域(Dropout, LayerNorm)的激活值减少 tp_size 倍
    通信量:与标准 TP 相同(AllGather ≈ AllReduce 通信量)
    """
    pass

Context Parallelism(CP):Ring Attention

当序列长度超过 SP 能处理的范围时(例如 1M token),还有更激进的 Context Parallelism

# Ring Attention(CP 的核心)
# 将注意力的 Q/K/V 沿序列维度切分,
# 通过 Ring 通信实现分布式注意力计算

class RingAttention:
    """
    核心思路:
    - 每个 GPU 拥有完整 Q 的一段 [B, S/cp, H]
    - K/V 以 Ring 方式在 GPU 间循环传递
    - 每个 GPU 在本地完成部分注意力计算
    - 最终合并得到完整注意力输出
    
    通信量:O(S/cp × H × cp) = O(S × H),与层数无关
    显存:注意力矩阵从 O(S²) 降至 O(S²/cp²)
    
    典型应用:Apple MLX、Google JAX 超长上下文训练
    """
    def forward(self, q, k, v, cp_group):
        S_local = q.shape[1]  # S / cp
        output = torch.zeros_like(q)
        
        # 本地 K/V
        k_local = k.clone()
        v_local = v.clone()
        
        for step in range(self.cp_size):
            # 计算当前 K/V 块的注意力
            attn_out = flash_attn(q, k_local, v_local, causal=(step == 0))
            output += attn_out
            
            # 通过 Ring 传递 K/V 到下一个 GPU
            k_local = self.ring_send_recv(k_local, cp_group)
            v_local = self.ring_send_recv(v_local, cp_group)
        
        return output

五、ZeRO:消除优化器冗余

问题来源

在标准 DDP 中,每个 GPU 存储完整的:

  • 模型参数(Parameters):FP16,2 byte/param
  • 梯度(Gradients):FP32,4 byte/param
  • 优化器状态(Optimizer states):Adam 需要 m + v,FP32,8 byte/param

总计约 16 byte/param。对 70B 模型 = 1120 GB,即使 1000 张 A100 仍需每卡 > 1GB,但实际问题是每张卡存储了完全相同的优化器状态——这是巨大的冗余!

ZeRO 三阶段

ZeRO(Zero Redundancy Optimizer) 由微软 DeepSpeed 提出 [2],通过分片消除冗余:

ZeRO 0/1/2/3 三阶段分片对比(以 4 GPU、16 byte/param 为例)

如上图所示,ZeRO 分三个阶段递进分片,逐步消除优化器状态、梯度、模型参数的冗余:

阶段分片内容每卡显存(N=4)节省
ZeRO-0(DDP)无分片80 B/param
ZeRO-1优化器状态~38 B/param52%
ZeRO-2+ 梯度~21 B/param74%
ZeRO-3+ 模型参数80/N B/paramN 倍

1000 张 A100 训练 70B 模型,每卡约 1.12GB,轻松放下!

# 使用 DeepSpeed ZeRO 训练
import deepspeed

# ZeRO Stage 配置
ds_config = {
    "zero_optimization": {
        "stage": 3,                      # ZeRO-3:全量分片
        "offload_optimizer": {
            "device": "cpu",             # 可选:将优化器状态卸载到 CPU
            "pin_memory": True,
        },
        "offload_param": {
            "device": "cpu",             # 可选:将参数卸载到 CPU(ZeRO-Infinity)
        },
        "overlap_comm": True,            # 通信与计算重叠
        "contiguous_gradients": True,    # 连续显存提升通信效率
        "sub_group_size": 1e9,
        "reduce_bucket_size": 5e8,
    },
    "bf16": {"enabled": True},
    "gradient_checkpointing": True,
}

# 初始化 DeepSpeed 引擎
model_engine, optimizer, _, _ = deepspeed.initialize(
    model=model,
    config=ds_config,
)

# 训练循环与普通 PyTorch 几乎相同
for batch in dataloader:
    loss = model_engine(batch)
    model_engine.backward(loss)    # 替代 loss.backward()
    model_engine.step()            # 替代 optimizer.step()

ZeRO++ 与 ZeRO-Infinity

ZeRO++(2023) 在 ZeRO-3 基础上进一步压缩通信量:

  • qwZ(量化权重):AllGather 时将 FP16 量化为 INT8,通信量减少 50%
  • hpZ(层次化分区):优先做节点内分区,减少跨节点流量
  • qgZ(量化梯度):ReduceScatter 前量化,进一步降低带宽需求

ZeRO-Infinity 则将参数/梯度/优化器状态卸载到 CPU RAM 或 NVMe SSD,理论上可训练任意大小的模型,但速度受限于 PCIe 带宽,适合"大模型 + 少量 GPU"的研究场景。

FSDP(PyTorch 原生 ZeRO)

PyTorch 2.0+ 内置了 Fully Sharded Data Parallel(FSDP),是 ZeRO-3 的原生实现:

from torch.distributed.fsdp import (
    FullyShardedDataParallel as FSDP,
    MixedPrecision,
    BackwardPrefetch,
    ShardingStrategy,
)

# FSDP 配置
fsdp_config = dict(
    sharding_strategy=ShardingStrategy.FULL_SHARD,    # ZeRO-3 等效
    # ShardingStrategy.SHARD_GRAD_OP = ZeRO-2
    # ShardingStrategy.NO_SHARD = 标准 DDP
    
    mixed_precision=MixedPrecision(
        param_dtype=torch.bfloat16,
        reduce_dtype=torch.float32,
        buffer_dtype=torch.bfloat16,
    ),
    backward_prefetch=BackwardPrefetch.BACKWARD_PRE,  # 预取下一层参数
    cpu_offload=None,  # 或 CPUOffload(offload_params=True)
    auto_wrap_policy=lambda module, recurse, *args: (
        recurse or isinstance(module, TransformerDecoderLayer)
    ),
)

model = FSDP(model, **fsdp_config)

六、专家并行(EP)

EP 专门针对 MoE(Mixture of Experts) 模型,如 Mixtral、DeepSeek-V3 等。

MoE 回顾

专家并行 EP:MoE 路由与 AllToAll 动态分发

如上图左侧所示,标准 FFN 对每个 token 使用同一个 FFN;而 MoE FFN 设置了多个专家,每个 token 只经过其中 K 个(Top-K 路由),大幅提升了参数量而不增加激活计算量。DeepSeek-V3 使用 256 个专家,每 token 激活 8 个,总参数 671B 但激活约 37B。

EP 的切分方式

如上图右侧所示,EP 将不同专家分配给不同 GPU(以 256 专家、EP=8 为例,每 GPU 持有 32 个专家)。最大挑战是 MoE 路由是动态的——每个 token 发给哪个专家在运行时才决定,需要两次 AllToAll 通信(分发 token → 计算 → 收回结果)。

class ExpertParallelMoE(nn.Module):
    """
    专家并行 MoE 层
    
    通信模式:
    1. Router 决定每个 token 发给哪个专家(本地计算)
    2. AllToAll:将 token 发送到对应 GPU
    3. 各 GPU 独立计算自己持有专家的 FFN
    4. AllToAll:将计算结果发回原始 GPU
    5. 合并专家输出
    """
    def __init__(self, d_model, n_experts, n_experts_per_token, ep_group):
        super().__init__()
        self.ep_group = ep_group
        self.ep_size = dist.get_world_size(ep_group)
        self.local_n_experts = n_experts // self.ep_size
        
        # 每个 GPU 只持有 local_n_experts 个专家
        self.experts = nn.ModuleList([
            MLP(d_model) for _ in range(self.local_n_experts)
        ])
        self.router = Router(d_model, n_experts, n_experts_per_token)
    
    def forward(self, x):
        B, S, D = x.shape
        x_flat = x.view(-1, D)  # [B*S, D]
        
        # 1. 路由计算(本地)
        expert_indices, expert_weights = self.router(x_flat)
        
        # 2. AllToAll:将 token 分发到对应 GPU
        x_dispatched = self.all_to_all_dispatch(x_flat, expert_indices)
        
        # 3. 本地专家计算
        expert_outputs = []
        for i, expert in enumerate(self.experts):
            # 获取分配给本 GPU 第 i 个专家的 token
            tokens_for_expert = x_dispatched[i]
            if tokens_for_expert.shape[0] > 0:
                expert_outputs.append(expert(tokens_for_expert))
        
        # 4. AllToAll:将结果发回原始 GPU
        combined = self.all_to_all_combine(expert_outputs)
        
        # 5. 加权合并
        return self.weighted_sum(combined, expert_weights)

七、3D / 4D / 5D 并行:组合使用

生产训练通常同时使用多种并行策略,这被称为 3D/4D/5D 并行

  • 3D 并行 = TP × PP × DP(Megatron-LM 的经典方案)
  • 4D 并行 = TP × PP × DP × SP(加入序列并行)
  • 5D 并行 = TP × PP × DP × SP × EP(再加入专家并行,MoE 模型专用)

配置示例(Llama-3 70B 在 512 张 H100 上):TP=8(节点内 NVLink)× PP=8(跨节点 InfiniBand)× DP=8(ZeRO-1)= 512 GPU。

选择并行策略的经验法则

并行策略选型决策树

按照决策树:先确定 TP(单层是否放得下),再确定 PP(是否需要跨节点),再看 SP(序列是否超 32K),剩余 GPU 全给 DP,MoE 则独立叠加 EP。

# 实际配置示例(参考 Megatron-LM 和 LLaMA-Factory)
training_config = {
    # 并行维度
    "tensor_model_parallel_size": 4,      # TP=4(节点内 4 卡)
    "pipeline_model_parallel_size": 4,    # PP=4(跨 4 个节点)
    "data_parallel_size": 16,             # DP=16(总 256 卡 / TP4 / PP4)
    "sequence_parallel": True,            # 与 TP 配合使用
    
    # ZeRO 配置
    "zero_stage": 1,                      # PP+TP 模式下通常只需 ZeRO-1
    
    # Batch 配置
    "global_batch_size": 2048,
    "micro_batch_size": 2,               # 每张卡每次处理 2 个样本
    "gradient_accumulation_steps": 64,   # 2048 / (2 × 16) = 64
    
    # 序列长度
    "seq_length": 8192,
    
    # 混合精度
    "bf16": True,
    "fp32_residual_connection": False,
}

通信量对比

策略通信操作通信量延迟敏感性推荐网络
DPAllReduce2P(参数量)Ethernet / IB
TPAllReduce / AllGather2 × 激活量/层NVLink(节点内)
PPP2P Send/Recv激活量 × B/SIB
SPAllGather / ReduceScatter同 TPNVLink(节点内)
ZeRO-3AllGather(前向)+ ReduceScatter(反向)3P(更多)IB

八、Gradient Checkpointing(梯度检查点)

这不是并行策略,但与分布式训练密不可分——它是用计算换显存的重要技巧:

# 标准训练:保留所有激活值(显存 O(layers × S × H))
output = model(input)
loss = criterion(output, target)
loss.backward()  # 使用前向时保存的激活值

# 梯度检查点:只保留部分激活值,反向时重新计算
from torch.utils.checkpoint import checkpoint

def forward_with_checkpointing(model, input):
    """
    不保存中间激活值,反向传播时重新前向计算一次
    
    显存节省:从 O(L × S × H) → O(√L × S × H)
    计算开销:增加约 30%(相当于额外做了一次前向)
    """
    # 对每个 Transformer 层启用检查点
    for layer in model.layers:
        # 不保存 layer 的激活值
        input = checkpoint(layer, input, use_reentrant=False)
    return input

# 在 Hugging Face Transformers 中启用
from transformers import LlamaForCausalLM

model = LlamaForCausalLM.from_pretrained("meta-llama/Llama-3-8B")
model.gradient_checkpointing_enable()  # 一行代码,节省 ~50% 显存

综合对比与选型建议

不同模型规模的推荐配置

模型规模GPU 数量推荐配置典型框架
1B~7B1~8 卡DP + ZeRO-2/3DeepSpeed, FSDP
7B~13B8~32 卡DP/ZeRO-3 + GradCkptFSDP, LLaMA-Factory
13B~70B32~256 卡TP=4 + PP=2 + DP + ZeRO-1Megatron-LM
70B~400B256~1024 卡TP=8 + PP=4~8 + DP + ZeRO-1Megatron-LM
400B~1T(MoE)512~8192 卡TP=8 + PP=8 + EP=8 + DP + ZeRO-1Megatron-Core

主流训练框架对比

框架支持的并行适合场景易用性
DeepSpeedDP+ZeRO, PP, TP(有限)中小模型,资源受限⭐⭐⭐⭐
PyTorch FSDPDP+ZeRO-3中等规模,PyTorch 原生⭐⭐⭐⭐
Megatron-LMTP+PP+SP+DP+ZeRO超大规模预训练⭐⭐
LLaMA-Factory封装 FSDP/DeepSpeedSFT/RL 微调⭐⭐⭐⭐⭐
Axolotl封装 FSDP/DeepSpeedSFT 微调⭐⭐⭐⭐

本节小结

技术切分维度解决问题关键约束
DP / DDPBatch吞吐量每卡需存完整模型
ZeRO-1/2/3优化器/梯度/参数DP 下的显存冗余通信量增加
FSDP参数+梯度+优化器ZeRO-3 的 PyTorch 原生版多 AllGather 开销
TP权重矩阵内部单层过大需 NVLink 高带宽
PP模型层(深度)层数过多流水线气泡
SP序列长度长序列激活值配合 TP 使用
CP / Ring Attn超长序列百万 token 注意力注意力计算切分
EPMoE 专家专家参数分布AllToAll 动态路由

💡 Agent 开发者的核心收获
如果你在用 LLaMA-Factory 或 Axolotl 微调模型,FSDP(ZeRO-3)+ Gradient Checkpointing 是小团队的最优选择——支持 8 卡以内训练 70B 模型。
如果你在设计从零预训练,需要认真规划 3D 并行(TP × PP × DP)的组合。


参考文献

[1] SHOEYBI M, et al. Megatron-LM: training multi-billion parameter language models using model parallelism[J]. arXiv:1909.08053, 2019.

[2] RAJBHANDARI S, et al. ZeRO: memory optimizations toward training trillion parameter models[C]//SC, 2020.

[3] KORTHIKANTI V, et al. Reducing activation recomputation in large Transformer models[J]. arXiv:2205.05198, 2022.(SP 原论文)

[4] LIU Z, et al. Ring attention with blockwise transformers for near-infinite context[J]. arXiv:2310.01889, 2023.(Ring Attention)

[5] Microsoft DeepSpeed Team. ZeRO++: extremely efficient collective communication for giant model training[J]. arXiv:2306.10209, 2023.

[6] ZHAO Y, et al. PyTorch FSDP: experiences on scaling fully sharded data parallel[J]. arXiv:2304.11277, 2023.


上一节:11.2 SFT + LoRA 基础训练
下一节:11.3 PPO:近端策略优化