当前位置: 首页 > news >正文

RNN(一)——循环神经网络的实现

文章目录

  • 一、循环神经网络RNN
    • 1.RNN是什么
    • 2.RNN的语言模型
    • 3.RNN的结构形式
  • 二、完整代码
  • 三、代码解读
    • 1.参数return_sequences
    • 2.调参过程

一、循环神经网络RNN

1.RNN是什么

循环神经网络RNN主要体现在上下文对理解的重要性,他比传统的神经网络(传统的神经网络结构:输入层-隐藏层-输出层)更细腻温情,前面所有的输入产生的结果都对后续输出产生影响,他关注隐层每个神经元在时间维度上的成长。体现在图上,就是表示隐层在不同时刻的状态。RNN在小数据集,低算力的情况下非常有效。

在这里插入图片描述
在这里插入图片描述

2.RNN的语言模型

在这里插入图片描述

3.RNN的结构形式

由于时序上的层级就够,使得RNN在输入输出关系上有很大的灵活性。以下是四种结构形式:

  1. 单入多出的形式:可实现看图说话等功能。
    在这里插入图片描述
  2. N to one:与上面一种刚好相反,输入很多句话,可以输出一张图片。

在这里插入图片描述

  1. N to N:输入输出等长序列。可生成文章、诗歌、代码等。

在这里插入图片描述

  1. N to M(Encoder-Decoder模型或Seq2Seq模型):将输入数据编码成上下文向量,然后输出预测的序列。常用语文本翻译、阅读理解、对话生成等很多领域广泛应用。

二、完整代码

# 一、前期准备
# 1.1 导入所需包和设置GPU
import osos.environ['TF_CPP_MIN_LOG_LEVEL'] = '2'  # 不显示等级2以下的提示信息
import tensorflow as tf
import pandas as pd
import numpy as np
from sklearn.preprocessing import StandardScaler
from sklearn.model_selection import train_test_split
from tensorflow.keras.models import Sequential
from tensorflow.keras.layers import Dense, LSTM, SimpleRNN
import matplotlib.pyplot as pltgpus = tf.config.list_physical_devices("GPU")if gpus:gpu0 = gpus[0]  # 如果有多个GPU,仅使用第0个GPUtf.config.experimental.set_memory_growth(gpu0, True)  # 设置GPU显存用量按需使用tf.config.set_visible_devices([gpu0], "GPU")
print(gpus)#1.2 导入数据
df = pd.read_csv('R1heart.csv')
print(df)df.isnull().sum()  #检查是否有空值#二、数据预处理
#2.1 数据集划分
x = df.iloc[:,:-1]
y = df.iloc[:,-1]x_train, x_test, y_train, y_test = train_test_split(x, y, test_size=0.1, random_state=1)
print(x_train.shape, y_train.shape)# 将每一列特征标准化为标准正态分布,注意,标准化是针对每一列而言的
sc = StandardScaler()
x_train = sc.fit_transform(x_train)
x_test = sc.transform(x_test)x_train = x_train.reshape(x_train.shape[0], x_train.shape[1], 1)
x_test = x_test.reshape(x_test.shape[0], x_test.shape[1], 1)#三、构建RNN模型model = Sequential()
model.add(SimpleRNN(128, input_shape= (13,1),return_sequences=True,activation='relu'))
model.add(SimpleRNN(64,return_sequences=True, activation='relu'))
model.add(SimpleRNN(32, activation='relu'))
model.add(Dense(64, activation='relu'))
model.add(Dense(1, activation='sigmoid'))
model.summary()#四、编译模型
opt = tf.keras.optimizers.Adam(learning_rate=1e-4)
model.compile(loss='binary_crossentropy', optimizer=opt,metrics=['accuracy'])#五、训练模型
epochs = 100
history = model.fit(x_train, y_train,epochs=epochs,batch_size=128,validation_data=(x_test, y_test),verbose=1)
#六、模型评估
acc = history.history['accuracy']
val_acc = history.history['val_accuracy']loss = history.history['loss']
val_loss = history.history['val_loss']epochs_range = range(epochs)plt.figure(figsize=(14, 4))
plt.subplot(1, 2, 1)plt.plot(epochs_range, acc, label='Training Accuracy')
plt.plot(epochs_range, val_acc, label='Validation Accuracy')
plt.legend(loc='lower right')
plt.title('Training and Validation Accuracy')plt.subplot(1, 2, 2)
plt.plot(epochs_range, loss, label='Training Loss')
plt.plot(epochs_range, val_loss, label='Validation Loss')
plt.legend(loc='upper right')
plt.title('Training and Validation Loss')
plt.show()scores = model.evaluate(x_test,y_test,verbose=0)
print("%s: %.2f%%" % (model.metrics_names[1], scores[1]*100))

