当前位置: 首页 > news >正文

Python + 深度学习从 0 到 1(02 / 99)

希望对你有帮助呀!!💜💜 如有更好理解的思路,欢迎大家留言补充 ~ 一起加油叭 💦
欢迎关注、订阅专栏 【深度学习从 0 到 1】谢谢你的支持!

⭐ 手写数字分类: Keras + MNIST 数据集

手写数字分类任务

任务:将手写数字的灰度图像(28像素×28像素)划分到10个类别中(0~9)

MNIST数据集包含60 000张训练图像和10 000张测试图像,由美国国家标准与技术研究院(National Institute of Standards and Technology,即 MNIST 中的NIST)在20世纪80年代收集得到

  • 样本示例如下:(hint: 显示数据集的第一个数字的代码:plt.imshow(train_images[0], cmap=plt.cm.binary))
💜步骤一 : 加载Keras中的MNIST数据集
from keras.datasets import mnist 
(train_images, train_labels), (test_images, test_labels) = mnist.load_data()  # 包括4个Numpy数组# 准备数据 
train_images = train_images.reshape((60000, 28 * 28)) 
train_images = train_images.astype('float32') / 255 
test_images  = test_images.reshape((10000, 28 * 28)) 
test_images  = test_images.astype('float32') / 255# 准备标签
from keras.utils import to_categorical 
train_labels = to_categorical(train_labels) 
test_labels  = to_categorical(test_labels)
💜步骤二 : 构建网络架构 (两层全连接层为例)
from keras import models 
from keras import layers 
network = models.Sequential() 
network.add(layers.Dense(512, activation='relu', input_shape=(28 * 28,))) 
network.add(layers.Dense(10, activation='softmax'))
💜步骤三: 编译步骤 (optimizer + loss + metrics)
network.compile(optimizer='rmsprop', loss='categorical_crossentropy', metrics=['accuracy'])
💜步骤四:训练网络
network.fit(train_images, train_labels, epochs=5, batch_size=128)
💜步骤五:测试网络
 test_loss, test_acc = network.evaluate(test_images, test_labels) 

完整代码参考:

from keras.datasets import mnist 
from keras import models 
from keras import layers (train_images, train_labels), (test_images, test_labels) = mnist.load_data()  # 包括4个Numpy数组# 准备数据 
train_images = train_images.reshape((60000, 28 * 28)) 
train_images = train_images.astype('float32') / 255 
test_images  = test_images.reshape((10000, 28 * 28)) 
test_images  = test_images.astype('float32') / 255# 准备标签
from keras.utils import to_categorical 
train_labels = to_categorical(train_labels) 
test_labels  = to_categorical(test_labels)# 构建网络架构
network = models.Sequential() 
network.add(layers.Dense(512, activation='relu', input_shape=(28 * 28,))) 
network.add(layers.Dense(10, activation='softmax'))# 编译步骤
network.compile(optimizer='rmsprop', loss='categorical_crossentropy', metrics=['accuracy'])# 训练网络
network.fit(train_images, train_labels, epochs=5, batch_size=128)# 测试网络
test_loss, test_acc = network.evaluate(test_images, test_labels) print("Loss: {}, Acc: {}".format(test_loss, test_acc))

----- 结束后会得到类似如下结果:

Epoch 1/5
469/469 [==============================] - 2s 5ms/step - loss: 0.2598 - accuracy: 0.9253
Epoch 2/5
469/469 [==============================] - 2s 5ms/step - loss: 0.1041 - accuracy: 0.9692
Epoch 3/5
469/469 [==============================] - 2s 5ms/step - loss: 0.0684 - accuracy: 0.9795
Epoch 4/5
469/469 [==============================] - 2s 5ms/step - loss: 0.0492 - accuracy: 0.9848
Epoch 5/5
469/469 [==============================] - 2s 5ms/step - loss: 0.0367 - accuracy: 0.9892
313/313 [==============================] - 0s 702us/step - loss: 0.0665 - accuracy: 0.9803
Loss: 0.06652633100748062, Acc: 0.9803000092506409

参考书籍:Python 深度学习

http://www.lryc.cn/news/509745.html

相关文章:

  • HTML+CSS+JS制作在线书城网站(内附源码,含5个页面)
  • 【FastAPI】中间件
  • 5个实用的设计相关的AI网站
  • STL 六大组件
  • Python选择题训练工具:高效学习、答题回顾与音频朗读一站式体验
  • Python实现机器学习驱动的智能医疗预测模型系统的示例代码框架
  • AWS Certified AI Practitioner 自学考试心得
  • JQ中的each()方法与$.each()函数的使用区别
  • 滚珠丝杆与直线导轨的区别
  • 【Ovis】Ovis1.6的本地部署及推理
  • C语言结构体位定义(位段)的实际作用深入分析
  • 儿童影楼管理系统:基于SSM的创新设计与功能实现
  • 青蛇人工智能学家
  • uniapp+vue 前端防多次点击表单,防误触多次请求方法。
  • 【ES6复习笔记】rest参数(7)
  • Hive SQL 窗口函数 `ROW_NUMBER() ` 案例分析
  • 前端mock数据 —— 使用Apifox mock页面所需数据
  • 车载U盘制作教程:轻松享受个性化音乐
  • springboot 3 websocket react 系统提示,选手实时数据更新监控
  • 现代图形API综合比较:Vulkan DirectX Metal WebGPU
  • 【Hot100刷题计划】Day04 栈专题 1~3天回顾(持续更新)
  • 用VBA将word文档处理成支持弹出式注释的epub文档可用的html内容
  • 舵机原理介绍 简洁讲解面向实战 非阻塞式驱动代码, arduino
  • Oracle Database 23ai 中的DBMS_HCHECK
  • 如何利用AWS监听存储桶并上传到tg bot
  • STM32 SPI读取SD卡
  • TANGO与LabVIEW控制系统集成
  • eth_type_trans 函数
  • 派克汉尼汾推出新的快换接头产品系列,扩展热管理解决方案
  • uniapp 前端解决精度丢失的问题 (后端返回分布式id)