Skip to content

Commit a8865bb

Browse files
authored
Update pretrain_gpt2.py (deepspeedai#59)
Changing torch.optim.Adam to torch.optim.AdamW
1 parent 28a32fb commit a8865bb

File tree

1 file changed

+3
-1
lines changed

1 file changed

+3
-1
lines changed

Megatron-LM/pretrain_gpt2.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -111,9 +111,11 @@ def get_optimizer(model, args):
111111
param.model_parallel = False
112112

113113
if args.cpu_optimizer:
114+
#Apex FusedAdam uses decoupled weight decay so use the same here
114115
if args.cpu_torch_adam:
115-
cpu_adam_optimizer = torch.optim.Adam
116+
cpu_adam_optimizer = torch.optim.AdamW
116117
else:
118+
#TODO add option for decoupled weight decay in DeepCPUAdam
117119
from deepspeed.ops.adam import DeepSpeedCPUAdam
118120
cpu_adam_optimizer = DeepSpeedCPUAdam
119121
optimizer = cpu_adam_optimizer(param_groups,

0 commit comments

Comments
 (0)