当前位置: 首页 > news >正文

17- TensorFlow实现手写数字识别 (tensorflow系列) (项目十七)

项目要点

  • 模型创建: model = Sequential()
  • 添加卷积层: model.add(Dense(32, activation='relu', input_dim=100))  # 第一层需要 input_dim
  • 添加dropout: model.add(Dropout(0.2))
  • 添加第二次网络: model.add(Dense(512, activation='relu'))   # 除了first, 其他层不要输入shape
  • 添加输出层: model.add(Dense(num_classes, activation='softmax'))  # last 通常使用softmax
  • TensorFlow 中,使用 model.compile 方法来选择优化器和损失函数:
    • optimizer: 优化器: 主要有: tf.train.AdamOptimizer , tf.train.RMSPropOptimizer , or tf.train.GradientDescentOptimizer .

    • loss: 损失函数: 主要有:mean square error (mse, 回归), categorical_crossentropy (多分类) , and binary_crossentropy (二分类).

    • metrics: 算法的评估标准, 一般分类用accuracy.

  • model.fit(x_train, y_train, batch_size = 64, epochs = 20, validation_data = (x_test, y_test))    # 模型训练
  • score = model.evaluate(x_test, y_test, verbose=0)    两个返回值: [ 损失率 , 准确率 ]


1 实例演示Keras的使用 (手写数字识别)

1.1 导包

import keras
from keras.datasets import mnist
from keras.models import Sequential
from keras.layers import Dense, Dropout
from keras.optimizers import rmsprop_v2

1.2 导入数据

# 导入手写数字数据集
(x_train, y_train), (x_test, y_test) = mnist.load_data()
print(x_train.shape, y_train.shape, x_test.shape, y_test.shape)
'''(60000, 28, 28) (60000,) (10000, 28, 28) (10000,)'''
import matplotlib.pyplot as plt
plt.imshow(x_train[0], cmap = 'gray')

 1.3 数据初步处理

# 对数据进行初步处理
x_train = x_train.reshape(60000, 784)
x_test = x_test.reshape(10000, 784)
x_train = x_train.astype('float32')
x_test = x_test.astype('float32')
x_train /= 255
x_test /= 255
print(x_train.shape, 'train samples')  # (60000, 784) train samples
print(x_test.shape, 'test samples')    # (10000, 784) test samples

1.4 数据初步处理

  • 独热编码
import tensorflow
# 将标记结果转化为独热编码
num_classes = 10
y_train = tensorflow.keras.utils.to_categorical(y_train, num_classes)
y_test = tensorflow.keras.utils.to_categorical(y_test, num_classes)
y_train

  1.5 创建模型

# 创建顺序模型
model = Sequential()
# 添加第一层网络, 512个神经元, 激活函数为relu
model.add(Dense(512, activation='relu', input_shape=(784,)))
# 添加Dropout
model.add(Dropout(0.2))
# 第二层网络
model.add(Dense(512, activation='relu'))
model.add(Dropout(0.2))
# 输出层
model.add(Dense(num_classes, activation='softmax'))
# 打印神经网络参数情况
model.summary()

 1.6 模型训练

# 编译
model.compile(loss='categorical_crossentropy',optimizer='rmsprop',metrics=['accuracy'])batch_size = 128
epochs = 20
# 训练并打印中间过程
history = model.fit(x_train, y_train,batch_size=batch_size,epochs=epochs,verbose=1,validation_data=(x_test, y_test))
# 计算预测数据的准确率
score = model.evaluate(x_test, y_test, verbose=0)
print('Test loss:', score[0])  # Test loss: 0.14742641150951385
print('Test accuracy:', score[1])   # Test accuracy: 0.9815000295639038

http://www.lryc.cn/news/22763.html

相关文章:

  • Polkadot 基础
  • spring源码编译
  • 防盗链是什么?带你了解什么是防盗链
  • Linux基础命令-fdisk管理磁盘分区表
  • (四)K8S 安装 Nginx Ingress Controller
  • 高频面试题
  • js 字节数组操作,TCP协议组装
  • JavaScript的引入并执行-包含动态引入与静态引入
  • 第四阶段01-酷鲨商城项目准备
  • Uncaught ReferenceError: jQuery is not defined
  • 面试阿里测开岗,被面试官针对,当场翻脸,把我的简历还给我,疑似被拉黑...
  • 2. 驱动开发--驱动开发环境搭建
  • 《数据库系统概论》学习笔记——第四章 数据库安全
  • 山洪径流过程模拟及洪水危险性评价
  • LeetCode HOT100 (23、32、33)
  • 电力监控仪表主要分类
  • 山野户外定位依赖GPS或者卫星电话就能完成么?
  • SAP 应收应付重组配置
  • 算法练习(八)计数质数(素数)
  • 用反射模拟IOC模拟getBean
  • 【Ap AutoSAR入门与实战开发02】-【Ap_s2s模块01】: s2s的背景
  • C语言数据结构(3)----无头单向非循环链表
  • Android 实现菜单拖拽排序
  • 通过window.open打开新的页面并修改样式添加内容
  • Java中 Synchronized 的用法
  • Rust语言的基本介绍
  • 新冠小阳人症状记录
  • SQL零基础入门学习(十四)
  • Excel工作表不能移动或复制?看看是不是这两个原因
  • 利用递归实现括号匹配