当前位置: 首页 > news >正文

机器学习(九):KNN算法全解析与项目实践

声明:未经允许,禁止转载

算法原理

kkk近邻算法是一种常见的懒惰监督学习算法,懒惰指的是该算法不需要训练,只需要计算测试样本到各个训练集样本的距离,然后找出距离测试样本最近的kkk个训练样本:

  • 对于分类任务,使用投票法,将kkk个训练样本的大多数样本所属的类别作为该预测样本的类别;
  • 对于回归任务,将kkk个样本实值的均值作为预测结果。

kkk近邻算法中,kkk和距离度量的选择都十分重要,kkk取不同值时,分类结果可能会差异很大(见下图),而距离度量的不同会影响kkk个近邻的选择。

k-nearest-neighbor

算法实践

本文在鸢尾花数据集上进行kkk近邻算法的实践。数据集按训练集和测试集7:37:37:3的比例划分,具体实现代码如下:

import numpy as np
from scipy.spatial.distance import cdist
from scipy.stats import mode
from process import read_data
from sklearn.metrics import accuracy_score
from sklearn.neighbors import KNeighborsClassifier
import matplotlib.pyplot as pltclass KNNClassifier:def __init__(self, k, metric='euclidean'):self.k = kself.train_x = Noneself.train_y = Noneself.metric = metricdef fit(self, x, y):self.train_x = xself.train_y = ydef euclidean_distance(self, x1, x2):return cdist(x1, x2, metric=self.metric)def predict(self, test_x):# 距离计算dis_x = self.euclidean_distance(test_x, self.train_x)# 获取测试样本对应的k个最近邻的索引k_indices = np.argsort(dis_x, axis=1)[:, :self.k]# 根据索引获取对应的标签k_labels = self.train_y[k_indices]# 对每个样本的k个最近邻的标签进行投票选出频率最高的作为预测概率y_hat, _ = mode(k_labels, axis=1)return y_hatif __name__ == "__main__":data_path = "../datasets/iris.csv"metric = 'euclidean'train_x, train_y, test_x, test_y = read_data(data_path)k_max = 20accs = []for k in range(1, k_max + 1):# 自建模型cknn = KNNClassifier(k, metric)cknn.fit(train_x, train_y)cy_hat = cknn.predict(test_x)custom_acc = accuracy_score(test_y.reshape(-1), cy_hat.reshape(-1))# 调库sknn = KNeighborsClassifier(n_neighbors=k, metric=metric)sknn.fit(train_x, train_y)sy_hat = sknn.predict(test_x)sklearn_acc = accuracy_score(test_y.reshape(-1), sy_hat.reshape(-1))accs.append(custom_acc)print("k[{}]\tcustom acc[{}]\tsklearn acc[{}]".format(k, custom_acc, sklearn_acc))x = np.arange(1, k_max + 1)plt.figure()plt.xlabel("k")plt.ylabel("accuracy")plt.xticks(x)plt.plot(x, accs, marker="<")plt.savefig("accuracy.png")

实践结果中计算了k=[1,2,3,...,20]k=[1,2,3,...,20]k=[1,2,3,...,20]时的准确率,下图为对应的结果,横轴为kkk的取值,纵轴为对应在测试集上的准确率:

knn-accuracy

结语

参考资料:《机器学习》周志华
源码地址:KNN
以上便是本文的全部内容,若有任何错误敬请批评指正,觉得不错的话可以支持一下,不胜感激!!!。

http://www.lryc.cn/news/600152.html

相关文章:

  • C/C++---I/O性能优化
  • 谁将统治AI游戏时代?腾讯、网易、米哈游技术暗战
  • 《C++ vector 完全指南:vector的模拟实现》
  • LeetCode|Day25|389. 找不同|Python刷题笔记
  • UE5多人MOBA+GAS 30、技能升级机制
  • 动漫花园资源网在线观看,动漫花园镜像入口
  • 基于Java的健身房管理系统
  • HTTP 与 SpringBoot 参数提交与接收协议方式
  • [MMU]TLB Miss 后的 Hardware Table Walk及优化
  • AI与区块链融合:2025年的技术革命与投资机遇
  • c语言-数据结构-沿顺相同树解决对称二叉树问题的两种思路
  • Web前端:JavaScript Math内置对象
  • ABP VNext + OData:实现可查询的 REST API
  • MyBatis-Plus极速开发指南
  • Springboot3.0 集成 RocketMQ5
  • 理解Spring中的IoC
  • 数字增加变化到目标数值动画,js实现
  • 2025年-ClickHouse 高性能实时分析数据库(大纲版)
  • cha的操作
  • 最新Amos 29下载及详细安装教程,附免激活中文版Amos安装包
  • Nature Communications:复杂光下多维视觉信息处理,利用时间演变的环境极化敏感神经突触器件
  • 基于Docker的GPU版本飞桨PaddleOCR部署深度指南(国内镜像)2025年7月底测试好用:从理论到实践的完整技术方案
  • JavaScript 中 let 在循环中的作用域机制解析
  • 【深度学习新浪潮】Claude code是什么样的一款产品?
  • swiper.js实现名录上下滚动
  • Python 条件分支与循环详解--python004
  • 【Agent】API Reference Manual(API 参考手册)
  • 【Spring AI详解】开启Java生态的智能应用开发新时代(附不同功能的Spring AI实战项目)
  • 手写A2C(FrozenLake环境)
  • 牛客刷题记录01