当前位置: 首页 > news >正文

MLP实现fashion_mnist数据集分类(1)-模型构建、训练、保存与加载(tensorflow)

1、查看tensorflow版本

import tensorflow as tfprint('Tensorflow Version:{}'.format(tf.__version__))
print(tf.config.list_physical_devices())

在这里插入图片描述

2、fashion_mnist数据集下载与展示

(train_image,train_label),(test_image,test_label) = tf.keras.datasets.fashion_mnist.load_data()
print(train_image.shape)
print(train_label.shape)
print(test_image.shape)
print(test_label.shape)

在这里插入图片描述

import matplotlib.pyplot as plt
# plt.imshow(train_image[0])  # 此处为啥是彩色的?def plot_images_lables(images,labels,start_idx,num=5):fig = plt.gcf()fig.set_size_inches(12,14)for i in range(num):ax = plt.subplot(1,num,1+i)ax.imshow(images[start_idx+i],cmap='binary')title = 'label=' + str(labels[start_idx+i])ax.set_title(title,fontsize=10)ax.set_xticks([])ax.set_yticks([])plt.show()
plot_images_lables(train_image,train_label,0,5)
# plot_images_lables(test_image,test_label,0,5)

在这里插入图片描述

3、数据预处理

X_train,X_test = tf.cast(train_image/255.0,tf.float32),tf.cast(test_image/255.0,tf.float32) # 归一化
y_train,y_test = train_label,test_label # 此处对y没有做onehot处理,需要使用稀疏交叉损失函数

4、模型构建

from keras import Sequential
from keras.layers import Flatten,Dense,Dropout
from keras import Inputmodel = Sequential()
model.add(Input(shape=(28,28)))
model.add(Flatten())
model.add(Dense(units=256,kernel_initializer='normal',activation='relu'))
model.add(Dropout(rate=0.1))
model.add(Dense(units=64,kernel_initializer='normal',activation='relu'))
model.add(Dropout(rate=0.1))
model.add(Dense(units=10,kernel_initializer='normal',activation='softmax'))
model.summary()

在这里插入图片描述

5、模型配置

model.compile(optimizer='adam',loss='sparse_categorical_crossentropy',metrics=['acc'])

6、模型训练

H = model.fit(x=X_train,y=y_train,validation_split=0.2,# validation_data=(X_test,y_test),epochs=10,batch_size=128,verbose=1)

在这里插入图片描述

plt.plot(H.epoch, H.history['loss'], label='loss')
plt.plot(H.epoch, H.history['val_loss'], label='val_loss')
plt.legend()

在这里插入图片描述

plt.plot(H.epoch, H.history['acc'], label='acc')
plt.plot(H.epoch, H.history['val_acc'], label='val_acc')
plt.legend()

在这里插入图片描述

7、模型评估

model.evaluate(X_test,y_test)

在这里插入图片描述

8、模型预测

import numpy as np
import matplotlib.pyplot as pltdef pred_plot_images_lables(images,labels,start_idx,num=5):# 预测res = model.predict(images[start_idx:start_idx+num])res = np.argmax(res,axis=1)# 画图fig = plt.gcf()fig.set_size_inches(12,14)for i in range(num):ax = plt.subplot(1,num,1+i)ax.imshow(images[start_idx+i],cmap='binary')title = 'label=' + str(labels[start_idx+i]) + ', pred=' + str(res[i])ax.set_title(title,fontsize=10)ax.set_xticks([])ax.set_yticks([])plt.show()
pred_plot_images_lables(X_test,y_test,0,5)

在这里插入图片描述

9、模型保存与加载

import numpy as nptf.keras.models.save_model(model,"model.keras")
loaded_model = tf.keras.models.load_model("model.keras")
# assert np.allclose(model.predict(X_test[:5]), loaded_model.predict(X_test[:5]))
print(np.argmax(model.predict(X_test[:5]),axis=1))
print(np.argmax(loaded_model.predict(X_test[:5]),axis=1))

在这里插入图片描述

http://www.lryc.cn/news/341741.html

相关文章:

  • ChatGPT-税收支持新质生产力
  • Linux下深度学习虚拟环境的搭建与模型训练
  • Map-Reduce是个什么东东?
  • 上位机工作感想-从C#到Qt的转变-2
  • 【C++】C++ 中 的 lambda 表达式(匿名函数)
  • OpenSSL实现AES-CBC加解密,可一次性加解密任意长度的明文字符串或字节流(QT C++环境)
  • cURL:命令行下的网络工具
  • Baumer工业相机堡盟工业相机如何通过NEOAPISDK查询和轮询相机设备事件函数(C#)
  • Day45代码随想录动态规划part07:70. 爬楼梯(进阶版)、322. 零钱兑换、279.完全平方数、139.单词拆分
  • 土壤重金属含量分布、Cd镉含量、Cr、Pb、Cu、Zn、As和Hg、土壤采样点、土壤类型分布
  • 力扣:100284. 有效单词(Java)
  • 如何快速掌握DDT数据驱动测试?
  • OpenCV如何实现背投(58)
  • 5-在Linux上部署各类软件
  • 【Jenkins】持续集成与交付 (八):Jenkins凭证管理(实现使用 SSH 、HTTP克隆Gitlab代码)
  • 开源模型应用落地-CodeQwen模型小试-SQL专家测试(二)
  • Arch Linux安装macOS
  • 接口自动化框架篇:Pytest + Allure报告企业定制化实现!
  • 保持 Hiti 证卡打印机清洁的重要性和推荐的清洁用品
  • Unity C#的底层原理概述
  • 国产数据库的发展势不可挡
  • 权益商城系统源码 现支持多种支付方式
  • python安装问题及解决办法(pip不是内部或外部命令也不是可运行)
  • Json高效处理方法
  • 若依分离版-前端使用echarts组件
  • android native开发
  • Partisia Blockchain 生态zk跨链DEX上线,加密资产将无缝转移
  • Vue3组合式API + TS项目中手写国际化插件
  • 深入解析Jackson的ObjectMapper:核心功能与方法指南
  • 计算机是如何执行指令的