当前位置: 首页 > news >正文

【TensorFlow】T1:实现mnist手写数字识别

  • 🍨 本文为🔗365天深度学习训练营 中的学习记录博客
  • 🍖 原作者:K同学啊

1、设置GPU

import tensorflow as tf
gpus = tf.config.list_physical_devices("GPU")if gpus:gpu0 = gpus[0]tf.config.experimental.set_memory_growth(gpu0, True)tf.config.set_visible_devices([gpu0],"GPU")print(gpus)
# 输出:[PhysicalDevice(name='/physical_device:GPU:0', device_type='GPU')]

固定模板,直接调用即可

2、导入数据

import tensorflow as tf
from tensorflow.keras import datasets, layers, models
import matplotlib.pyplot as plt# 导入mnist数据集
(train_images, train_labels), (test_images, test_labels) = tf.keras.datasets.mnist.load_data()
# train_images -- 训练集图片
# train_labels -- 训练集标签
# test_images  -- 测试集图片
# test_labels  -- 测试机标签

此处所需数据直接联网下载,故无需数据集

3、归一化

将图像数据从0-255缩放到0-1,主要是为了加快训练速度和提高模型的性能,使得模型训练过程更加稳定高效:

  1. 加速收敛:帮助优化算法更快找到最优解。
  2. 提升性能:确保不同规模的数据对模型的影响均衡,提高预测准确性。
  3. 简化调参:使超参数的选择更加简单有效。
  4. 数值稳定:避免因数据尺度问题导致的计算异常,如梯度爆炸。
train_images, test_images = train_images / 255.0, test_images / 255.0
print(train_images.shape, train_labels.shape,test_images.shape, test_labels.shape)
# 输出:(60000, 28, 28) (60000,) (10000, 28, 28) (10000,)

4、可视化

# figure函数--设置画布大小
# figsize--指定了画布的宽度(20英寸)和高度(10英寸)
plt.figure(figsize=(20, 10))# 遍历前20张图片
for i in range(20):# subplot函数--创建子图# 2--2行,10--10列,i+1--第i+1个位置plt.subplot(2, 10, i+1)# xticks函数--隐藏x轴plt.xticks([])# yticks函数--隐藏y轴plt.yticks([])# grid函数--关闭网格线plt.grid(False)# imshow函数--绘制图像# train_images[i]--需要显示的图片,cmap--设置颜色映射(binary)plt.imshow(train_images[i], cmap=plt.cm.binary)# xlabel函数--给图像添加标签plt.xlabel(train_labels[i])plt.show()

输出:
在这里插入图片描述

5、调整格式

# reshape函数--只改变数据形状,不改变数据内容
# 6000--6000个样本,28--28*28像素,1--颜色通道数(灰度)
train_images = train_images.reshape((60000, 28, 28, 1))
# 1000--1000个样本,28--28*28像素,1--颜色通道数(灰度)
test_images = test_images.reshape((10000, 28, 28, 1))print(train_images.shape, train_labels.shape,test_images.shape, test_labels.shape)
# 输出:(60000, 28, 28, 1) (60000,) (10000, 28, 28, 1) (10000,)

6、构建CNN网络模型

关键组件:

  1. 卷积层(Conv2D):主要用于提取图像的特征。它通过滑动窗口(即滤波器或核)在输入图像上移动,计算每个位置上的点积来生成特征图。每个滤波器可以识别特定类型的特征,如边缘或纹理。
  2. 激活函数(Activation function):如ReLU(Rectified Linear Unit),用于引入非线性因素,使模型能够学习更复杂的模式。
  3. 最大池化层(MaxPooling2D):用于减小特征图的空间维度(宽度和高度),同时保留最重要的信息。这有助于减少计算量并控制过拟合。
  4. 展平层(Flatten)将多维的卷积层输出转换为一维向量,以便将其输入到全连接层中。
  5. 全连接层(Dense):在这里,每个神经元与前一层的所有神经元相连,用于最终的分类任务。最后一层通常不使用激活函数(或者使用softmax函数),以直接输出每个类别的得分。
