当前位置: 首页 > news >正文

Pytorch实现感知器并实现分类动画

这个实现包含以下关键部分:

  1. 数据生成:使用用户提供的函数生成两类可线性分离的数据点。

  2. 感知机模型

    • 一个线性层接收二维输入并输出一个值
    • 不使用激活函数(原始感知机形式)
    • 使用均方误差损失函数(MSE)和随机梯度下降优化器
  3. 动态可视化

    • 使用 matplotlib 的 FuncAnimation 创建动画
    • 每帧更新显示当前决策边界和损失值
    • 数据点根据真实标签着色(蓝色为 - 1,红色为 1)
    • 绿色线表示当前感知机的决策边界

运行代码后,你将看到一个动画展示感知机如何逐步学习区分两类数据的决策边界。随着训练的进行,决策边界会不断调整,直到能够正确分离两个类别。

 

import os
os.environ['KMP_DUPLICATE_LIB_OK'] = 'True'
import numpy as np
import torch
import torch.nn as nn
import matplotlib.pyplot as plt
from matplotlib.animation import FuncAnimation# 数据生成函数(保持与用户提供的一致)
def generate_data():np.random.seed(0)class_1 = np.random.randn(100, 2) + np.array([2, 2])class_2 = np.random.randn(100, 2) + np.array([-2, -2])labels_1 = np.ones((100, 1))labels_2 = -np.ones((100, 1))data = np.vstack((class_1, class_2))labels = np.vstack((labels_1, labels_2))return torch.Tensor(data), torch.Tensor(labels)# 感知机模型
class Perceptron(nn.Module):def __init__(self):super(Perceptron, self).__init__()self.linear = nn.Linear(2, 1)  # 二维输入,一维输出def forward(self, x):return self.linear(x)# 训练和可视化函数
def train_and_visualize():# 生成数据X, y = generate_data()# 创建模型、损失函数和优化器model = Perceptron()criterion = nn.MSELoss()optimizer = torch.optim.SGD(model.parameters(), lr=0.01)# 设置图形fig, ax = plt.subplots(figsize=(10, 8))scatter = ax.scatter(X[:, 0], X[:, 1], c=y.numpy().flatten(), cmap='coolwarm', alpha=0.7)line, = ax.plot([], [], 'g-', lw=2)ax.set_xlim(-6, 6)ax.set_ylim(-6, 6)ax.set_title('Perceptron Classification')# 初始化线def init():line.set_data([], [])return line,# 更新函数def update(frame):# 训练一步optimizer.zero_grad()outputs = model(X)loss = criterion(outputs, y)loss.backward()optimizer.step()# 获取当前权重和偏置w1, w2 = model.linear.weight.data[0]b = model.linear.bias.data[0]# 计算决策边界x_vals = np.linspace(-6, 6, 100)y_vals = -(w1 * x_vals + b) / w2# 更新线line.set_data(x_vals, y_vals)ax.set_title(f'Perceptron Classification (Epoch {frame + 1}, Loss: {loss.item():.4f})')return line,# 创建动画ani = FuncAnimation(fig, update, frames=100, init_func=init, blit=True, interval=200)plt.show()return ani# 运行训练和可视化
if __name__ == "__main__":animation = train_and_visualize()

http://www.lryc.cn/news/586633.html

相关文章:

  • 深入理解观察者模式:构建松耦合的交互系统
  • 为什么玩游戏用UDP,看网页用TCP?
  • 【C++详解】STL-priority_queue使用与模拟实现,仿函数详解
  • 信息收集实战
  • 【读书笔记】《C++ Software Design》第九章:The Decorator Design Pattern
  • 设计模式:软件开发的高效解决方案(单例、工厂、适配器、代理)
  • 基于无人机 RTK 和 yolov8 的目标定位算法
  • 一文认识并学会c++模板(初阶)
  • AI 助力编程:Cursor Vibe Coding 场景实战演示
  • 基于 Redisson 实现分布式系统下的接口限流
  • 牛客网50题
  • 【C/C++】编译期计算能力概述
  • [Python] -实用技巧篇1-用一行Python代码搞定日常任务
  • python-range函数
  • 校园幸运抽(抽奖系统)测试报告
  • 第七章应用题
  • HT8313功放入门
  • HashMap的原理
  • 数据结构与算法之美:线索二叉树
  • 蒙特卡洛树搜索方法实践
  • 蓝牙调试抓包工具--nRF Connect移动端 使用详细总结
  • 生成式对抗网络(GAN)模型原理概述
  • Java生产带文字、带边框的二维码
  • 牛客:HJ19 简单错误记录[华为机考][字符串]
  • 009 ST表:静态区间最值的极致优化
  • 面试现场:奇哥扮猪吃老虎,RocketMQ高级原理吊打面试官
  • MyBatis实现分页查询-苍穹外卖笔记
  • comfyUI-controlNet-线稿软边缘
  • python-enumrate函数
  • HarmonyOS从入门到精通:动画设计与实现之六 - 动画曲线与运动节奏控制