原因:网络为多头输出,input[0]的内容为list(第一个输出、第二个输出、第三个输出...)。i
修改./lib/python3.7/site-packages/torchsummary/torchsummary.py中的代码
summary[m_key]["input_shape"].append(x.shape)
修改为如下所示:
if isinstance(input[0],list):
summary[m_key]["input_shape"] = []
for x in input[0]:
summary[m_key]["input_shape"].append(x.shape)
else:
summary[m_key]["input_shape"] = list(input[0].size())
文章讲述了在处理网络的多头输出时,对torchsummary库中`torchsummary.py`文件的`input_shape`记录进行调整的方法,当input[0]是list时,使用循环添加每个tensor的shape,否则直接存储单个tensor的大小。
4万+

被折叠的 条评论
为什么被折叠?



