GRU门控循环单元读书笔记

1. 概述

  • 不是每个观察值都是同等重要
  • 想只记住相关的观察需要
    (1)能关注的机制(更新门)
    (2)能遗忘的机制(重置门)

2. GRU关键组件

2.1 门

  • 重置门 (Reset-Gate)
    R t = σ ( X t W x r + H t − 1 W h r + b r ) (1) R_t = \sigma(X_tW_{xr}+H_{t-1}W_{hr}+b_r)\tag1 Rt=σ(XtWxr+Ht1Whr+br)(1)
  • 更新门(Update-Gate)
    Z t = σ ( X t W x z + H t − 1 W h z + b z ) (2) Z_t = \sigma(X_tW_{xz}+H_{t-1}W_{hz}+b_z)\tag2 Zt=σ(XtWxz+Ht1Whz+bz)(2)
    在这里插入图片描述

2.2 候选隐藏状态 H ~ t \widetilde{H}_t H t

H ~ t = tanh ⁡ ( X t W x h + ( R t ⊙ H t − 1 ) W h h + b h ) (3) \widetilde{H}_t=\tanh(X_tW_{xh}+(R_t \odot H_{t-1})W_{hh}+b_h)\tag3 H t=tanh(XtWxh+(RtHt1)Whh+bh)(3)
注:因为 0 ≤ R t ≤ 1 0\leq R_t \leq 1 0Rt1;所以我们可以通过 R t R_t Rt来表示 H ~ t \widetilde{H}_t H t有数据来自于过去 H t − 1 H_{t-1} Ht1

  • R_t=0:表示完全丢弃过去的信息,相当于重置,故为重置门
  • R_t=1:表示完全接受过去的信息,相当于RNN
    在这里插入图片描述

2.3 隐状态 H t H_t Ht

H t = Z t ⊙ H t − 1 + ( 1 − Z t ) ⊙ H ~ t (4) H_t=Z_t \odot H_{t-1}+(1-Z_t)\odot \widetilde{H}_t \tag4 Ht=ZtHt1+(1Zt)H t(4)
在这里插入图片描述

2.4 小结

R t = σ ( X t W x r + H t − 1 W h r + b r ) R_t = \sigma(X_tW_{xr}+H_{t-1}W_{hr}+b_r) Rt=σ(XtWxr+Ht1Whr+br)
Z t = σ ( X t W x z + H t − 1 W h z + b z ) Z_t = \sigma(X_tW_{xz}+H_{t-1}W_{hz}+b_z) Zt=σ(XtWxz+Ht1Whz+bz)
H ~ t = tanh ⁡ ( X t W x h + ( R t ⊙ H t − 1 ) W h h + b h ) \widetilde{H}_t=\tanh(X_tW_{xh}+(R_t \odot H_{t-1})W_{hh}+b_h) H t=tanh(XtWxh+(RtHt1)Whh+bh)
H t = Z t ⊙ H t − 1 + ( 1 − Z t ) ⊙ H ~ t H_t=Z_t \odot H_{t-1}+(1-Z_t)\odot \widetilde{H}_t Ht=ZtHt1+(1Zt)H t

  • 重置门有助于捕获序列中的短期依赖关系
  • 更新门有助于捕获序列中的长期依赖关系

3. 代码

  • 源码
# -*- coding: utf-8 -*-
# @Project: zc
# @Author: zc
# @File name: GRU_test
# @Create time: 2022/1/27 16:10

# 1. 导入相关数据库
import matplotlib.pyplot as plt
import torch
from torch import nn
from d2l import torch as d2l

# 2. 定义批量大小和文本长度
batch_size, num_steps = 32, 35
train_iter, vocab = d2l.load_data_time_machine(batch_size, num_steps)