# Sequential函数--创建了一个Sequential模型,允许按顺序堆叠各层
model = models.Sequential([# Conv2D--第一个二维卷积层# 32--32个滤波器,(3, 3)--每个滤波器的大小为(3, 3),activation0--激活函数(relu),input_shape--输入数据的大小(28*28像素、单通道)layers.Conv2D(32, (3, 3), activation='relu', input_shape=(28, 28, 1)),# MaxPooling2D--第一个二维最大池化层# 2--池化窗口的大小为2*2,用于下采样时减少特征维度layers.MaxPooling2D(2, 2),# Conv2D--第二个二维卷积层layers.Conv2D(64, (3, 3), activation='relu'),# MaxPooling2D--第二个二维最大池化层layers.MaxPooling2D(2, 2),# Flatten--将多维的卷积层输出展平为一维向量,方便输入全连接层layers.Flatten(),# Dense--全连接层# 64--64个节点,activation0--激活函数(relu)layers.Dense(64, activation='relu'),# Dense--输出层# 10--10个节点(此处对应0-9),默认线性激活函数layers.Dense(10)
])# 打印模型结构和参数信息,包括每一层的输出形状和参数数量
print(model.summary())

输出:
在这里插入图片描述

7、编译模型

# compile函数--配置模型的编译设置
model.compile(# optimizer--优化器(adam)# Adam:一种基于一阶梯度的优化算法,能自适应地调整不同参数的学习率,适用于大规模数据和高维空间问题optimizer='adam',# loss--损失函数(Sparse Categorical Crossentropy)# Sparse Categorical Crossentropy(稀疏分类交叉熵损失):适用于多分类问题,特别是标签是整数形式时# from_logits=True:表示网络输出未经过softmax激活函数处理。这种情况下,损失函数会自动应用softmax来计算最终的分类概率loss=tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True),# metrics--性能指标# accuracy(准确率):表示预测正确的样本数占总样本数的比例metrics=['accuracy']
)

8、训练模型

  1. 主要参数:
  • x训练数据的输入(特征)。它可以是一个 Numpy 数组或者一个 TensorFlow 数据集等。对于图像数据,这通常是一个形状为 (样本数量, 高度, 宽度, 颜色通道) 的四维数组。
  • y目标数据(标签)。与 x 对应,表示每个输入样本的真实类别或值。对于分类任务,这通常是一个一维数组,长度等于训练样本的数量;对于多分类问题,可能是一个二维数组,采用 one-hot 编码。
  • batch_size每次梯度更新时使用的样本数。默认值是 32。较大的批量大小可以使计算更高效,但需要更多的内存。
  • epochs整个训练集迭代的次数。每次迭代称为一个 epoch。增加 epoch 数量可以让模型有更多机会学习数据中的模式,但也增加了过拟合的风险。
  • verbose:日志显示模式。0 表示不输出日志到屏幕上,1 表示输出进度条记录,2 表示每个epoch输出一行记录。
  • validation_data:在每个 epoch 结束后用于评估模型性能的数据集。这是一个元组 (x_val, y_val) 或者一个包含输入和目标数据的列表。这对于监控模型在未见过的数据上的表现非常重要。
  • validation_split:从训练数据中划分出一定比例的数据作为验证集,取值范围是 (0, 1)。注意,这个参数只会在 x 是 Numpy 数组时有效。
  • shuffle:在每个 epoch 开始前是否打乱训练数据。这对于确保模型不会因为数据顺序而产生偏差非常重要,默认值为 True。
  • callbacks:一个列表,其中包含各种回调函数对象,在训练过程中这些回调函数会在特定时间点被调用(如每轮结束、训练开始或结束等),可用于实现提前停止、动态调整学习率等功能。
  1. 返回值:
  • History对象:该对象包含了一个 history 属性,这是一个字典,包含了训练过程中损失值评估指标的变化情况。可以访问 history.history[‘loss’] 来获取每个 epoch 的训练损失值,或 history.history[‘val_accuracy’] 获取每个 epoch 的验证准确率。
    “”"
