加载模型时发生错误RuntimeError: Error(s) in loading state_dict for Net:unexpected key(s) in state_dict: XXX

Traceback (most recent call last):
File "demo.py", line 380, in <module>
model.load_state_dict(torch.load('./0428.pth'))
File "/root/miniconda3/lib/python3.8/site-packages/torch/nn/modules/module.py", line 1482, in load_state_dict
raise RuntimeError('Error(s) in loading state_dict for {}:\n\t{}'.format(
RuntimeError: Error(s) in loading state_dict for ViT:
Unexpected key(s) in state_dict: "transformer.skipcat.3.weight", "transformer.skipcat.3.bias", "transformer.skipcat.4.weight", "transformer.skipcat.4.bias".
原因:
加载使用模型时和训练模型时的环境不一致.
解决方法:
将load_state_dict(state_dict) 改成 model.load_state_dict(state_dict, False)
model.load_state_dict(torch.load('models/params.pt'),strict=False)
问题解决~
在尝试加载预训练的PyTorch模型时遇到RuntimeError,错误指出state_dict中存在意外的键。问题源于训练模型时的环境与加载环境不一致。解决方法是使用`model.load_state_dict(torch.load('path'), strict=False)`,通过设置`strict=False`忽略不匹配的键,成功加载模型参数。
1万+

被折叠的 条评论
为什么被折叠?