打印结果:
在这里插入图片描述

三、代码解读

1.参数return_sequences

当return_sequences=True时,无论输入序列的长度如何,输出都将是一个三维数组,其形状为[batch_size, sequence_length, output_dim]。这在处理序列数据时非常有用,特别是当你需要在多个时间步上使用层的输出时。

当return_sequences=False(默认值)时,只有序列中的最后一个时间步的输出会被返回,输出形状为[batch_size, output_dim]。

2.调参过程

尝试将RNN层分别增加到三层和四层,层数越多精确度越高,其中前n-1层都需要加参数return_sequences=True,意味着它的输出将保留整个序列的信息,可以被下一个RNN层使用,否则就会出现维度不匹配的情况,比如simple_rnn_2 层期望的输入数据维度是3(即,一个三维张量),但实际接收到的输入数据维度是2,就会出现报错。
也可尝试对全连接层的层数进行调整,也可对激活函数activation进行调整。但效果都不如调整RNN层数精确度高。

小记:
距离新疆之旅还有半个月,已经有点浮躁了,因为此次旅行有点不太一样,一家四口整整齐齐的分别从各自呆的城市“一起出发”,汇聚到同一趟车上,神奇吧!此行并不是突发奇想的说走就走的旅行,这个所谓的蓄谋已久持续了4年,多少还是有点期待的。那就在畅玩之前先整个“两周畅学卡”吧!

参考:
【循环神经网络】5分钟搞懂RNN,3D动画深入浅出

http://www.lryc.cn/news/407284.html

相关文章:

  • php 根据位置的经纬度计算距离
  • 17 Python常用内置函数——基本输入输出
  • 【Web】LitCTF 2024 题解(全)
  • 家政项目小程序的设计
  • electron TodoList网页应用打包成linux deb、AppImage应用
  • 【C语言】 使用fgets和fputs完成两个文件的拷贝
  • 使用PyTorch导出JIT模型:C++ API与libtorch实战
  • Python——异常捕获,传递及其抛出操作
  • 【Maven】 的继承机制
  • 微信小程序结合后端php发送模版消息
  • sqlalchemy报错sqlalchemy.orm.exc.DetachedInstanceError
  • 华为网络模拟器eNSP安装部署教程
  • 【React】详解样式控制:从基础到进阶应用的全面指南
  • 【ROS2】高级:安全-理解安全密钥库
  • C语言 ——— 数组指针的定义 数组指针的使用
  • opencascade AIS_ManipulatorOwner AIS_MediaPlayer源码学习
  • 如何防止用户通过打印功能复制页面文字
  • Python3网络爬虫开发实战(3)网页数据的解析提取
  • 基于 HTML+ECharts 实现监控平台数据可视化大屏(含源码)
  • 立创梁山派--移植开源的SFUD和FATFS实现SPI-FLASH文件系统
  • MySQL之视图和索引实战
  • 快速参考:用C# Selenium实现浏览器窗口缩放的步骤
  • MyBatis 插件机制、分页插件如何实现的
  • CentOS6.0安装telnet-server启用telnet服务
  • H5+CSS+JS工作性价比计算器
  • Linux:基础命令学习
  • 遇到Websocket就不会测了?别慌,学会这个Jmeter插件轻松解决....
  • 高性能 Java 本地缓存 Caffeine 框架介绍及在 SpringBoot 中的使用
  • Http 和 Https 的区别(图文详解)
  • DP学习——外观模式