当前位置: 首页 > news >正文

LSTM每个变量的shape分析

首先要理解一点,LSTM中迭代不光是要遍历每个批量,在处理每个批量时,还需要以时间步来推进,这样才能体现时间上的关系,才能得到state。

如下是权重的初始化以及RNN的model定义:

 

def get_lstm_params(vocab_size, num_hiddens, device):num_inputs = num_outputs = vocab_sizedef normal(shape):return torch.randn(size=shape, device=device)*0.01def three():return (normal((num_inputs, num_hiddens)),normal((num_hiddens, num_hiddens)),torch.zeros(num_hiddens, device=device))W_xi, W_hi, b_i = three()  # 输入门参数W_xf, W_hf, b_f = three()  # 遗忘门参数W_xo, W_ho, b_o = three()  # 输出门参数W_xc, W_hc, b_c = three()  # 候选记忆元参数# 输出层参数W_hq = normal((num_hiddens, num_outputs))b_q = torch.zeros(num_outputs, device=device)# 附加梯度params = [W_xi, W_hi, b_i, W_xf, W_hf, b_f, W_xo, W_ho, b_o, W_xc, W_hc,b_c, W_hq, b_q]for param in params:param.requires_grad_(True)return params
class RNNModel(nn.Module):"""The RNN model."""def __init__(self, rnn_layer, vocab_size, **kwargs):super(RNNModel, self).__init__(**kwargs)self.rnn = rnn_layerself.vocab_size = vocab_sizeself.num_hiddens = self.rnn.hidden_size# If the RNN is bidirectional (to be introduced later),# `num_directions` should be 2, else it should be 1.if not self.rnn.bidirectional:self.num_directions = 1self.linear = nn.Linear(self.num_hiddens, self.vocab_size)else:self.num_directions = 2self.linear = nn.Linear(self.num_hiddens * 2, self.vocab_size)def forward(self, inputs, state):X = F.one_hot(inputs.T.long(), self.vocab_size)X = X.to(torch.float32)Y, state = self.rnn(X, state)# The fully connected layer will first change the shape of `Y` to# (`num_steps` * `batch_size`, `num_hiddens`). Its output shape is# (`num_steps` * `batch_size`, `vocab_size`).output = self.linear(Y.reshape((-1, Y.shape[-1])))return output, statedef begin_state(self, device, batch_size=1):if not isinstance(self.rnn, nn.LSTM):# `nn.GRU` takes a tensor as hidden statereturn torch.zeros((self.num_directions * self.rnn.num_layers,batch_size, self.num_hiddens), device=device)else:# `nn.LSTM` takes a tuple of hidden statesreturn (torch.zeros((self.num_directions * self.rnn.num_layers,batch_size, self.num_hiddens),device=device),torch.zeros((self.num_directions * self.rnn.num_layers,batch_size, self.num_hiddens),device=device))

可以看到forward函数第一行代码,会将输入进行转置,然后进行独热编码,使得输入input的shape变化:[batch_size,seq_len,1]->[seq_len,batch_size,vocab_size]

 接下来的输入会进入lstm层,定义如下:

def lstm(inputs, state, params):[W_xi, W_hi, b_i, W_xf, W_hf, b_f, W_xo, W_ho, b_o, W_xc, W_hc, b_c,W_hq, b_q] = params(H, C) = stateoutputs = []for X in inputs:I = torch.sigmoid((X @ W_xi) + (H @ W_hi) + b_i)F = torch.sigmoid((X @ W_xf) + (H @ W_hf) + b_f)O = torch.sigmoid((X @ W_xo) + (H @ W_ho) + b_o)C_tilda = torch.tanh((X @ W_xc) + (H @ W_hc) + b_c)C = F * C + I * C_tildaH = O * torch.tanh(C)#这里可以看出C与H的size是相同的Y = (H @ W_hq) + b_qoutputs.append(Y)return torch.cat(outputs, dim=0), (H, C)

 这里有个迭代:for X in inputs,就是在迭代时间步,每次迭代就是一个时间步,因此此时的X的shape是[batch_size,vocab_size]

而W_xi的shape是[vocab_size,num_hiddens],故得到的I的shape应该是[batch_sizenum_hiddens]的,其他变量的分析也同理

http://www.lryc.cn/news/575821.html

相关文章:

  • 从输入到路径:AI赋能的地图语义解析与可视化探索之旅
  • 通过ETL从MySQL同步到GaussDB
  • 喜讯 | Mediatom斩获2025第十三届TopDigital创新营销奖「年度程序化广告平台」殊荣
  • LINUX625 DNS反向解析
  • 基于 Spring Boot + Vue 3的现代化社区团购系统
  • 科技如何影响我们的生活?
  • 工业电子 | 什么是SerDes,为何工业和汽车应用需要它?
  • HarmonyOS NEXT仓颉开发语言实战案例:简约音乐播放页
  • 金蝶云星空客户端自定义控件插件-WPF实现自定义控件
  • 使用Docker部署mysql8
  • 社会工程--如何使用对方的语言
  • JDBC入门:Java连接数据库全指南
  • AI辅助编写前端VUE应用流程
  • 树状dp(dfs)(一道挺基础的)
  • Spring Boot 项目问题:while constructing a mapping found duplicate key api
  • 微信小程序封装loading 修改
  • 常见网络安全威胁和防御措施
  • 智能实验室革命:Deepoc大模型驱动全自动化科研新生态
  • HTML简介,初步了解HTML
  • SQl中多使用EXISTS导致多查出了一条不符合条件的数据
  • 教程 | 一键批量下载 Dify「Markdown 转 Docx」生成的 Word 文件(附源码)
  • 【Linux】基础开发工具(2)
  • 操作系统面试知识点(1):操作系统基础
  • CyberGlove触觉反馈手套遥操作机器人灵巧手解决方案
  • Kotlin环境搭建与基础语法入门
  • 大厂测开实习和小厂开发实习怎么选
  • 华为云鸿蒙应用入门级开发者认证 实验(HCCDA-HarmonyOS Cloud Apps)
  • linux网络编程socket套接字
  • mysql无法启动的数据库迁移
  • WebSocket 与 HTTP 的区别及 Spring Boot 实战应用