当前位置: 首页 > news >正文

Tensorflow2.0:CNN、ResNet实现MNIST分类识别

以下仅是个人的学习笔记 ,内容可能是错误

CNN: 

import tensorflow as tf
from tensorflow import keras
from tensorflow.keras import layers# 导入数据
(x_train, y_train), (x_test, y_test) = keras.datasets.mnist.load_data()# 数据预处理
x_train = x_train.reshape(-1, 28, 28, 1) / 255.0
x_test = x_test.reshape(-1, 28, 28, 1) / 255.0# 构建模型
model = keras.Sequential([layers.Conv2D(filters=32, kernel_size=(3, 3), activation='relu', input_shape=(28, 28, 1)),layers.MaxPooling2D(pool_size=(2, 2)),layers.Flatten(),layers.Dense(10, activation='softmax')
])# 编译模型
model.compile(optimizer='adam',loss='sparse_categorical_crossentropy',metrics=['accuracy'])# 训练模型
model.fit(x_train, y_train, epochs=5, validation_data=(x_test, y_test))# 评估模型
test_loss, test_acc = model.evaluate(x_test, y_test)
print('Test accuracy:', test_acc)

ResNet18: 

import tensorflow as tf
from keras import layers, models, datasets
import os# 定义gpu
os.environ['CUDA_VISIBLE_DEVICES'] = '0'  # 指定GPU编号
gpus = tf.config.experimental.list_physical_devices('GPU')
if gpus:try:tf.config.experimental.set_memory_growth(gpus[0], True)  # 动态申请显存except RuntimeError as e:print(e)# 加载数据集
(train_images, train_labels), (test_images, test_labels) = datasets.cifar10.load_data()# 数据预处理
train_images, test_images = train_images / 255.0, test_images / 255.0# 搭建残差模块
def resnet_block(inputs, num_filters=16, kernel_size=3, strides=1, activation='relu'):x = layers.Conv2D(num_filters, kernel_size=kernel_size, strides=strides, padding='same')(inputs)x = layers.BatchNormalization()(x)if activation:x = layers.Activation(activation)(x)return x# 定义resnet
def resnet18():inputs = layers.Input(shape=(32, 32, 3))num_filters = 64t = layers.BatchNormalization()(inputs)t = resnet_block(t, num_filters=num_filters)for i in range(2):t = resnet_block(t, num_filters=num_filters, activation=None)t = layers.Add()([t, layers.Activation('relu')(t)])t = resnet_block(t, num_filters=num_filters * 2, strides=2, activation=None)t = layers.Add()([t, resnet_block(t, num_filters=num_filters * 2)])num_filters *= 2for i in range(2):t = resnet_block(t, num_filters=num_filters, activation=None)t = layers.Add()([t, layers.Activation('relu')(t)])t = resnet_block(t, num_filters=num_filters * 2, strides=2, activation=None)t = layers.Add()([t, resnet_block(t, num_filters=num_filters * 2)])num_filters *= 2for i in range(2):t = resnet_block(t, num_filters=num_filters, activation=None)t = layers.Add()([t, layers.Activation('relu')(t)])t = layers.AveragePooling2D()(t)outputs = layers.Dense(10, activation='softmax')(layers.Flatten()(t))model = models.Model(inputs, outputs)return model# 定义模型
model = resnet18()
model.compile(optimizer='adam',loss='sparse_categorical_crossentropy',metrics=['accuracy'])# 训练 CPU
# history = model.fit(train_images, train_labels, epochs=10,
#                     validation_data=(test_images, test_labels))with tf.device('GPU:0'):  # 指定使用GPUhistory = model.fit(train_images, train_labels, epochs=10,validation_data=(test_images, test_labels))

 

http://www.lryc.cn/news/236585.html

相关文章:

  • 本地jar导入maven
  • 数据结构与算法【堆】的Java实现
  • 同创永益联合红帽打造一站式数字韧性解决方案
  • c++ call_once 使用详解
  • 【rosrun diagnostic_analysis】报错No module named rospkg | ubuntu 20.04
  • 高防CDN有什么作用?
  • 盛元广通开放实训室管理系统2.0
  • 3D建模基础教程:编辑多边形功能命令快捷方式
  • SaleSmartly新增AI意图识别触发器!让客户享受更精准的自动化服务
  • 计算机毕业设计选题推荐-个人博客微信小程序/安卓APP-项目实战
  • 一篇详解,Postman设置token依赖步骤
  • 音频录制实现 绘制频谱
  • nginx代理本地服务请求,避免跨域;前端图片压缩并上传
  • Vue3-readonly(深只读) 与 shallowReadonly(浅只读)
  • 中小企业怎么实现数字化转型?有什么实用的工单管理系统?
  • vue3.x中父组件添加自定义参数后,如何获取子组件$emit传递过来的参数
  • 【Machine Learning in R - Next Generation • mlr3】
  • CorelDraw2024(CDR)- 矢量图制作软件介绍
  • RT-DETR优化改进:轻量级Backbone改进 | VanillaNet极简神经网络模型 | 华为诺亚2023
  • 本地部署 EmotiVoice易魔声 多音色提示控制TTS
  • 5g路由器赋能园区无人配送车联网应用方案
  • ARTS 打卡第一周
  • 第八部分:JSP
  • Github小彩蛋显示自己的README,git 个人首页的 README,readme基本语法
  • dxva2+ffmpeg硬件解码(Windows)终结发布
  • C#密封类、偏类
  • C++菱形继承问题
  • 第20章 数据库编程
  • PS学习笔记——初识PS界面
  • JDBC,Java连接数据库