一学一做看视频网站有哪些内容,石家庄微信网站建设,opencms wordpress,农产品交易平台前言
我在上一篇文章中介绍了 RNN#xff0c;它是一个隐变量模型#xff0c;主要通过隐藏状态连接时间序列#xff0c;实现了序列信息的记忆与建模。然而#xff0c;RNN在实践中面临严重的“梯度消失”与“长期依赖建模困难”问题#xff1a; 难以捕捉相隔很远的时间步之…前言
我在上一篇文章中介绍了 RNN它是一个隐变量模型主要通过隐藏状态连接时间序列实现了序列信息的记忆与建模。然而RNN在实践中面临严重的“梯度消失”与“长期依赖建模困难”问题 难以捕捉相隔很远的时间步之间的关系隐状态在不断更新中容易遗忘早期信息。 为了解决这些问题LSTMLong Short-Term Memory 网络于 1997 年被 Hochreiter等人提出该模型是对RNN的一次重大改进。 一、LSTM相比RNN的核心改进
接下来我们通过对比RNN、LSTM来看一下具体的改进
模型特点优势缺点RNN单一隐藏转态时间步传递结构简答容易造成梯度消失/爆炸对长期依赖差LSTM多门控机制 单独的“记忆单元”解决长距离依赖问题保留长期信息结构复杂计算开销大
通过对比我们可以发现其实LSTM的核心思想是引入了一个专门的“记忆单元”在通过门控机制对信息进行有选择的保留、遗忘与更新。 二、LSTM的核心结构
LSTM的核心结构如下图所示 如图可以轻松的看出LSTM主要由门控机制和候选记忆单元组成对于每个时间步LSTM都会进行以下操作
1. 忘记门
忘记门主要的作用是控制保留多少之前的记忆 2. 输入门
输入门主要的作用是决定当前输入中哪些信息信息被写入记忆 3. 候选记忆单元 4. 输出门
输出门的作用是决定是是否使用隐状态 5. 真正记忆单元
记忆单元 用于长期存储信息解决RNN容易遗忘的问题 7. 隐藏转态 LSTM引入了专门的记忆单元 长期存储信息解决了传统RNN容易遗忘的问题。 三、手写LSTM
通过上面的介绍我们现在已经知道了LSTM的实现原理现在我们试着手写一个LSTM核心层
首先初始化需要训练的参数
import torch
import torch.nn as nn
import torch.nn.functional as Fdef params(input_size, output_size, hidden_size):W_xi, W_hi, b_i torch.randn(input_size, hidden_size) * 0.1, torch.randn(hidden_size, hidden_size) * 0.1, torch.zeros(hidden_size)W_xf, W_hf, b_f torch.randn(input_size, hidden_size) * 0.1, torch.randn(hidden_size, hidden_size) * 0.1, torch.zeros(hidden_size)W_xo, W_ho, b_o torch.randn(input_size, hidden_size) * 0.1, torch.randn(hidden_size, hidden_size) * 0.1, torch.zeros(hidden_size)W_xc, W_hc, b_c torch.randn(input_size, hidden_size) * 0.1, torch.randn(hidden_size, hidden_size) * 0.1, torch.zeros(hidden_size)W_hq torch.randn(hidden_size, output_size) * 0.1b_q torch.zeros(output_size)params [W_xi, W_hi, b_i, W_xf, W_hf, b_f, W_xo, W_ho, b_o, W_xc, W_hc, b_c, W_hq, b_q]for param in params:param.requires_grad Truereturn params接着我们需要初始化0时刻的隐藏转态
import torchdef init_state(batch_size, hidden_size):return (torch.zeros((batch_size, hidden_size)), torch.zeros((batch_size, hidden_size)))然后 就是LSTM的核心操作
import torch
import torch.nn as nn
def lstm(X, state, params):[W_xi, W_hi, b_i, W_xf, W_hf, b_f, W_xo, W_ho, b_o, W_xc, W_hc, b_c, W_hq, b_q] params(H, C) stateoutputs []for x in X:I torch.sigmoid(torch.mm(x, W_xi) torch.mm(H, W_hi) b_i)F torch.sigmoid(torch.mm(x, W_xf) torch.mm(H, W_hf) b_f)O torch.sigmoid(torch.mm(x, W_xo) torch.mm(H, W_ho) b_o)C_tilde torch.tanh(torch.mm(x, W_xc) torch.mm(H, W_hc) b_c)C F * C I * C_tildeH O * torch.tanh(C)Y torch.mm(H, W_hq) b_qoutputs.append(Y)return torch.cat(outputs, dim1), (H, C) 四、使用Pytroch实现简单的LSTM
在Pytroch中已经内置了lstm函数我们只需要调用就可以实现上述操作
import torch
import torch.nn as nnclass mylstm(nn.Module):def __init__(self, input_size, output_size, hidden_size):super(mylstm, self).__init__()self.lstm nn.LSTM(input_size, hidden_size, batch_firstTrue)self.fc nn.Linear(hidden_size, output_size)def forward(self, x, h0, c0):out, (hn, cn) self.lstm(x, h0, c0)out self.fc(out)return out, (hn, cn)# 示例
input_size 10
hidden_size 20
output_size 10
batch_size 1
seq_len 5
num_layer 1 # lstm堆叠层数h0 torch.zeros(num_layer, batch_size, hidden_size)
c0 torch.randn(num_layer, batch_size, hidden_size)
x torch.randn(batch_size, seq_len, hidden_size)model mylstm(input_sizeinput_size, hidden_sizehidden_size, output_sizeoutput_size)out, _ model(x, (h0, c0))
print(out.shape) 总结
在现实中LSTM的实际应用场景很多比如语言模型、文本生成、时间序列预测、情感分析等长序列任务重这是因为相比于RNN而言LSTM能够更高地捕捉长期依赖而且也更好的缓解了梯度消失问题但是由于LSTM引入了三个门控机制导致参数量比RNN要多训练慢。
总的来说LSTM是对传统RNN的一次革命性升级引入门控机制和记忆单元使模型能够选择性地记忆与遗忘从而有效地捕捉长距离依赖。尽管LSTM近年来Transformer所取代但LSTM依然是理解深度学习序列模型不可绕开的一环有时在其他任务上甚至优于Transformer。 如果小伙伴们觉得本文对各位有帮助欢迎点赞 | ⭐ 收藏 | 关注。我将持续在专栏《人工智能》中更新人工智能知识帮助各位小伙伴们打好扎实的理论与操作基础欢迎订阅本专栏向AI工程师进阶