# 3. 设置GRU网络所需的参数
def get_params(vocab_size, num_hiddens, device):
	num_inputs = num_outputs = vocab_size

	def normal(shape):
		return torch.randn(size=shape, device=device) * 0.01

	def three():
		return (normal((num_inputs, num_hiddens)),
				normal((num_hiddens, num_hiddens)),
				torch.zeros(num_hiddens, device=device))

	W_xz, W_hz, b_z = three()  # 更新门参数
	W_xr, W_hr, b_r = three()  # 重置门参数
	W_xh, W_hh, b_h = three()  # 候选隐状态参数
	# 输出层参数
	W_hq = normal((num_hiddens, num_outputs))
	b_q = torch.zeros(num_outputs, device=device)
	# 附加梯度
	params = [W_xz, W_hz, b_z, W_xr, W_hr, b_r, W_xh, W_hh, b_h, W_hq, b_q]
	for param in params:
		param.requires_grad_(True)
	return params


# 4. 初始化 GRU 的隐变量的状态
def init_gru_state(batch_size, num_hiddens, device):
	return (torch.zeros((batch_size, num_hiddens), device=device),)


# 5. 自定义 GRU 网络的 forward_fn;如果用到pytoch自带的,则不需要
def gru(inputs, state, params):
	# 参数展开
	W_xz, W_hz, b_z, W_xr, W_hr, b_r, W_xh, W_hh, b_h, W_hq, b_q = params
	# 初始化参数
	H, = state
	# 输出
	outputs = []
	for X in inputs:
		# update_gate 更新门
		Z = torch.sigmoid((X @ W_xz) + (H @ W_hz) + b_z)
		# reset_gate 重置门
		R = torch.sigmoid((X @ W_xr) + (H @ W_hr) + b_r)
		# 候选隐藏状态
		H_tilda = torch.tanh((X @ W_xh) + ((R * H) @ W_hh) + b_h)
		# 隐藏状态
		H = Z * H + (1 - Z) * H_tilda
		# 输出 Y
		Y = H @ W_hq + b_q
		outputs.append(Y)
	return torch.cat(outputs, dim=0), (H,)


vocab_size, num_hiddens, device = len(vocab), 256, d2l.try_gpu()
num_epochs, lr = 500, 1

num_inputs = vocab_size
gru_layer = nn.GRU(num_inputs, num_hiddens)
model = d2l.RNNModel(gru_layer, len(vocab))
model = model.to(device)
d2l.train_ch8(model, train_iter, vocab, lr, num_epochs, device)

plt.show()
  • 结果
perplexity 1.0, 408368.9 tokens/sec on cuda:0
time traveller fired as we lat said ally can andither at seare y
traveller sith a slight accession ofcheerfulness really thi

在这里插入图片描述

4. torch.nn.GRU

class torch.nn.GRU(*args, **kwargs)
  • 说明:
    将多层门控循环单元(GRU) RNN应用于输入序列。对于输入序列中的每个元素,每个层计算以下函数
    在这里插入图片描述

  • input_size:输入x中预期特征的数量

  • hidden_size:处于隐藏状态h的特性数

  • num_layers:循环层数

  • 代码说明:

import torch
from torch import nn

# GRU参数来说 input_size=10,output_size=20,layers=2
# 所以输入的最后一个维度必须是10,即 input[-1]=10
# nn.GRU(input_size=10,hidden_size=20,num_layers=2)
# L = sequence_length;N=Batch_size;H_in=input_size;H_out=hidden_size
# input =(L=8,N=6,H_in=10)
# h0=(2,N=6,H_out=20)
# output = (L=8,N=6,H_out=20)
# h_n=(2,N=6,H_out=20)
rnn = nn.GRU(10, 20, 2)
input = torch.randn(8, 6, 10)
h0 = torch.randn(2, 6, 20)
output, hn = rnn(input, h0)
print(f"input.shape={input.shape}")
print(f"h0.shape={h0.shape}")
print(f"output.shape={output.shape}")
print(f"hn.shape={hn.shape}")
print(f"rnn={rnn}")
input.shape=torch.Size([8, 6, 10])
h0.shape=torch.Size([2, 6, 20])
output.shape=torch.Size([8, 6, 20])
hn.shape=torch.Size([2, 6, 20])
rnn=GRU(10, 20, num_layers=2)
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值