We read every piece of feedback, and take your input very seriously.
To see all available qualifiers, see our documentation.
There was an error while loading. Please reload this page.
1 parent 28a32fb commit a8865bbCopy full SHA for a8865bb
Megatron-LM/pretrain_gpt2.py
@@ -111,9 +111,11 @@ def get_optimizer(model, args):
111
param.model_parallel = False
112
113
if args.cpu_optimizer:
114
+ #Apex FusedAdam uses decoupled weight decay so use the same here
115
if args.cpu_torch_adam:
- cpu_adam_optimizer = torch.optim.Adam
116
+ cpu_adam_optimizer = torch.optim.AdamW
117
else:
118
+ #TODO add option for decoupled weight decay in DeepCPUAdam
119
from deepspeed.ops.adam import DeepSpeedCPUAdam
120
cpu_adam_optimizer = DeepSpeedCPUAdam
121
optimizer = cpu_adam_optimizer(param_groups,
0 commit comments