当前位置: 首页 > news >正文

基于最近邻数据进行分类

 人工智能例子汇总:AI常见的算法和例子-CSDN博客 

完整代码:

import torch
import numpy as np
from sklearn.neighbors import KNeighborsClassifier
from sklearn.metrics import accuracy_score
import matplotlib.pyplot as plt# 生成一个简单的数据集 (2个特征和2个分类)
# X为输入特征,y为标签
X = np.array([[1, 2], [2, 3], [3, 4], [5, 7], [6, 8], [7, 9], [8, 10], [3, 6], [4, 5], [6, 4]])
y = np.array([0, 0, 0, 1, 1, 1, 1, 0, 0, 1])# 数据转换为 PyTorch 张量
X_tensor = torch.tensor(X, dtype=torch.float32)
y_tensor = torch.tensor(y, dtype=torch.long)# 打印数据
print("Features:")
print(X_tensor)
print("Labels:")
print(y_tensor)# 使用 sklearn KNN 分类器,调整邻居数量为 5
knn = KNeighborsClassifier(n_neighbors=5)
knn.fit(X, y)# 预测
y_pred = knn.predict(X)# 计算准确率
accuracy = accuracy_score(y, y_pred)
print(f"Accuracy: {accuracy * 100:.2f}%")# 可视化数据
plt.figure(figsize=(6, 4))
plt.scatter(X[:, 0], X[:, 1], c=y, cmap='bwr', marker='o', edgecolor='k', s=100)
plt.title("KNN Classification Example")
plt.xlabel("Feature 1")
plt.ylabel("Feature 2")
plt.show()# 测试:给定新的输入数据进行预测
test_data = np.array([[5, 6], [2, 3]])
test_prediction = knn.predict(test_data)print(f"Predictions for test data {test_data} are {test_prediction}")
  • 生成数据:创建了一个具有 2 个特征和 2 个类别标签的数据集。X 是输入特征,y 是标签。
  • 转换为 PyTorch 张量:虽然这里我们不需要在 KNN 算法中使用 PyTorch,但我们将数据转换为 PyTorch 张量,显示如何与 PyTorch 数据结构进行交互。
  • KNN 分类器:使用 sklearn.neighbors.KNeighborsClassifier 创建并训练 KNN 模型。我们将 n_neighbors 设置为 5,即选择 5 个最近邻。
  • 预测与准确率:使用训练好的模型对所有数据进行预测,并计算准确率。
  • 可视化:使用 matplotlib 将数据点可视化,数据点的颜色根据标签进行区分。
  • 测试预测:我们对新的测试数据点 [5, 6][2, 3] 进行预测。
  • 结果:
  • Features:
    tensor([[ 1.,  2.],[ 2.,  3.],[ 3.,  4.],[ 5.,  7.],[ 6.,  8.],[ 7.,  9.],[ 8., 10.],[ 3.,  6.],[ 4.,  5.],[ 6.,  4.]])
    Labels:
    tensor([0, 0, 0, 1, 1, 1, 1, 0, 0, 1])
    Accuracy: 90.00%
    Predictions for test data [[5 6][2 3]] are [1 0]

http://www.lryc.cn/news/530580.html

相关文章:

  • DeepSeek V3 vs R1:大模型技术路径的“瑞士军刀“与“手术刀“进化
  • 一、TensorFlow的建模流程
  • 指导初学者使用Anaconda运行GitHub上One - DM项目的步骤
  • 7层还是4层?网络模型又为什么要分层?
  • C++:抽象类习题
  • C++ 泛型编程指南02 (模板参数的类型推导)
  • 音视频入门基础:RTP专题(5)——FFmpeg源码中,解析SDP的实现
  • 计算机网络 应用层 笔记 (电子邮件系统,SMTP,POP3,MIME,IMAP,万维网,HTTP,html)
  • 【视频+图文详解】HTML基础3-html常用标签
  • FreeRTOS学习 --- 消息队列
  • PHP If...Else 语句详解
  • pytorch使用SVM实现文本分类
  • 安卓(android)读取手机通讯录【Android移动开发基础案例教程(第2版)黑马程序员】
  • 【Qt】常用的容器
  • 基于UKF-IMM无迹卡尔曼滤波与交互式多模型的轨迹跟踪算法matlab仿真,对比EKF-IMM和UKF
  • 分布式事务组件Seata简介与使用,搭配Nacos统一管理服务端和客户端配置
  • JavaScript常用的内置构造函数
  • 25寒假算法刷题 | Day1 | LeetCode 240. 搜索二维矩阵 II,148. 排序链表
  • MQTT知识
  • 【机器学习与数据挖掘实战】案例11:基于灰色预测和SVR的企业所得税预测分析
  • 新一代搜索引擎,是 ES 的15倍?
  • 使用 Context API 管理临时状态,避免 Redux/Zustand 的持久化陷阱
  • PyTorch框架——基于深度学习YOLOv8神经网络学生课堂行为检测识别系统
  • word2vec 实战应用介绍
  • C# 操作符重载对象详解
  • python学opencv|读取图像(五十四)使用cv2.blur()函数实现图像像素均值处理
  • CNN的各种知识点(四): 非极大值抑制(Non-Maximum Suppression, NMS)
  • 虚幻基础16:locomotion direction
  • C++游戏开发实战:从引擎架构到物理碰撞
  • 代理模式——C++实现