1.知识点
- GRU只有两个门--重置门、更新门
- 初始状态只需要提供h0
- ht里“ * ”这个符号,代表着逐个元素相乘,不是矩阵相乘;其他是矩阵相乘

-
2.使用pytorch看参数个数
-
import torch import torch.nn as nn lstm_layer = nn.LSTM(3,5) #输入特征为3,隐含特征为5的特征量 gru_layer = nn.GRU(3,5) #同样的隐含大小 sum(p.numel() for p in lstm_layer.paramenters()) #计算lstm总参数量,调用paramenters()函数,在对其进行参数枚举(p代表参数),p.numel()计算每个参数p的所有元素进行统计----》200 sum(p.numel() for p in gru_layer.paramenters()) ----->150 ## GRU的参数量是LSTM的0.753.GRU网络代码
-
#准备工作 def gru_forward(input, initial_states, w_ih, w_hh, b_ih, b_hh): #定义函数,前向运算 initial_states提供初始状态 w权重--表示大的矩阵 b_ih偏置项 prev_h = initial_states #h t=0时刻的初始值 bs, T, i_size = i

文章详细解释了GRU(门控循环单元)的结构,强调它只有重置门和更新门,对比了GRU与LSTM的参数量,GRU参数量为LSTM的75%。通过PyTorch展示了GRU的前向传播过程,并提供了自定义GRU函数的实现,最后通过官方API与自定义函数的输出比较验证了代码的正确性。
2万+

被折叠的 条评论
为什么被折叠?



