当前位置: 首页 > news >正文

边写代码边学习之LSTM

1.  什么是LSTM

长短期记忆网络 LSTM(long short-term memory)是 RNN 的一种变体,其核心概念在于细胞状态以及“门”结构。细胞状态相当于信息传输的路径,让信息能在序列连中传递下去。你可以将其看作网络的“记忆”。理论上讲,细胞状态能够将序列处理过程中的相关信息一直传递下去。因此,即使是较早时间步长的信息也能携带到较后时间步长的细胞中来,这克服了短时记忆的影响。信息的添加和移除我们通过“门”结构来实现,“门”结构在训练过程中会去学习该保存或遗忘哪些信息。
 

在这里插入图片描述

 

2. 实验代码

2.1. 搭建一个只有一层RNN和Dense网络的模型。

2.2. 验证LSTM里的逻辑

 假设我的输入数据是x = [1,0], 

kernel = [[[2, 1, 1, 0, 0, 0, 0, 1, 1, 0, 1, 0],

              [1, 1, 0, 1, 1, 0, 0, 1, 1 ,0, 0, 0],]]

recurrent_kernel = [[1, 0, 0, 1, 2,1,0,1,2,0,1,0],

                              [1, 1, 0, 0, 2,1,0,1,2,2,0,0],

                              [1, 0, 1, 2, 0,1,0,1,1,0,1,0]]

biase = [3, 1, 0, 1, 1,0,0,1,0,2,0.0,0]

通过下面手算,h的结果是[0, 4,1], c 的结果是[0,4,1].  注意无激活函数。

代码验证上面的结果


def change_weight():# Create a simple Dense layerlstm_layer = LSTM(units=3, input_shape=(3, 2), activation=None, recurrent_activation=None, return_sequences=True,return_state= True)# Simulate input data (batch size of 1 for demonstration)input_data = np.array([[[1.0, 2], [2, 3], [3, 4]],[[5, 6], [6, 7], [7, 8]],[[9, 10], [10, 11], [11, 12]]])# Pass the input data through the layer to initialize the weights and biaseslstm_layer(input_data)kernel, recurrent_kernel, biases = lstm_layer.get_weights()# Print the initial weights and biasesprint("recurrent_kernel:", recurrent_kernel, recurrent_kernel.shape ) # (3,3)print('kernal:',kernel, kernel.shape) #(2,3)print('biase: ',biases , biases.shape) # (3)kernel = np.array([[2, 1, 1, 0, 0, 0, 0, 1, 1, 0, 1, 0],[1, 1, 0, 1, 1, 0, 0, 1, 1 ,0, 0, 0],])recurrent_kernel = np.array([[1, 0, 0, 1, 2,1,0,1,2,0,1,0],[1, 1, 0, 0, 2,1,0,1,2,2,0,0],[1, 0, 1, 2, 0,1,0,1,1,0,1,0]])biases = np.array([3, 1, 0, 1, 1,0,0,1,0,2,0.0,0])lstm_layer.set_weights([kernel, recurrent_kernel, biases])print(lstm_layer.get_weights())# test_data = np.array([#     [[1.0, 3], [1, 1], [2, 3]]# ])test_data = np.array([[[1,0.0]]])output, memory_state, carry_state  = lstm_layer(test_data)print(output)print(memory_state)print(carry_state)
if __name__ == '__main__':change_weight()

执行结果:

recurrent_kernel: [[-0.36744034 -0.11181469 -0.10642298  0.5450207  -0.30208975  0.54054320.09643812 -0.14983998  0.1859854   0.2336958  -0.16187981  0.11621032][ 0.07727922 -0.226477    0.1491096  -0.03933501  0.31236103 -0.129630920.10522162 -0.4815724  -0.2093935   0.34740582 -0.60979587 -0.15877807][ 0.15371156  0.01244636 -0.09840634 -0.32093546  0.06523462  0.189349320.38859126 -0.3261706  -0.05138849  0.42713478  0.49390993  0.37013963]] (3, 12)
kernal: [[-0.47606698 -0.43589187 -0.5371355  -0.07337284  0.30526626 -0.18241835-0.03675252  0.2873094   0.33218485  0.24838251  0.17765659  0.4312396 ][ 0.4007727   0.41280174  0.40750778 -0.6245315   0.6382301   0.428892250.11961156 -0.6021105  -0.43556038  0.39798307  0.6390712   0.16719025]] (2, 12)
biase:  [0. 0. 0. 1. 1. 1. 0. 0. 0. 0. 0. 0.] (12,)
[array([[2., 1., 1., 0., 0., 0., 0., 1., 1., 0., 1., 0.],[1., 1., 0., 1., 1., 0., 0., 1., 1., 0., 0., 0.]], dtype=float32), array([[1., 0., 0., 1., 2., 1., 0., 1., 2., 0., 1., 0.],[1., 1., 0., 0., 2., 1., 0., 1., 2., 2., 0., 0.],[1., 0., 1., 2., 0., 1., 0., 1., 1., 0., 1., 0.]], dtype=float32), array([3., 1., 0., 1., 1., 0., 0., 1., 0., 2., 0., 0.], dtype=float32)]
tf.Tensor([[[0. 4. 0.]]], shape=(1, 1, 3), dtype=float32)
tf.Tensor([[0. 4. 0.]], shape=(1, 3), dtype=float32)
tf.Tensor([[0. 4. 1.]], shape=(1, 3), dtype=float32)

可以看出h=[0,4,0], c=[0,4,1]

http://www.lryc.cn/news/112335.html

相关文章:

  • Elasticsearch8.8.0 SpringBoot实战操作各种案例(索引操作、聚合、复杂查询、嵌套等)
  • 《MySQL高级篇》十五、其他数据库日志
  • 【Linux】【预】配置虚拟机的桥接网卡+nfs
  • 【Android】Retrofit2和RxJava2新手快速上手
  • 1.4 Nacos注册中心
  • AOJ 2200 Mr. Rito Post Office 最短路径+动态规划+谨慎+思维
  • 红米电视 ADB 安装 app 报错 failed to authenticate xxx:5555
  • Linux 下设置开机自启动的方法
  • MySQL常见问题处理(三)
  • maven中常见问题
  • vue2中bus的使用
  • 实证研究在机器学习中的应用
  • IO进程线程day8(2023.8.6)
  • 【5G NR】逻辑信道、传输信道和物理信道的映射关系
  • tmux基础教程
  • 项目实战 — 消息队列(4){消息持久化}
  • AI编程工具Copilot与Codeium的实测对比
  • webpack基础知识六:说说webpack的热更新是如何做到的?原理是什么?
  • Linux从安装到实战 常用命令 Bash常用功能 用户和组管理
  • webpack基础知识三:说说webpack中常见的Loader?解决了什么问题?
  • 深度学习:Pytorch常见损失函数Loss简介
  • 【Android-java】Parcelable 是什么?
  • Spring整合MyBatis小实例(转账功能)
  • List集合的对象传输的两种方式
  • 海外媒体发稿:软文写作方法方式?一篇好的软文理应合理规划?
  • 【秋招】算法岗的八股文之机器学习
  • 为什么list.sort()比Stream().sorted()更快?
  • SQL账户SA登录失败,提示错误:18456
  • Linux 终端操作命令(1)
  • java与javaw运行jar程序