文章目录
1. 概述
- 不是每个观察值都是同等重要
- 想只记住相关的观察需要
(1)能关注的机制(更新门)
(2)能遗忘的机制(重置门)
2. GRU关键组件
2.1 门
- 重置门 (Reset-Gate)
R t = σ ( X t W x r + H t − 1 W h r + b r ) (1) R_t = \sigma(X_tW_{xr}+H_{t-1}W_{hr}+b_r)\tag1 Rt=σ(XtWxr+Ht−1Whr+br)(1) - 更新门(Update-Gate)
Z t = σ ( X t W x z + H t − 1 W h z + b z ) (2) Z_t = \sigma(X_tW_{xz}+H_{t-1}W_{hz}+b_z)\tag2 Zt=σ(XtWxz+Ht−1Whz+bz)(2)

2.2 候选隐藏状态 H ~ t \widetilde{H}_t H t
H
~
t
=
tanh
(
X
t
W
x
h
+
(
R
t
⊙
H
t
−
1
)
W
h
h
+
b
h
)
(3)
\widetilde{H}_t=\tanh(X_tW_{xh}+(R_t \odot H_{t-1})W_{hh}+b_h)\tag3
H
t=tanh(XtWxh+(Rt⊙Ht−1)Whh+bh)(3)
注:因为
0
≤
R
t
≤
1
0\leq R_t \leq 1
0≤Rt≤1;所以我们可以通过
R
t
R_t
Rt来表示
H
~
t
\widetilde{H}_t
H
t有数据来自于过去
H
t
−
1
H_{t-1}
Ht−1
R_t=0:表示完全丢弃过去的信息,相当于重置,故为重置门R_t=1:表示完全接受过去的信息,相当于RNN

2.3 隐状态 H t H_t Ht
H
t
=
Z
t
⊙
H
t
−
1
+
(
1
−
Z
t
)
⊙
H
~
t
(4)
H_t=Z_t \odot H_{t-1}+(1-Z_t)\odot \widetilde{H}_t \tag4
Ht=Zt⊙Ht−1+(1−Zt)⊙H
t(4)

