当前位置: 首页 > news >正文

rnn 和lstm源码学习笔记

目录

rnn学习笔记

lstm学习笔记


rnn学习笔记


import torchdef rnn(inputs, state, params):# inputs的形状: (时间步数量, 批次大小, 词表大小)W_xh, W_hh, b_h, W_hq, b_q = paramsH = stateoutputs = []# 遍历每个时间步for X in inputs:# 计算隐藏状态 HH = torch.tanh(torch.mm(X, W_xh) + torch.mm(H, W_hh) + b_h)# 计算输出 YY = torch.mm(H, W_hq) + b_qoutputs.append(Y)# 返回输出和新的隐藏状态return torch.cat(outputs, dim=0), (H,)# 参数示例初始化(根据实际情况调整)
input_size = 10  # 词表大小
hidden_size = 20  # 隐藏层大小
output_size = 5  # 输出大小# 初始化参数
W_xh = torch.randn(input_size, hidden_size)
W_hh = torch.randn(hidden_size, hidden_size)
b_h = torch.randn(hidden_size)
W_hq = torch.randn(hidden_size, output_size)
b_q = torch.randn(output_size)params = (W_xh, W_hh, b_h, W_hq, b_q)
state = (torch.zeros(4,hidden_size))# 输入示例
time_steps = 3
batch_size = 4
inputs = torch.randn(time_steps, batch_size, input_size)# 调用RNN函数
outputs, new_state = rnn(inputs, state, params)
print(outputs)
print(new_state)

lstm学习笔记

import torch
import torch.nn as nndef lstm(inputs, state, params):# inputs的形状: (时间步数量, 批次大小, 词表大小)W_xi, W_hi, b_i, W_xf, W_hf, b_f, W_xo, W_ho, b_o, W_xc, W_hc, b_c, W_hq, b_q = params(H, C) = stateoutputs = []# 遍历每个时间步for X in inputs:I = torch.sigmoid(torch.mm(X, W_xi) + torch.mm(H, W_hi) + b_i)F = torch.sigmoid(torch.mm(X, W_xf) + torch.mm(H, W_hf) + b_f)O = torch.sigmoid(torch.mm(X, W_xo) + torch.mm(H, W_ho) + b_o)C_tilda = torch.tanh(torch.mm(X, W_xc) + torch.mm(H, W_hc) + b_c)C = F * C + I * C_tildaH = O * torch.tanh(C)Y = torch.mm(H, W_hq) + b_qoutputs.append(Y)return torch.cat(outputs, dim=0), (H, C)# 参数示例初始化(根据实际情况调整)
input_size = 10  # 词表大小
hidden_size = 20  # 隐藏层大小
output_size = 5  # 输出大小# 初始化参数
W_xi = torch.randn(input_size, hidden_size)
W_hi = torch.randn(hidden_size, hidden_size)
b_i = torch.zeros(hidden_size)
W_xf = torch.randn(input_size, hidden_size)
W_hf = torch.randn(hidden_size, hidden_size)
b_f = torch.zeros(hidden_size)
W_xo = torch.randn(input_size, hidden_size)
W_ho = torch.randn(hidden_size, hidden_size)
b_o = torch.zeros(hidden_size)
W_xc = torch.randn(input_size, hidden_size)
W_hc = torch.randn(hidden_size, hidden_size)
b_c = torch.zeros(hidden_size)
W_hq = torch.randn(hidden_size, output_size)
b_q = torch.zeros(output_size)# 输入示例
time_steps = 3
batch_size = 4
inputs = torch.randn(time_steps, batch_size, input_size)params = (W_xi, W_hi, b_i, W_xf, W_hf, b_f, W_xo, W_ho, b_o, W_xc, W_hc, b_c, W_hq, b_q)
state = (torch.zeros(batch_size, hidden_size), torch.zeros(batch_size, hidden_size))  # 初始隐藏状态和单元状态# 调用LSTM函数
outputs, new_state = lstm(inputs, state, params)
print(outputs)
print(new_state)

http://www.lryc.cn/news/358878.html

相关文章:

  • 解析Java中1000个常用类:CharSequence类,你学会了吗?
  • 微服务远程调用之拦截器实战
  • 德人合科技——天锐绿盾内网安全管理软件 | -文档透明加密模块
  • 超融合架构下,虚拟机高可用机制如何构建?
  • 工厂模式详情
  • 【Word】调整列表符号与后续文本的间距
  • 匠心独运,B 端系统 UI 演绎华章之美
  • Java电商平台-开放API接口签名验证(小程序/APP)
  • Tale全局函数对象base
  • 【启程Golang之旅】掌握Go语言数组基础概念与实际应用
  • COMSOL中液晶材料光学特性模拟
  • 虚拟现实环境下的远程教育和智能评估系统(五)
  • 【算法】模拟算法——Z字形变换(medium)
  • 上位机图像处理和嵌入式模块部署(f103 mcu获取唯一id)
  • 运筹学_3.运输问题(特殊的线性规划)
  • 科研项目书写作学习(持续更新中...)
  • 男士内裤哪个品牌好一点?2024热门男士内裤推荐
  • Llama模型家族之RLAIF 基于 AI 反馈的强化学习(六) RLAIF 代码实战
  • 计算机tcp/ip网络通信过程
  • 42.开发中对String.format()的使用之空位补齐
  • 通用代码生成器应用场景四,跨编程语言翻译
  • β-烟酰胺单核苷酸(NMN)功能不断得到验证 市场规模呈增长态势
  • 深入理解 Go 语言中的字符串不可变性与底层实现
  • 采购订单审批和取消例子
  • PHP:集成Xunsearch生成前端搜索骨架
  • ThreadLocal详解,与 HashMap 对比
  • flask流式接口
  • MatLab命令行常用命令记录
  • Linux —— MySQL操作(1)
  • TCP四次握手与http协议版本区别