当前位置: 首页 > news >正文

tensorflow07——使用tf.keras搭建神经网络(Sequential顺序神经网络)——六步法——鸢尾花数据集分类

使用tf.keras搭建顺序神经网络
六步法——鸢尾花数据集分类

01 导入相关包
02 导入数据集,打乱顺序
03 建立Sequential模型
04 编译——确定优化器,损失函数,评测指标(用哪一种准确率)
05 训练模型——把各项参入填入模型
06 总结——打印网络结构


# 01
import tensorflow as tf
from sklearn import datasets
import numpy as np# 02
x_train = datasets.load_iris().data
y_train = datasets.load_iris().target
# 测试集可以在此处按照上述方法划分
# 本案例把测试集放到训练过程fit中,按照比例直接从训练集中划分(validation_split)# 乱序步骤
np.random.seed(116)
np.random.shuffle(x_train)
np.random.seed(116)
np.random.shuffle(y_train)
tf.random.set_seed(116)# 03
model = tf.keras.models.Sequential([# 定义全连接层tf.keras.layers.Dense(3,activation='softmax',kernel_regularizer=tf.keras.regularizers.l2())
])# 04
model.compile(optimizer=tf.keras.optimizers.SGD(lr=0.1),loss=tf.keras.losses.SparseCategoricalCrossentropy(from_logits=False),metrics=['sparse_categorical_accuracy'])# 05
model.fit(x_train, y_train, batch_size=32, epochs=500, validation_split=0.2,validation_freq=20)# 06
model.summary()

输出结果

Train on 120 samples, validate on 30 samples
Epoch 1/500
120/120 [==============================] - 0s 3ms/sample - loss: 2.2022 - sparse_categorical_accuracy: 0.3833
Epoch 2/500
120/120 [==============================] - 0s 36us/sample - loss: 1.0013 - sparse_categorical_accuracy: 0.6083
Epoch 3/500
120/120 [==============================] - 0s 36us/sample - loss: 0.8497 - sparse_categorical_accuracy: 0.6333
。
。
此处省略500回合
。
。
。> Epoch 496/500 120/120 [==============================] - 0s
> 21us/sample - loss: 0.3384 - sparse_categorical_accuracy: 0.9583 Epoch
> 497/500 120/120 [==============================] - 0s 22us/sample -
> loss: 0.3442 - sparse_categorical_accuracy: 0.9750 Epoch 498/500
> 120/120 [==============================] - 0s 22us/sample - loss:
> 0.3394 - sparse_categorical_accuracy: 0.9583 Epoch 499/500 120/120 [==============================] - 0s 21us/sample - loss: 0.3394 -
> sparse_categorical_accuracy: 0.9333 Epoch 500/500 120/120
> [==============================] - 0s 168us/sample - loss: 0.4425 -
> sparse_categorical_accuracy: 0.8583 - val_loss: 0.3130 -
> val_sparse_categorical_accuracy: 0.9667 Model: "sequential"
> _________________________________________________________________ Layer (type)                 Output Shape              Param #   
> ================================================================= dense (Dense)                multiple                  15        
> ================================================================= Total params: 15 Trainable params: 15 Non-trainable params: 0
> ________________________________________________________________

由于sequential是顺序模型,不方便在中间加入其他步骤
可以采取类封装的形式,新建一个类,将整个神经网络模型封装装起来
里面设置两个函数方法_ _ init _ _和call
_ _ init _ _用于定义网络结构块
call用于实现前向传播

import tensorflow as tf
from tensorflow.keras.layers import Dense #新增
from tensorflow.keras import Model		  #新增
from sklearn import datasets
import numpy as npx_train = datasets.load_iris().data
y_train = datasets.load_iris().targetnp.random.seed(116)
np.random.shuffle(x_train)
np.random.seed(116)
np.random.shuffle(y_train)
tf.random.set_seed(116)#类名 IrisModel
class IrisModel(Model):def __init__(self):super(IrisModel, self).__init__()#定义——网络结构块self.d1 = Dense(3, activation='softmax', kernel_regularizer=tf.keras.regularizers.l2())def call(self, x):#调用——网络结构快,实现前向传播y = self.d1(x)return ymodel = IrisModel()model.compile(optimizer=tf.keras.optimizers.SGD(lr=0.1),loss=tf.keras.losses.SparseCategoricalCrossentropy(from_logits=False),metrics=['sparse_categorical_accuracy'])model.fit(x_train, y_train, batch_size=32, epochs=500, validation_split=0.2, validation_freq=20)
model.summary()
http://www.lryc.cn/news/28070.html

相关文章:

  • 关于Java连接Hive,Spark等服务的Kerberos工具类封装
  • 大数据框架之Hadoop:MapReduce(五)Yarn资源调度器
  • uniapp实现地图点聚合功能
  • 经典分类模型回顾2—GoogleNet实现图像分类(matlab版)
  • Java经典面试题——谈谈 final、finally、finalize 有什么不同?
  • C#的Version类型值与SQL Server中二进制binary类型转换
  • 软测入门(五)接口测试Postman
  • UWB通道选择、信号阻挡和反射对UWB定位范围和定位精度的影响
  • linux基本功之列之wget命令实战
  • 学习ROS时针对gazebo相关的问题(重装与卸载是永远的神)
  • 几个C语言容易忽略的问题
  • CentOS 7.9安装Zabbix 4.4《保姆级教程》
  • 路由器与交换机的区别(基础知识)
  • Python基础学习9——函数
  • 项目中的MD5、盐值加密
  • 电商项目后端框架SpringBoot、MybatisPlus
  • 2023年03月IDE流行度最新排名
  • 华为校招机试 - 数组取最小值(Java JS Python)
  • 20 客户端服务订阅的事件机制剖析
  • ThreadPoolExecutor中的addWorker方法
  • 9 有线网络的封装
  • Linux----网络基础(2)--应用层的序列化与反序列化--守护进程--0226
  • uipath实现滑动验证码登录
  • openai-chatGPT的API调用异常处理
  • css实现音乐播放器页面 · 笔记
  • buu [NPUCTF2020]这是什么觅 1
  • Restful API 设计规范
  • sigwaittest测试超标的调试过程
  • Python进阶-----面对对象4.0(面对对象三大特征之--继承)
  • 九龙证券|利好政策密集发布,机构扎堆看好的高增长公司曝光