分布式训练与显存优化
「单卡放不下千亿模型怎么办」是工程岗与算法岗都会问的硬核题。本文讲清几种并行策略、ZeRO、混合精度,以及显存到底花在哪。
显存都花在哪了?
训练时单卡显存主要由四部分组成:
| 部分 | 说明 | 量级(以参数量 Ψ 计) |
|---|---|---|
| 模型参数(Weights) | FP16 下每参数 2 字节 | 2Ψ |
| 梯度(Gradients) | 与参数同形状 | 2Ψ |
| 优化器状态(Optimizer States) | Adam 要存动量 + 方差,且常用 FP32 主副本 | 12Ψ |
| 激活值(Activations) | 前向中间结果,供反向用 | 随 batch、序列长度变化 |
关键认知:用 Adam 训练时,优化器状态(12Ψ)+ 参数和梯度的 FP32 副本,往往比模型本身大得多。一个 7B 模型,光「参数+梯度+优化器状态」就要约 7B×16 ≈ 112GB,远超单卡显存。这正是 ZeRO 要解决的问题。
激活值可用 梯度检查点(Gradient/Activation Checkpointing) 大幅降低:只保存部分激活,反向时重算其余,用计算换显存。
三种并行策略
数据并行(Data Parallelism, DP)
每张卡保存完整模型副本,喂不同的数据切片,各自前向反向算出梯度后通过 All-Reduce 同步求平均,再各自更新。
- 优点:实现简单、扩展性好、加速明显。
- 缺点:每张卡都要放下完整模型,解决不了「单卡放不下」的问题。
张量并行(Tensor Parallelism, TP,层内并行)
把单层内部的大矩阵切分到多张卡上并行计算(如把 FFN/Attention 的权重矩阵按列/行切开),算完再通信合并。
- 适合切分单层巨大的权重;通信频繁(每层都要通信),对带宽要求极高,通常只在单机多卡(NVLink)内做。
- 代表实现:Megatron-LM。
流水线并行(Pipeline Parallelism, PP,层间并行)
把模型的不同层放到不同卡上,数据像流水线一样逐段流过。
- 挑战是「流水线气泡(bubble)」——前面的卡算完要等后面,存在空闲。用 micro-batch(把 batch 切小、错峰送入,如 GPipe / 1F1B 调度)来填充气泡、提高利用率。
3D 并行
超大模型训练通常组合使用 DP + TP + PP(再加 ZeRO/序列并行),即「3D 并行」:机内用 TP(高带宽),机间用 PP 和 DP。这是 GPT-3、Megatron-Turing 等千亿模型的标准做法。
ZeRO(DeepSpeed)与 FSDP
ZeRO(Zero Redundancy Optimizer)的洞察是:数据并行下每张卡都存了完整的参数/梯度/优化器状态,存在大量冗余。 ZeRO 把这些**分片(shard)**到各卡,需要时再用通信临时聚合。
| 阶段 | 分片内容 | 显存节省 | 代价 |
|---|---|---|---|
| ZeRO-1 | 优化器状态 | 4 倍 | 通信略增 |
| ZeRO-2 | + 梯度 | 8 倍 | 通信再增 |
| ZeRO-3 | + 参数 | 与卡数成正比 | 通信最多(前向/反向都要聚合参数) |
- ZeRO-Offload / Infinity:进一步把状态卸载到 CPU 内存甚至 NVMe,单卡也能训大模型(牺牲速度)。
- FSDP(PyTorch Fully Sharded Data Parallel):PyTorch 原生的 ZeRO-3 等价实现,是目前社区主流。
ZeRO/FSDP 本质是「数据并行的省显存版」:保留 DP 的易用性,又通过分片消除冗余,让大模型能在更多卡上铺开。
混合精度训练
用低精度(FP16/BF16)做计算以省显存、提速,同时保留 FP32 主权重保证数值稳定。
- FP16:动态范围小,容易溢出/下溢,需要 Loss Scaling(放大 loss 防止梯度下溢)。
- BF16:指数位与 FP32 相同,动态范围大、不易溢出,无需 loss scaling,是大模型训练首选(需硬件支持,如 A100/H100)。
- FP8:H100 起支持,训练/推理进一步提速,是前沿方向(DeepSeek-V3 用 FP8 训练)。
高频追问
Q:DP、TP、PP 怎么选? 优先 DP(简单);模型单卡放不下时,机内(NVLink)用 TP 切大矩阵、机间用 PP 切层;再叠加 ZeRO/FSDP 省显存。通信开销 TP > PP > DP,所以 TP 只在高带宽的机内用。
Q:ZeRO-3 和张量并行都能省参数显存,区别是什么? TP 是把单层算子切开并行计算(改变计算方式);ZeRO-3 是把参数存储分片,计算时临时聚合(不改变计算逻辑,是数据并行的扩展)。TP 通信在算子内部、更频繁;ZeRO-3 在每层前向/反向边界聚合。二者可叠加。
Q:梯度检查点(重计算)省的是什么显存? 省的是激活值显存。代价是反向传播时要重新做一次前向计算,通常增加约 30% 计算时间,换取激活显存大幅下降,是长序列/大 batch 训练的常用手段。
Q:为什么 BF16 比 FP16 更适合大模型训练? BF16 的指数位和 FP32 一样多(8 位),动态范围大,几乎不会溢出,省去了 FP16 必须的 loss scaling 调参;代价是尾数精度低一些,但对大模型训练影响很小。
Q:流水线气泡怎么来的、怎么减小? 流水线启动和收尾阶段,部分卡无事可做(在等数据流过)。把全局 batch 切成多个 micro-batch 错峰流入(如 1F1B 调度),让各卡尽量同时有活干,就能减小气泡比例。