history = model.fit(train_images,train_labels,epochs=10,validation_data=(test_images, test_labels),
)

9、模型预测

plt.imshow(test_images[1])
plt.show()pre = model.predict(test_images)
print(pre[1])
# 输出:[-0.0146451 0.11037247 -0.01110678 0.03087252 -0.02923543 -0.10968889 -0.00841374 0.04551534  -0.02969249 -0.00869128]

输出:
在这里插入图片描述

10、完整代码

# 1.设置GPU
import tensorflow as tfgpus = tf.config.list_physical_devices("GPU")if gpus:gpu0 = gpus[0]tf.config.experimental.set_memory_growth(gpu0, True)tf.config.set_visible_devices([gpu0], "GPU")print(gpus)# 2.导入数据
import tensorflow as tf
from tensorflow.keras import datasets, layers, models
import matplotlib.pyplot as plt(train_images, train_labels), (test_images, test_labels) = datasets.mnist.load_data()# 3.归一化
train_images, test_images = train_images / 255.0, test_images / 255.0print(train_images.shape,test_images.shape,train_labels.shape,test_labels.shape)# 4.可视化
plt.figure(figsize=(20,10))
for i in range(20):plt.subplot(2,10,i+1)plt.xticks([])plt.yticks([])plt.grid(False)plt.imshow(train_images[i], cmap=plt.cm.binary)plt.xlabel(train_labels[i])plt.show()# 5.调整图片格式
train_images = train_images.reshape((60000, 28, 28, 1))
test_images = test_images.reshape((10000, 28, 28, 1))print(train_images.shape,test_images.shape,train_labels.shape,test_labels.shape)# 6.构建CNN网络模型
model = models.Sequential([layers.Conv2D(32, (3, 3), activation='relu', input_shape=(28, 28, 1)),layers.MaxPooling2D((2, 2)),layers.Conv2D(64, (3, 3), activation='relu'),layers.MaxPooling2D((2, 2)),layers.Flatten(),layers.Dense(64, activation='relu'),layers.Dense(10)
])
print(model.summary())# 7.编译模型
model.compile(optimizer='adam',loss=tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True),metrics=['accuracy'])# 8.训练模型
history = model.fit(train_images,train_labels,epochs=10,validation_data=(test_images, test_labels))# 9.模型预测
plt.imshow(test_images[1])
plt.show()pre = model.predict(test_images)
print(pre[1])
http://www.lryc.cn/news/532021.html

相关文章:

  • Rapidjson 实战
  • 【React】受控组件和非受控组件
  • Ollama+deepseek+Docker+Open WebUI实现与AI聊天
  • DEEPSEKK GPT等AI体的出现如何重构工厂数字化架构:从设备控制到ERP MES系统的全面优化
  • 阿莱(arri)mxf文件变0字节的恢复方法
  • 初识 Node.js
  • debug-vscode调试方法
  • Cypher进阶(函数、索引)
  • XML Schema 数值数据类型
  • Window获取界面空闲时间
  • Java进阶(vue基础)
  • Mac电脑上好用的压缩软件
  • Ubuntn24.04安装
  • 基于ansible部署elk集群
  • 解锁.NET Fiddle:在线编程的神奇之旅
  • 记录pve中使用libvirt创建虚拟机
  • 【HTML性能优化】提升网站加载速度:GZIP、懒加载与资源合并
  • 三维空间全局光照 | 及各种扫盲
  • 数据库开发常识(10.6)——SQL性能判断标准及索引误区(1)
  • 网络爬虫js逆向之某音乐平台案例
  • Spark--算子执行原理
  • 事件驱动架构(EDA)
  • C++ 入门速通-第5章【黑马】
  • 2025春招,深度思考MyBatis面试题
  • 排序算法--冒泡排序
  • 简易C语言矩阵运算库
  • 通过C/C++编程语言实现“数据结构”课程中的链表
  • 【分布式架构理论3】分布式调用(2):API 网关分析
  • 基于Kamailio、MySQL、Redis、Gin、Vue.js的微服务架构
  • 6S模型的编译问题解决