在解决TypeError: load_weights() got an unexpected keyword argument ‘skip_mismatch‘中的一点小提示

本文探讨了解决Keras在加载模型权重时遇到的TypeError问题,特别是关于未定义的'skip_mismatch'关键字。分析了问题根源在于使用的Keras版本与TensorFlow内置Keras版本不一致,提供了升级Keras或TensorFlow的解决方案。

结论

对于 升级keras解决load_weights()中的未定义skip_mismatch关键字问题 中通过pip安装更高版本keras来解决TypeError: load_weights() got an unexpected keyword argument 'skip_mismatch’问题,首先需要确定model本身是通过keras API构建(模型继承于 keras.Model),还是通过tensorflow内置keras构建(模型继承于 tf.keras.Model)。
如果是前者,那么升级keras,可以解决。如果是后者,升级keras没有用,需要升级tensorflow,因为内置keras的版本与tensorflow的版本相关联。

如果不想使用tensorflow内置keras,可以考虑将tf.keras.Model改为keras.Model,同时修改相应的继承函数,例如call()转为__call__()。(这个方法没尝试成功,只是建议可以朝这个方向思考一下。)


项目场景:

  • 环境配置(如图):环境配置这里主要是说明keras和tensorflow-keras之间的关系,所以只放这两个即可。
  • 使用tensorflow训练模型,并通过tf.keras.Model.save_weights(weight_filepath)保存模型权重。

问题描述

  1. 执行如下代码加载模型权重。
target_rnn = TargetRNN(args)
target_rnn.load_weights(weight_filepath
                       )
  1. 出现如下错误:valueError,尝试加载一个权重文件.....
  2. 在网上看了一些解决方法,有人推荐使用参数by_name=True, skip_mismatch=True两个参数。(写文章的时候没找到原文。。。)代码实现变为如下形式。
target_rnn = TargetRNN(args)
target_rnn.load_weights(weight_filepath
                        by_name=True,  
                        skip_mismatch=True)
  1. 进而产生新的错误:_weights()没有skip_mismatch属性
  2. 找到 升级keras解决load_weights()中的未定义skip_mismatch关键字问题 ,按照步骤升级keras,没能解决我的问题。

原因分析:

  1. 既然显示load_weights() got an unexpected keyword argument 'skip_mismatch’的错误,说明参数没传对。直接查看函数定义,确实没有skip_mismatch的形参。tensorflow中函数定义这里面有个细节,即keras的文件路径是在tensorflow下的,这说明keras是tensorflow内置的(当时没注意,但这一点是问题所在)。

  2. 连续尝试升级了keras 2.2.4、2.2.5、2.3.1,上图的load_weights函数中始终没有skip_mismatch的形参。说明 升级keras版本 并不能完全解决这个问题。

  3. 去keras官网搜索load_weights的API,发现有skip_mismatch形参。
    请添加图片描述

  4. 因此想到查看我所安装的keras是否是所需的版本。在 查看Keras版本号 中发现,可以直接import keras,而不用通过tensorflow.keras使用keras,因此查看了两者的版本。结果如下查看版本号发现tensorflow使用的keras是老版本的2.1.6-tf,而我安装的keras版本是正确的。也就是说,我正常安装或升级了keras,但是因为我使用的是tensorflow内置的、老版本的keras,导致我一直没能更新load_weights的API。当我进入到keras的库中,找到load_weights()函数,印证了我的想法。keras的load_weights

  5. 之所以我使用的是tensorflow内置keras的原因在于,我创建模型类的时候,继承的是tensorflow内置的keras,而不是直接继承keras。如下图请添加图片描述

  6. 综上分析来看,

    • 如果模型类继承自keras.Model,那么升级keras能解决问题;
    • 如果继承自tensorflow.keras.Model,那么升级keras解决不了问题。要么升级tensorflow,要么模型类转变为继承keras.Model。
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值