长短期记忆网络(LSTM)基本原理详解
一、LSTM核心思想
目标:解决传统RNN的梯度消失/爆炸问题,显式建模长期依赖关系
核心创新:引入细胞状态(Cell State)和门控机制,通过三个门结构精确控制信息流动


二、网络结构分解
1. 核心组件(四个关键部分)
| 组件 | 符号 | 功能描述 |
|---|---|---|
| 遗忘门 | ftf_tft | 决定从细胞状态中丢弃哪些信息 |
| 输入门 | iti_tit | 确定新信息存入细胞状态的比例 |
| 候选值 | C~t\tilde{C}_tC~t | 生成待存入细胞状态的新候选值 |
| 输出门 | oto_tot | 控制细胞状态到隐藏状态的输出比例 |
2. 数学公式推导
遗忘门(Forget Gate)
ft=σ(Wf⋅[ht−1,xt]+bf) f_t = \sigma(W_f \cdot [h_{t-1}, x_t] + b_f) ft=σ(Wf⋅[ht−1,xt]+bf)
- σ\sigmaσ: Sigmoid函数(输出0-1间的遗忘比例)
输入门(Input Gate)
it=σ(Wi⋅[ht−1,xt]+bi) i_t = \sigma(W_i \cdot [h_{t-1}, x_t] + b_i) it=σ(Wi⋅[ht−1,xt]+bi)
候选细胞状态
C~t=tanh(WC⋅[ht−1,xt]+bC) \tilde{C}_t = \tanh(W_C \cdot [h_{t-1}, x_t] + b_C) C~t=tanh(WC⋅[ht−1,xt]+bC)
细胞状态更新
Ct=ft⊙Ct−1+it⊙C~t C_t = f_t \odot C_{t-1} + i_t \odot \tilde{C}_t Ct=ft⊙Ct−1+it⊙C~t
- ⊙\odot⊙: Hadamard积(逐元素相乘)
输出门(Output Gate)
ot=σ(Wo⋅[ht−1,xt]+bo) o_t = \sigma(W_o \cdot [h_{t-1}, x_t] + b_o) ot=σ(Wo⋅[ht−1,xt]+bo)
隐藏状态计算
ht=ot⊙tanh(Ct) h_t = o_t \odot \tanh(C_t) ht=ot⊙tanh(Ct)
三、PyTorch实现
1. LSTM单元实现
import torch
import torch.nn as nn
class LSTMCell(nn.Module):
def __init__(self, input_size, hidden_size):
super().__init__()
self.hidden_size = hidden_size
# 合并计算四个门的参数矩阵
self.W = nn.Linear(input_size + hidden_size, 4*hidden_size)
def forward(self, x, state):
# state = (h, c)
h_prev, c_prev = state
# 合并输入与隐藏状态
combined = torch.cat((x, h_prev), dim=1)
gates = self.W(combined)
# 分割四个门计算结果
f, i, o, g = torch.split(gates, self.hidden_size, dim=1)
# 激活函数应用
f = torch.sigmoid(f) # 遗忘门
i = torch.sigmoid(i) # 输入门
o = torch.sigmoid(o) # 输出门
g = torch.tanh(g) # 候选值
# 更新细胞状态
c = f * c_prev + i * g
# 更新隐藏状态
h = o * torch.tanh(c)
return (h, c)
5万+

被折叠的 条评论
为什么被折叠?



