当前位置: 首页 > news >正文

matplotlib 动态显示训练过程中的数据和模型的决策边界

文章目录

  • Github
  • 官网
  • 文档
  • 简介
  • 动态显示训练过程中的数据和模型的决策边界
    • 安装
    • 源码

Github

  • https://github.com/matplotlib/matplotlib

官网

  • https://matplotlib.org/stable/

文档

  • https://matplotlib.org/stable/api/index.html

简介

matplotlib 是 Python 中最常用的绘图库之一,用于创建各种类型的静态、动态和交互式可视化。

动态显示训练过程中的数据和模型的决策边界

在这里插入图片描述

安装

pip install tensorflow==2.13.1
pip install matplotlib==3.7.5
pip install numpy==1.24.3

源码

import numpy as np
import tensorflow as tf
from tensorflow.keras.models import Sequential
from tensorflow.keras.layers import Dense
import matplotlib.pyplot as plt
from matplotlib.colors import ListedColormap# 生成数据
np.random.seed(0)
num_samples_per_class = 500
negative_samples = np.random.multivariate_normal(mean=[0, 3],cov=[[1, 0.5], [0.5, 1]],size=num_samples_per_class
)
positive_samples = np.random.multivariate_normal(mean=[3, 0],cov=[[1, 0.5], [0.5, 1]],size=num_samples_per_class
)inputs = np.vstack((negative_samples, positive_samples)).astype(np.float32)
targets = np.vstack((np.zeros((num_samples_per_class, 1)), np.ones((num_samples_per_class, 1)))).astype(np.float32)# 将数据分为训练集和测试集
train_size = int(0.8 * len(inputs))
X_train, X_test = inputs[:train_size], inputs[train_size:]
y_train, y_test = targets[:train_size], targets[train_size:]# 构建二分类模型
model = Sequential([# 输入层:输入形状为 (2,)# 第一个隐藏层:包含 4 个节点,激活函数使用 ReLUDense(4, activation='relu', input_shape=(2,)),# 输出层:包含 1 个节点,激活函数使用 Sigmoid(因为是二分类问题)Dense(1, activation='sigmoid')
])# 编译模型
# 指定优化器为 Adam,损失函数为二分类交叉熵,评估指标为准确率
model.compile(optimizer='adam', loss='binary_crossentropy', metrics=['accuracy'])# 准备绘图
fig, ax = plt.subplots()
cmap_light = ListedColormap(['#FFAAAA', '#AAAAFF'])
cmap_bold = ListedColormap(['#FF0000', '#0000FF'])# 动态绘制函数
def plot_decision_boundary(epoch, logs):ax.clear()x_min, x_max = X_train[:, 0].min() - 1, X_train[:, 0].max() + 1y_min, y_max = X_train[:, 1].min() - 1, X_train[:, 1].max() + 1xx, yy = np.meshgrid(np.arange(x_min, x_max, 0.1),np.arange(y_min, y_max, 0.1))grid = np.c_[xx.ravel(), yy.ravel()]probs = model.predict(grid).reshape(xx.shape)ax.contourf(xx, yy, probs, alpha=0.8, cmap=cmap_light)ax.scatter(X_train[:, 0], X_train[:, 1], c=y_train[:, 0], edgecolor='k', cmap=cmap_bold)ax.set_title(f'Epoch {epoch+1}')plt.draw()plt.pause(0.01)# 自定义回调函数
class PlotCallback(tf.keras.callbacks.Callback):def on_epoch_end(self, epoch, logs=None):plot_decision_boundary(epoch, logs)# 训练模型并动态显示
plot_callback = PlotCallback()
model.fit(X_train, y_train, epochs=50, batch_size=16, callbacks=[plot_callback])# 评估模型
loss, accuracy = model.evaluate(X_test, y_test)
print(f"Test Loss: {loss}")
print(f"Test Accuracy: {accuracy}")plt.show()
http://www.lryc.cn/news/367762.html

相关文章:

  • 【学术小白成长之路】02三方演化博弈(基于复制动态方程)期望与复制动态方程
  • 短剧看剧系统投流版系统搭建,前端uni-app
  • 最新的ffmepg.js前端VUE3实现视频、音频裁剪上传功能
  • “Apache Kylin 实战指南:从安装到高级优化的全面教程
  • 【iOS】内存泄漏检查及原因分析
  • “深入探讨Java中的对象拷贝:浅拷贝与深拷贝的差异与应用“
  • Docker 进入指定容器内部(以Mysql为例)
  • 计算机网络-数制转换与子网划分
  • 【ssh命令】ssh登录远程服务器
  • 【区块链】truffle测试
  • 【AIGC调研系列】chatTTS与GPT-SoVITS的对比优劣势
  • LLVM Cpu0 新后端10
  • k8s面试题大全,保姆级的攻略哦(二)
  • Mysql:通过一张表里的父子级,递归查询并且分组分级
  • 数据结构之排序算法
  • 移动安全赋能化工能源行业智慧转型
  • 今天是放假带娃的一天
  • linux Ubuntu安装samba服务器与SSH远程登录
  • 纳什均衡:博弈论中的运作方式、示例以及囚徒困境
  • Linux之进程信号详解【上】
  • 【Spring Cloud】Eureka详细介绍及底层原理解析
  • 【清华大学】《自然语言处理》(刘知远)课程笔记 ——NLP Basics
  • 代码随想录 | Day17 | 二叉树:二叉树的最大深度最小深度
  • 【Linux】Socket编程基础
  • 关于stm32的软件复位
  • 规范系统运维:系统性能监控与优化的重要性与实践
  • 用python编撰一个电脑清理程序
  • 2024年【天津市安全员C证】免费试题及天津市安全员C证试题及解析
  • 【Python数据挖掘实战案例】机器学习LightGBM算法原理、特点、应用---基于鸢尾花iris数据集分类实战
  • 使用LabVIEW进行大数据数组操作的优化方法