Module类内置了很多函数,其中本文主要介绍常用的属性设置函数,包括向module添加参数的register_parameter(),register_buffer()。官方文档如下:Module — PyTorch 1.7.0 documentation
这两种方法均可以往模型中额外添加参数。不同的是register_parameter() 添加的参数在模型训练时可以正常更新,但register_buffer()则不进行参数更新。在保存模型时,两者的参数都会保存。
示例1:分别使用以下四种形式来定义参数,查看网络更新的参数
- 常用的 nn.Sequential() ,即nn.Module子类的形式
- 使用 register_buffer() 定义参数
- 使用register_parameter() 定义参数
- 使用 Net 类的自定义属性定义参数
import torch
import torch.nn as nn
import torch.nn.functional as F
class Net(nn.Module):
def __init__(self):
super().__init__()
# 使用 nn.Sequential定义参数
self.features

本文介绍了PyTorch中Module类的register_parameter()和register_buffer()函数,用于向模型添加参数。register_parameter()添加的参数在训练时可更新且保存在模型状态字典中,而register_buffer()添加的参数不参与训练但也会被保存。普通类属性定义的参数既不参与训练也不被保存。
7810

被折叠的 条评论
为什么被折叠?



