AdamW目前是大语言模型训练的默认优化器,而大部分资料对Adam跟AdamW区别的介绍都不是很明确,在此梳理一下Adam与AdamW的计算流程,明确一下二者的区别。
TLDR:AdamW将优化过程中使用的针对网络权重的衰减项(或者叫正则项)从loss中单独拿了出来,不参与Adam中一二阶动量的计算。
下面是二者的详细对比:
Adam
首先是Adam,给定在迭代步数 \(t\) 时模型的参数 \(\theta_t\) 与梯度 \(g_t\) ,Adam的计算公式如下:
- 式(1)用于计算梯度的一阶指数滑动平均
- 式(2)用于计算梯度的二阶项的指数滑动平均
- 式(3)与(4)对计算得到的指数滑动平均值进行消偏
- 式(5)为Adam的更新公式,其可以拆成两部分理解:动量更新与自适应学习率。
AdamW
AdamW 相对与Adam的改动十分简单,其将权重衰减项从梯度的计算中拿出来直接加在了最后的权重更新步骤上(下图,式12)。其提出的动机在于:原先Adam的实现中如果采用了L2权重衰减,则相应的权重衰减项会被直接加在loss里,从而导致动量的一阶与二阶滑动平均均考虑了该权重衰减项(下图. 式6),而这影响了Adam的优化效果,而将权重衰减与梯度的计算进行解耦能够显著提升Adam的效果。目前,AdamW现在已经成为transformer训练中的默认优化器了。

从上述的计算步骤中可以看出,Adam和AdamW在反向传播时需要维护的变量为原始参数 \(\theta_t\) ,梯度 \(g_t\) ,动量 \(m_t\) 与二阶动量 \(v_t\) ,明面上涉及的参数数量是网络参数的4倍。
实际上,使用Adam或AdamW进行训练时的显存的需求并不能简单的记为网络参数的倍数。训练过程中的显存分析是一件很复杂的事情,与训练过程超参的选取(比如batchsize,序列长度)、网络架构甚至是优化器的实现方式等都有关系,关于显存占用的分析,最近huggingface官方发了一篇博文来解释这个事情,感兴趣的读者可以阅读一下:Visualize and understand GPU memory in PyTorch