当前位置: 首页 > news >正文

python3+TensorFlow 2.x(三)手写数字识别

目录

代码实现

模型解析:

1、加载 MNIST 数据集:

2、数据预处理:

3、构建神经网络模型:

4、编译模型:

5、训练模型:

6、评估模型:

7、预测和可视化结果:

输出结果:

总结:


代码实现

TensorFlow 2.x 实现手写数字识别(MNIST 数据集)。MNIST 数据集包含了 28x28 像素的手写数字图像,任务是将这些图像分类为 10 个类别(0-9) 

import tensorflow as tf
from tensorflow.keras import layers, models
import matplotlib.pyplot as plt# 1. 加载 MNIST 数据集
(train_images, train_labels), (test_images, test_labels) = tf.keras.datasets.mnist.load_data()# 2. 数据预处理:归一化和改变形状
train_images = train_images / 255.0  # 将图像像素值归一化到 [0, 1]
test_images = test_images / 255.0# 调整形状,使得每张图片的维度是 [28, 28, 1],因为模型需要3D输入
train_images = train_images.reshape((train_images.shape[0], 28, 28, 1))
test_images = test_images.reshape((test_images.shape[0], 28, 28, 1))# 3. 构建神经网络模型
model = models.Sequential([layers.Conv2D(32, (3, 3), activation='relu', input_shape=(28, 28, 1)),layers.MaxPooling2D((2, 2)),layers.Conv2D(64, (3, 3), activation='relu'),layers.MaxPooling2D((2, 2)),layers.Conv2D(64, (3, 3), activation='relu'),layers.Flatten(),layers.Dense(64, activation='relu'),layers.Dense(10, activation='softmax')  # 10类分类问题
])# 4. 编译模型:选择优化器、损失函数和评价指标
model.compile(optimizer='adam',loss='sparse_categorical_crossentropy',  # 因为标签是整数,所以使用 sparse_categorical_crossentropymetrics=['accuracy'])# 5. 训练模型
history = model.fit(train_images, train_labels, epochs=5, validation_data=(test_images, test_labels))# 6. 评估模型
test_loss, test_acc = model.evaluate(test_images, test_labels)
print(f"Test accuracy: {test_acc}")# 7. 可视化训练过程中的损失和准确率变化
plt.plot(history.history['accuracy'], label='Training Accuracy')
plt.plot(history.history['val_accuracy'], label='Validation Accuracy')
plt.title('Training and Validation Accuracy')
plt.xlabel('Epochs')
plt.ylabel('Accuracy')
plt.legend()
plt.show()plt.plot(history.history['loss'], label='Training Loss')
plt.plot(history.history['val_loss'], label='Validation Loss')
plt.title('Training and Validation Loss')
plt.xlabel('Epochs')
plt.ylabel('Loss')
plt.legend()
plt.show()# 8. 使用模型进行预测
predictions = model.predict(test_images)# 显示一些预测结果
for i in range(5):plt.imshow(test_images[i].reshape(28, 28), cmap='gray')plt.title(f"Predicted Label: {predictions[i].argmax()}, Actual Label: {test_labels[i]}")plt.show()

模型解析:

1、加载 MNIST 数据集:

使用 tf.keras.datasets.mnist.load_data() 函数来加载 MNIST 数据集。返回的数据包括训练集和测试集。训练集有 60,000 张图像,测试集有 10,000 张图像。

2、数据预处理:

将图像的像素值从 [0, 255] 归一化到 [0, 1],使每个像素的值在 0 到 1 之间,提升模型的训练效果。将每张图像的形状调整为 (28, 28, 1),即每个图像是 28x28 的灰度图像。

3、构建神经网络模型:

使用卷积神经网络(CNN)构建模型:Conv2D 层用于提取图像的特征,使用了 ReLU 激活函数。MaxPooling2D 层用于下采样,减少计算量。Flatten 层将卷积层的输出展平,进入全连接层。Dense 层用于输出分类结果,其中最后一层使用了 softmax 激活函数,将模型的输出转换为 10 类的概率分布。

4、编译模型:

使用 adam 优化器,sparse_categorical_crossentropy 作为损失函数(适用于类别标签是整数的情况),并使用 accuracy 作为评价指标。

5、训练模型:

使用 model.fit 训练模型,设置了 5 个 epoch,使用训练集进行训练,并验证模型在测试集上的表现。

6、评估模型:

使用 model.evaluate 在测试集上评估模型的准确性。并可视化训练过程中的损失和准确率变化:使用 matplotlib 绘制训练过程中的损失和准确率变化曲线,查看模型的学习进度。

7、预测和可视化结果

使用训练好的模型对测试集进行预测,展示一些预测结果,并与真实标签进行对比。

输出结果

训练和验证准确率:随着训练的进行,准确率应该逐渐提高。
测试准确率:训练完成后,模型在测试集上的准确率会显示出来,通常可以达到 98% 以上。
预测图像:展示一些手写数字图像,标注预测的标签和实际标签。

预测可视化展示

总结:

该模型使用了卷积层、池化层以及全连接层,在 MNIST 数据集上训练,最终达到了很好的分类效果。你可以调整模型的超参数(例如卷积层的数量、神经元的数量等)以提高性能。

http://www.lryc.cn/news/528747.html

相关文章:

  • 杨辉三角(蓝桥杯2021年H)
  • 【蓝桥杯嵌入式入门与进阶】2.与开发板之间破冰:初始开发板和原理图2
  • C++ queue
  • 【MySQL-7】事务
  • 03链表+栈+队列(D1_链表(D1_基础学习))
  • Git 出现 Please use your personal access token instead of the password 解决方法
  • AI大模型开发原理篇-1:语言模型雏形之N-Gram模型
  • STM32新建不同工程的方式
  • 【Rust自学】14.5. cargo工作空间(Workspace)
  • 全面了解 Web3 AIGC 和 AI Agent 的创新先锋 MelodAI
  • 10.3 LangChain实战指南:解锁大模型应用的10大核心场景与架构设计
  • Swing使用MVC模型架构
  • 设计新的 Kibana 仪表板布局以支持可折叠部分等
  • 修改maven的编码格式为utf-8
  • 解锁罗技键盘新技能:轻松锁定功能键(罗技K580)
  • HTB:Active[RE-WriteUP]
  • [C语言日寄] 源码、补码、反码介绍
  • 安卓逆向之脱壳-认识一下动态加载 双亲委派(一)
  • Nuxt:利用public-ip这个npm包来获取公网IP
  • babylon.js-3:了解STL网格模型
  • 基于SpringBoot的假期周边游平台的设计与实现(源码+SQL脚本+LW+部署讲解等)
  • 【MySQL】初始MySQL、库与表的操作
  • 将DeepSeek接入Word,打造AI办公助手
  • Coze,Dify,FastGPT,对比
  • Kafka 日志存储 — 磁盘存储
  • 996引擎 - NPC-添加NPC引擎自带形象
  • GL C++显示相机YUV视频数据使用帧缓冲FBO后期处理,实现滤镜功能。
  • 【hot100】刷题记录(7)-除自身数组以外的乘积
  • 解决.NET程序通过网盘传到Linux和macOS不能运行的问题
  • 练习(复习)