2.4 小结
R
t
=
σ
(
X
t
W
x
r
+
H
t
−
1
W
h
r
+
b
r
)
R_t = \sigma(X_tW_{xr}+H_{t-1}W_{hr}+b_r)
Rt=σ(XtWxr+Ht−1Whr+br)
Z
t
=
σ
(
X
t
W
x
z
+
H
t
−
1
W
h
z
+
b
z
)
Z_t = \sigma(X_tW_{xz}+H_{t-1}W_{hz}+b_z)
Zt=σ(XtWxz+Ht−1Whz+bz)
H
~
t
=
tanh
(
X
t
W
x
h
+
(
R
t
⊙
H
t
−
1
)
W
h
h
+
b
h
)
\widetilde{H}_t=\tanh(X_tW_{xh}+(R_t \odot H_{t-1})W_{hh}+b_h)
H
t=tanh(XtWxh+(Rt⊙Ht−1)Whh+bh)
H
t
=
Z
t
⊙
H
t
−
1
+
(
1
−
Z
t
)
⊙
H
~
t
H_t=Z_t \odot H_{t-1}+(1-Z_t)\odot \widetilde{H}_t
Ht=Zt⊙Ht−1+(1−Zt)⊙H
t
- 重置门有助于捕获序列中的短期依赖关系
- 更新门有助于捕获序列中的长期依赖关系
3. 代码
- 源码
# -*- coding: utf-8 -*-
# @Project: zc
# @Author: zc
# @File name: GRU_test
# @Create time: 2022/1/27 16:10
# 1. 导入相关数据库
import matplotlib.pyplot as plt
import torch
from torch import nn
from d2l import torch as d2l
# 2. 定义批量大小和文本长度
batch_size, num_steps = 32, 35
train_iter, vocab = d2l.load_data_time_machine(batch_size, num_steps)
# 3. 设置GRU网络所需的参数
def get_params(vocab_size, num_hiddens, device):
num_inputs = num_outputs = vocab_size
def normal(shape):
return torch.randn(size=shape, device=device) * 0.01
def three():
return (normal((num_inputs, num_hiddens)),
normal((num_hiddens, num_hiddens)),
torch.zeros(num_hiddens, device=device))
W_xz, W_hz, b_z = three() # 更新门参数
W_xr, W_hr, b_r = three() # 重置门参数
W_xh, W_hh, b_h = three() # 候选隐状态参数
# 输出层参数
W_hq = normal((num_hiddens, num_outputs))
b_q = torch.zeros(num_outputs, device=device)
# 附加梯度
params = [W_xz, W_hz, b_z, W_xr, W_hr, b_r, W_xh, W_hh, b_h, W_hq, b_q]
for param in params:
param.requires_grad_(True)
return params
# 4. 初始化 GRU 的隐变量的状态
def init_gru_state(batch_size, num_hiddens, device):
return (torch.zeros((batch_size, num_hiddens), device=device),)
# 5. 自定义 GRU 网络的 forward_fn;如果用到pytoch自带的,则不需要
def gru(inputs, state, params):
# 参数展开
W_xz, W_hz, b_z, W_xr, W_hr, b_r, W_xh, W_hh, b_h, W_hq, b_q = params
# 初始化参数
H, = state
# 输出
outputs = []
for X in inputs:
# update_gate 更新门
Z = torch.sigmoid((X @ W_xz) + (H @ W_hz) + b_z)
# reset_gate 重置门
R = torch.sigmoid((X @ W_xr) + (H @ W_hr) + b_r)
# 候选隐藏状态
H_tilda = torch.tanh((X @ W_xh) + ((R * H) @ W_hh) + b_h)
# 隐藏状态
H = Z * H + (1 - Z) * H_tilda
# 输出 Y
Y = H @ W_hq + b_q
outputs.append(Y)
return torch.cat(outputs, dim=0), (H,)
vocab_size, num_hiddens, device = len(vocab), 256, d2l.try_gpu()
num_epochs, lr = 500, 1
num_inputs = vocab_size
gru_layer = nn.GRU(num_inputs, num_hiddens)
model = d2l.RNNModel(gru_layer, len(vocab))
model = model.to(device)
d2l.train_ch8(model, train_iter, vocab, lr, num_epochs, device)
plt.show()
- 结果
perplexity 1.0, 408368.9 tokens/sec on cuda:0
time traveller fired as we lat said ally can andither at seare y
traveller sith a slight accession ofcheerfulness really thi

4. torch.nn.GRU
class torch.nn.GRU(*args, **kwargs)
-
说明:
将多层门控循环单元(GRU) RNN应用于输入序列。对于输入序列中的每个元素,每个层计算以下函数

-
input_size:输入x中预期特征的数量 -
hidden_size:处于隐藏状态h的特性数 -
num_layers:循环层数 -
代码说明:
import torch
from torch import nn
# GRU参数来说 input_size=10,output_size=20,layers=2
# 所以输入的最后一个维度必须是10,即 input[-1]=10
# nn.GRU(input_size=10,hidden_size=20,num_layers=2)
# L = sequence_length;N=Batch_size;H_in=input_size;H_out=hidden_size
# input =(L=8,N=6,H_in=10)
# h0=(2,N=6,H_out=20)
# output = (L=8,N=6,H_out=20)
# h_n=(2,N=6,H_out=20)
rnn = nn.GRU(10, 20, 2)
input = torch.randn(8, 6, 10)
h0 = torch.randn(2, 6, 20)
output, hn = rnn(input, h0)
print(f"input.shape={input.shape}")
print(f"h0.shape={h0.shape}")
print(f"output.shape={output.shape}")
print(f"hn.shape={hn.shape}")
print(f"rnn={rnn}")
input.shape=torch.Size([8, 6, 10])
h0.shape=torch.Size([2, 6, 20])
output.shape=torch.Size([8, 6, 20])
hn.shape=torch.Size([2, 6, 20])
rnn=GRU(10, 20, num_layers=2)
7048

被折叠的 条评论
为什么被折叠?



