当前位置: 首页 > news >正文

机器学习笔记——K近邻算法、手写数字识别

KNN算法

“物以类聚,人以群分”相似的数据往往拥有相同的类别
其大概原理就是一个样本归到哪一类,当前样本需要归到频次最高的哪个类去
也就是说有一个待分类的样本,然后跟他周围的k个样本来看,k中哪一个类最多,待分类的样本就是哪一个。
那就以手写数字识别为例吧

import matplotlib.pyplot as plt
import numpy as np
import os
#%%
# 读入mnist数据集
m_x = np.loadtxt('./data/mnist_x', delimiter=' ')
m_y = np.loadtxt('./data/mnist_y')
#%%
# 数据集可视化
data = np.reshape(np.array(m_x[0], dtype=int), [28, 28])
plt.figure()
plt.imshow(data, cmap='gray')
#%%
# 将数据集分为训练集和测试集
ratio = 0.8
split = int(len(m_x) * ratio)
# 打乱数据
np.random.seed(0)
idx = np.random.permutation(np.arange(len(m_x))) #随机排序
m_x = m_x[idx]
m_y = m_y[idx]
x_train, x_test = m_x[:split], m_x[split:]
y_train, y_test = m_y[:split], m_y[split:]
#%%
#定义距离函数
def distance(x,y):return np.sqrt(np.sum(np.square(x-y)))#%%
#定义KNN模型
class KNN:def __init__(self,k,label_num):self.k=kself.label_num=label_num #类别的数量def fit(self,x_train,y_train):self.x_train=x_trainself.y_train=y_traindef get_knn_indices(self,x): #获得距离目标样本最近的k个点的标签,a来做self_x.traindis=list(map(lambda a:distance(a,x),self.x_train))knn_indices=np.argsort(dis) #对距离排序,在选择k个出来knn_indices=knn_indices[:self.k]#标签return knn_indicesdef get_label(self,x):#计算k个点中,样本的标签数量是多少knn_indices=self.get_knn_indices(x)label_statistic=np.zeros(shape=[self.label_num])for index in knn_indices:label=int(self.y_train[index])label_statistic[label]+=1return np.argmax(label_statistic) #找出最大的类别def predict(self,x_test):predicted_test_labels=np.zeros(shape=[len(x_test)],dtype=int)for i,x in enumerate(x_test): #枚举predicted_test_labels[i]=self.get_label(x)return predicted_test_labels#%%
for k in range(1,10):knn=KNN(k,label_num=10)knn.fit(x_train,y_train)predicted_labels=knn.predict(x_test)accuracy=np.mean(predicted_labels==y_test)print(f'k的取值为{k},预测准确率为{accuracy*100:.lf}%')
http://www.lryc.cn/news/355034.html

相关文章:

  • 基于STM32实现智能园艺系统
  • 网络原理-HTTP协议
  • 【ES001】elasticsearch实战经验总结(最近更新中)
  • OpenBayes 一周速览|TripoSR 开源:1 秒即 2D 变 3D、经典 GTZAN 音乐数据集上线
  • 【论文笔记】advPattern
  • 【鱼眼镜头11】Kannala-Brandt模型和Scaramuzza多项式模型区别,哪个更好?
  • 微信小程序仿胖东来轮播和背景效果(有效果图)
  • 10.SpringBoot 统一处理功能
  • 【八股系列】为什么会有webpack配置?webpack的构建流程是什么?
  • sdf 测试-2-openssl
  • 头歌springboot初体验
  • 矩阵对角化在机器学习中的奥秘与应用
  • 操作MySQL数据库
  • Linux shell 文件生成文件脚本(模拟生成文件、生成大量文件)
  • theharvester一键收集域名信息(KALI工具系列十)
  • 「动态规划」删除并获得点数
  • MongoDB CRUD操作:内嵌文档数组查询
  • 【C++】每日一题 50 Pow(x,n)
  • HG/T 6088-2022 透水道路用涂料检测
  • linux定时清理docker日志脚本
  • ROS学习笔记(16):夹缝循迹
  • 【MySQL精通之路】SQL语句(3)-锁和事务语句
  • 211大学计算机专业不考408,新增的交叉专业却考408!南京农业大学计算机考研考情分析!
  • 利用java8 的 CompletableFuture 优化 Flink 程序,性能提升 50%
  • 香橙派 AIpro综合体验及AI样例运行
  • 通过域名接口申请免费的ssl多域名证书
  • 【JAVA WEB实用与优化技巧】如何自己封装一个自定义UI的Swagger组件,包含Swagger如何处理JWT无状态鉴权自动TOKEN获取
  • 理解大语言模型(二)——从零开始实现GPT-2
  • SSH远程登录时常见问题解决
  • 工业级3D开发引擎HOOPS:创新与效率的融合!