结论
对于 升级keras解决load_weights()中的未定义skip_mismatch关键字问题 中通过pip安装更高版本keras来解决TypeError: load_weights() got an unexpected keyword argument 'skip_mismatch’问题,首先需要确定model本身是通过keras API构建(模型继承于 keras.Model),还是通过tensorflow内置keras构建(模型继承于 tf.keras.Model)。
如果是前者,那么升级keras,可以解决。如果是后者,升级keras没有用,需要升级tensorflow,因为内置keras的版本与tensorflow的版本相关联。
如果不想使用tensorflow内置keras,可以考虑将tf.keras.Model改为keras.Model,同时修改相应的继承函数,例如call()转为__call__()。(这个方法没尝试成功,只是建议可以朝这个方向思考一下。)
项目场景:
- 环境配置(如图):
这里主要是说明keras和tensorflow-keras之间的关系,所以只放这两个即可。 - 使用tensorflow训练模型,并通过tf.keras.Model.save_weights(weight_filepath)保存模型权重。
问题描述
- 执行如下代码加载模型权重。
target_rnn = TargetRNN(args)
target_rnn.load_weights(weight_filepath
)
- 出现如下错误:

- 在网上看了一些解决方法,有人推荐使用参数by_name=True, skip_mismatch=True两个参数。(写文章的时候没找到原文。。。)代码实现变为如下形式。
target_rnn = TargetRNN(args)
target_rnn.load_weights(weight_filepath
by_name=True,
skip_mismatch=True)
- 进而产生新的错误:

- 找到 升级keras解决load_weights()中的未定义skip_mismatch关键字问题 ,按照步骤升级keras,没能解决我的问题。
原因分析:
-
既然显示load_weights() got an unexpected keyword argument 'skip_mismatch’的错误,说明参数没传对。直接查看函数定义,确实没有skip_mismatch的形参。
这里面有个细节,即keras的文件路径是在tensorflow下的,这说明keras是tensorflow内置的(当时没注意,但这一点是问题所在)。 -
连续尝试升级了keras 2.2.4、2.2.5、2.3.1,上图的load_weights函数中始终没有skip_mismatch的形参。说明 升级keras版本 并不能完全解决这个问题。
-
去keras官网搜索load_weights的API,发现有skip_mismatch形参。

-
因此想到查看我所安装的keras是否是所需的版本。在 查看Keras版本号 中发现,可以直接import keras,而不用通过tensorflow.keras使用keras,因此查看了两者的版本。结果如下
发现tensorflow使用的keras是老版本的2.1.6-tf,而我安装的keras版本是正确的。也就是说,我正常安装或升级了keras,但是因为我使用的是tensorflow内置的、老版本的keras,导致我一直没能更新load_weights的API。当我进入到keras的库中,找到load_weights()函数,印证了我的想法。
-
之所以我使用的是tensorflow内置keras的原因在于,我创建模型类的时候,继承的是tensorflow内置的keras,而不是直接继承keras。如下图

-
综上分析来看,
- 如果模型类继承自keras.Model,那么升级keras能解决问题;
- 如果继承自tensorflow.keras.Model,那么升级keras解决不了问题。要么升级tensorflow,要么模型类转变为继承keras.Model。
本文探讨了解决Keras在加载模型权重时遇到的TypeError问题,特别是关于未定义的'skip_mismatch'关键字。分析了问题根源在于使用的Keras版本与TensorFlow内置Keras版本不一致,提供了升级Keras或TensorFlow的解决方案。
3892

被折叠的 条评论
为什么被折叠?



