训练深入:优化器、混合精度与稳定性
「Adam 和 AdamW 有什么区别?」「为什么要做学习率 warmup?」「训练 loss 突然飙升怎么办?」这些是算法岗深挖训练细节的高频题。本文补齐大模型训练的工程内功。分布式并行见 分布式训练。
优化器:从 SGD 到 AdamW
优化器决定「拿到梯度后怎么更新参数」。
- SGD:朴素地沿梯度反方向走,简单但收敛慢、易震荡。
- Momentum(动量):累积历史梯度方向(像惯性),加速收敛、平滑震荡。
- AdaGrad / RMSProp:自适应学习率——对更新频繁的参数减小步长。RMSProp 用滑动平均避免 AdaGrad 学习率衰减过快。
- Adam:= Momentum + RMSProp,同时维护梯度的一阶矩(均值/动量)和二阶矩(方差/自适应步长),是深度学习的默认选择。
- AdamW:大模型训练的实际标准,见下。
Adam vs AdamW(高频)
区别在权重衰减(weight decay,即 L2 正则)的实现方式:
- Adam:把权重衰减加进梯度里,再经过自适应缩放——导致衰减强度被二阶矩「污染」,对不同参数不一致。
- AdamW:把权重衰减解耦(decoupled),直接作用在权重上,与梯度自适应分离:
$$\theta_t = \theta_{t-1} - \eta \left( \frac{\hat{m}_t}{\sqrt{\hat{v}t}+\epsilon} + \lambda \theta \right)$$
AdamW 让权重衰减回归「正确」的正则化效果,泛化更好,是几乎所有大模型预训练的默认优化器。
注意:Adam 要为每个参数存一阶矩和二阶矩,优化器状态约为参数量的 2 倍(且常用 FP32),这是训练显存的大头(详见 分布式训练)。
学习率调度(LR Schedule)
学习率不是固定的,典型策略是 Warmup + 衰减:
学习率
│ ╱‾‾‾‾‾╲___
│ ╱ ╲___
│ ╱ (cosine 余弦衰减) ╲___
│ ╱ ╲
└─┴────────────────────────────────▶ 训练步数
warmup 平滑下降- Warmup(预热):训练初期从极小学习率线性升到目标值。为什么需要? 初始参数随机、梯度噪声大,一上来用大学习率容易把模型「带崩」(尤其 Adam 的二阶矩估计还不准)。warmup 让训练先稳住。
- 衰减(如 Cosine):后期逐渐减小学习率,让模型精细收敛到更好的最小值。
混合精度训练
用低精度算、高精度存,省显存提速(详见 分布式训练):
- FP16:动态范围小,梯度易下溢为 0,必须配 Loss Scaling(把 loss 放大 N 倍,梯度同步放大避免下溢,更新前再缩回)。
- BF16:指数位与 FP32 相同,动态范围大,几乎不溢出,无需 loss scaling,是大模型首选(需 A100/H100 等支持)。
- FP8:H100 起支持,进一步提速,DeepSeek-V3 已用于大规模训练,是前沿方向。
梯度处理技巧
- 梯度累积(Gradient Accumulation):显存装不下大 batch 时,累积多个小 batch 的梯度再统一更新,等效放大 batch size。
- 梯度裁剪(Gradient Clipping):当梯度范数超过阈值时按比例缩小,防止「梯度爆炸」导致参数更新过大、训练崩溃。大模型训练几乎必用。
- 梯度检查点(Checkpointing):用重计算换激活显存,见 分布式训练。
训练稳定性:Loss Spike 怎么办?
超大模型训练常遇到 loss 突然飙升(loss spike) 甚至发散,是工程上的真实痛点。常见成因与对策:
| 成因 | 对策 |
|---|---|
| 学习率过大 | 降低峰值学习率、延长 warmup |
| 数值溢出(FP16) | 改用 BF16、调 loss scaling |
| 某批数据异常 | 跳过该 batch、加数据清洗 |
| 梯度爆炸 | 梯度裁剪 |
| logits 过大 | 加 z-loss(正则化 softmax 的归一化项) |
| 初始化/架构不稳 | 合适的初始化、Pre-Norm、调整残差缩放 |
实践中常用「从最近的 checkpoint 回滚 + 跳过问题数据 + 微调超参」来度过 spike。
高频追问
Q:Adam 和 AdamW 的区别?为什么大模型用 AdamW? 区别在权重衰减的实现:Adam 把衰减混进梯度、被自适应缩放扭曲;AdamW 把衰减解耦、直接作用于权重,使正则化效果正确,泛化更好。所以大模型预训练默认用 AdamW。
Q:为什么需要学习率 warmup? 训练初期参数随机、梯度噪声大、Adam 的二阶矩估计不准,直接用大学习率易导致训练发散。warmup 用很小的学习率起步、逐步升高,让训练先进入稳定区。
Q:FP16 和 BF16 训练怎么选? 优先 BF16:它指数位和 FP32 一样宽,动态范围大、几乎不溢出,省去 loss scaling 调参;FP16 尾数精度略高但范围小、易下溢,必须配 loss scaling。有 A100/H100 就用 BF16。
Q:梯度累积和增大 batch size 等价吗? 在数学上近似等价(都等效更大 batch),但梯度累积是「时间换显存」——多算几次再更新,速度变慢;它不减少计算量,只是绕过显存限制。注意 BatchNorm 等依赖 batch 统计的层会有差异(Transformer 用 LayerNorm 故无此问题)。
Q:loss spike(损失飙升)怎么处理? 先定位成因:学习率过大→降 LR/延长 warmup;FP16 溢出→换 BF16;坏数据→跳过该 batch;梯度爆炸→梯度裁剪;logits 过大→加 z-loss。工程上常「回滚到最近 checkpoint + 跳过问题数据 + 调超参」继续。
Q:梯度裁剪为什么重要? 大模型偶发的超大梯度会让参数更新过猛、直接训崩。梯度裁剪在梯度范数超阈值时按比例缩小,限制单步更新幅度,是保证训练稳定的常规手段。