Skip to content

分布式训练与显存优化

「单卡放不下千亿模型怎么办」是工程岗与算法岗都会问的硬核题。本文讲清几种并行策略、ZeRO、混合精度,以及显存到底花在哪。

显存都花在哪了?

训练时单卡显存主要由四部分组成:

部分说明量级(以参数量 Ψ 计)
模型参数(Weights)FP16 下每参数 2 字节
梯度(Gradients)与参数同形状
优化器状态(Optimizer States)Adam 要存动量 + 方差,且常用 FP32 主副本12Ψ
激活值(Activations)前向中间结果,供反向用随 batch、序列长度变化

关键认知:用 Adam 训练时,优化器状态(12Ψ)+ 参数和梯度的 FP32 副本,往往比模型本身大得多。一个 7B 模型,光「参数+梯度+优化器状态」就要约 7B×16 ≈ 112GB,远超单卡显存。这正是 ZeRO 要解决的问题。

激活值可用 梯度检查点(Gradient/Activation Checkpointing) 大幅降低:只保存部分激活,反向时重算其余,用计算换显存。

三种并行策略

数据并行(Data Parallelism, DP)

每张卡保存完整模型副本,喂不同的数据切片,各自前向反向算出梯度后通过 All-Reduce 同步求平均,再各自更新。

  • 优点:实现简单、扩展性好、加速明显。
  • 缺点:每张卡都要放下完整模型,解决不了「单卡放不下」的问题

张量并行(Tensor Parallelism, TP,层内并行)

单层内部的大矩阵切分到多张卡上并行计算(如把 FFN/Attention 的权重矩阵按列/行切开),算完再通信合并。

  • 适合切分单层巨大的权重;通信频繁(每层都要通信),对带宽要求极高,通常只在单机多卡(NVLink)内做。
  • 代表实现:Megatron-LM。

流水线并行(Pipeline Parallelism, PP,层间并行)

把模型的不同层放到不同卡上,数据像流水线一样逐段流过。

  • 挑战是「流水线气泡(bubble)」——前面的卡算完要等后面,存在空闲。用 micro-batch(把 batch 切小、错峰送入,如 GPipe / 1F1B 调度)来填充气泡、提高利用率。

3D 并行

超大模型训练通常组合使用 DP + TP + PP(再加 ZeRO/序列并行),即「3D 并行」:机内用 TP(高带宽),机间用 PP 和 DP。这是 GPT-3、Megatron-Turing 等千亿模型的标准做法。

ZeRO(DeepSpeed)与 FSDP

ZeRO(Zero Redundancy Optimizer)的洞察是:数据并行下每张卡都存了完整的参数/梯度/优化器状态,存在大量冗余。 ZeRO 把这些**分片(shard)**到各卡,需要时再用通信临时聚合。

阶段分片内容显存节省代价
ZeRO-1优化器状态4 倍通信略增
ZeRO-2+ 梯度8 倍通信再增
ZeRO-3+ 参数与卡数成正比通信最多(前向/反向都要聚合参数)
  • ZeRO-Offload / Infinity:进一步把状态卸载到 CPU 内存甚至 NVMe,单卡也能训大模型(牺牲速度)。
  • FSDP(PyTorch Fully Sharded Data Parallel):PyTorch 原生的 ZeRO-3 等价实现,是目前社区主流。

ZeRO/FSDP 本质是「数据并行的省显存版」:保留 DP 的易用性,又通过分片消除冗余,让大模型能在更多卡上铺开。

混合精度训练

用低精度(FP16/BF16)做计算以省显存、提速,同时保留 FP32 主权重保证数值稳定。

  • FP16:动态范围小,容易溢出/下溢,需要 Loss Scaling(放大 loss 防止梯度下溢)。
  • BF16:指数位与 FP32 相同,动态范围大、不易溢出,无需 loss scaling,是大模型训练首选(需硬件支持,如 A100/H100)。
  • FP8:H100 起支持,训练/推理进一步提速,是前沿方向(DeepSeek-V3 用 FP8 训练)。

高频追问

Q:DP、TP、PP 怎么选? 优先 DP(简单);模型单卡放不下时,机内(NVLink)用 TP 切大矩阵、机间用 PP 切层;再叠加 ZeRO/FSDP 省显存。通信开销 TP > PP > DP,所以 TP 只在高带宽的机内用。

Q:ZeRO-3 和张量并行都能省参数显存,区别是什么? TP 是把单层算子切开并行计算(改变计算方式);ZeRO-3 是把参数存储分片,计算时临时聚合(不改变计算逻辑,是数据并行的扩展)。TP 通信在算子内部、更频繁;ZeRO-3 在每层前向/反向边界聚合。二者可叠加。

Q:梯度检查点(重计算)省的是什么显存? 省的是激活值显存。代价是反向传播时要重新做一次前向计算,通常增加约 30% 计算时间,换取激活显存大幅下降,是长序列/大 batch 训练的常用手段。

Q:为什么 BF16 比 FP16 更适合大模型训练? BF16 的指数位和 FP32 一样多(8 位),动态范围大,几乎不会溢出,省去了 FP16 必须的 loss scaling 调参;代价是尾数精度低一些,但对大模型训练影响很小。

Q:流水线气泡怎么来的、怎么减小? 流水线启动和收尾阶段,部分卡无事可做(在等数据流过)。把全局 batch 切成多个 micro-batch 错峰流入(如 1F1B 调度),让各卡尽量同时有活干,就能减小气泡比例。

基于 MIT 许可发布