-
Notifications
You must be signed in to change notification settings - Fork 2.2k
Description
I cannot restore training model from checkpoint. The error log:
File "/data/zmining/jupyter-notebook/antnh/embeddings/tml/common/checkpointing/snapshot.py", line 67, in restore
snapshot.restore(self.state) #check
File "/home/zdeploy/anaconda3/envs/quanhm_torchrec/lib/python3.8/site-packages/torchsnapshot/snapshot.py", line 406, in restore
self._load_stateful(
File "/home/zdeploy/anaconda3/envs/quanhm_torchrec/lib/python3.8/site-packages/torchsnapshot/snapshot.py", line 671, in _load_stateful
stateful.load_state_dict(state_dict)
File "/home/zdeploy/anaconda3/envs/quanhm_torchrec/lib/python3.8/site-packages/torch/nn/modules/module.py", line 2027, in load_state_dict
load(self, state_dict)
File "/home/zdeploy/anaconda3/envs/quanhm_torchrec/lib/python3.8/site-packages/torch/nn/modules/module.py", line 2015, in load
load(child, child_state_dict, child_prefix)
File "/home/zdeploy/anaconda3/envs/quanhm_torchrec/lib/python3.8/site-packages/torch/nn/modules/module.py", line 2015, in load
load(child, child_state_dict, child_prefix)
File "/home/zdeploy/anaconda3/envs/quanhm_torchrec/lib/python3.8/site-packages/torch/nn/modules/module.py", line 2015, in load
load(child, child_state_dict, child_prefix)
[Previous line repeated 2 more times]
File "/home/zdeploy/anaconda3/envs/quanhm_torchrec/lib/python3.8/site-packages/torch/nn/modules/module.py", line 2009, in load
module._load_from_state_dict(
File "/home/zdeploy/anaconda3/envs/quanhm_torchrec/lib/python3.8/site-packages/torch/nn/modules/module.py", line 1909, in _load_from_state_dict
hook(state_dict, prefix, local_metadata, strict, missing_keys, unexpected_keys, error_msgs)
File "/home/zdeploy/anaconda3/envs/quanhm_torchrec/lib/python3.8/site-packages/torch/nn/modules/module.py", line 69, in call
return self.hook(module, *args, **kwargs)
File "/home/zdeploy/anaconda3/envs/quanhm_torchrec/lib/python3.8/site-packages/torchrec/distributed/embeddingbag.py", line 439, in _pre_load_state_dict_hook
local_shards = state_dict[key].local_shards()
KeyError: 'model._dmp_wrapped_module.module.large_embeddings.ebc.embedding_bags.user.weight'
Thank you very much for helping me.