当前位置: 首页 > news >正文

机器学习项目一基于KNN算法的手写数字识别

    1、问题分析

     手写数字识别是计算机视觉与机器学习领域的经典问题,核心目标是让计算机自动识别人类手写的 0-9 数字。该问题的关键挑战在于:

  • 手写数字存在形态差异(如笔画粗细、倾斜角度、书写风格不同);
  • 图像数据需转换为机器可处理的数值形式;
  • 需通过训练模型学习数字的特征规律,实现对未知数字的准确分类。

本代码基于已有手写数字图像数据集(shouxieshuzi.png),使用 KNN(K - 近邻)算法实现识别,并通过测试集评估模型准确率。 

 2、解题思路

        针对手写数字识别问题,整体解题流程遵循 “数据处理→模型训练→评估验证” 的经典机器学习框架,具体步骤如下:

  1. 数据加载与预处理:加载原始图像,转换为灰度图,并分割为独立的单数字图像(每个数字为 20×20 像素);
  2. 数据集划分:将分割后的数字图像分为训练集(用于模型学习)和测试集(用于评估模型性能),比例为 1:1;
  3. 特征工程:将 20×20 的二维图像展平为 1×400 的一维向量(以像素值作为特征);
  4. 标签分配:为训练集和测试集的每个数字分配真实标签(0-9);
  5. 模型训练与预测:使用 KNN 算法训练模型,通过测试集预测并计算准确率。

数据集照片 

 

3、代码分析

1. 库导入

import numpy as np
import cv2  # 本身就自定了机器学习的函数
  • numpy:用于数值计算(如数组分割、重塑、维度调整);
  • cv2(OpenCV):用于图像处理(如加载图像、转换灰度图),并内置 KNN 模型接口。

2. 图像加载与预处理 

img = cv2.imread('shouxieshuzi.png')
gray = cv2.cvtColor(img, cv2.COLOR_BGR2GRAY)  # 转换为灰度图
  • 图像加载:通过cv2.imread读取原始图像shouxieshuzi.png(该图像是包含 5000 个手写数字的集合,按 50 行 ×100 列排列);
  • 灰度转换:通过cv2.COLOR_BGR2GRAY将 BGR 格式的彩色图转换为灰度图(单通道,像素值范围 0-255),目的是减少数据维度(从 3 通道降至 1 通道),降低计算量。

 

3. 图像分割(获取单数字样本) 

# 将原始图像分割成独立的数字,每个数字大小20*20,共5000个
cells = [np.hsplit(row, 100) for row in np.vsplit(gray, 50)]  # python列表
# 装成array,形状(50,100,20,20),50行,100列,每个图像20*20大小
x = np.array(cells)
  • 分割逻辑:原始图像是 50 行 ×100 列的数字网格(共 50×100=5000 个数字),每个数字尺寸为 20×20 像素。
    • np.vsplit(gray, 50):将灰度图按行分割为 50 个部分(每部分为 1 行数字,高度 = 原始高度 / 50);
    • np.hsplit(row, 100):将每行再按列分割为 100 个部分(每部分为 1 个数字,宽度 = 原始宽度 / 100);
  • 数据格式:分割后得到列表cells,转换为 numpy 数组x后,形状为(50,100,20,20),即 “50 行 ×100 列 ×20 像素高 ×20 像素宽”。

 

4. 训练集与测试集划分

# 划分为训练集和测试集,比例各占一半
train = x[:, :50]  # 取每一行的前50列数字作为训练集
test = x[:, 50:100]  # 取每一行的后50列数字作为测试集

 划分逻辑:为评估模型泛化能力,将 5000 个数字分为两部分:

  • 训练集:50 行 ×50 列 = 2500 个数字(用于模型学习特征);
  • 测试集:50 行 ×50 列 = 2500 个数字(用于验证模型性能)

 5. 特征向量转换

# 将数据构造为符合KNN的输入,将每个数字的尺寸由20*20调整为1*400(一行*400个像素)
train_new = train.reshape(-1, 400).astype(np.float32)  # Size = (2500, 400)
test_new = test.reshape(-1, 400).astype(np.float32)    # Size = (2500, 400)
  • 转换原因:KNN 算法要求输入为 “样本数 × 特征数” 的二维数组(每行一个样本,每列一个特征)。
  • 操作细节
    • reshape(-1, 400):将 20×20 的二维图像展平为 1×400 的一维向量(20×20=400 像素,每个像素值作为一个特征);
    • astype(np.float32):转换为浮点型,符合 OpenCV 中 KNN 模型的输入格式要求。

6. 标签分配

# 分配标签,分别为训练数据、测试数据分配标签(图像对应的实际值)
k = np.arange(10)  # 生成0-9的数字序列
labels = np.repeat(k, 250)  # 每个数字重复250次(2500个样本/10个数字=250)
train_labels = labels[:, np.newaxis]  # 转换为列向量(形状(2500,1))
test_labels = np.repeat(k, 250)[:, np.newaxis]  # 测试集标签与训练集结构一致
  • 标签逻辑:每个数字(0-9)在训练集和测试集中各有 250 个样本(2500 个样本 / 10 类 = 250),因此通过np.repeat生成重复标签;
  • 格式调整np.newaxis将一维标签数组转换为列向量(形状 (2500,1)),符合 OpenCV 中 KNN 模型对标签格式的要求。

         每250为一个标签,一共生成10个标签正好2500个数字

7. KNN 模型训练与预测 

# 模型构建+训练
knn = cv2.ml.KNearest_create()  # 创建KNN模型实例
knn.train(train_new, cv2.ml.ROW_SAMPLE, train_labels)  # 训练模型
# 预测测试集
ret, result, neighbours, dist = knn.findNearest(test_new, k=3)  # k=3表示取3个最近邻
  • 模型构建cv2.ml.KNearest_create()创建 KNN 分类器实例;
  • 训练过程knn.train接收 3 个参数:
    • train_new:训练集特征(2500×400);
    • cv2.ml.ROW_SAMPLE:指定训练数据按行组织(每行一个样本);
    • train_labels:训练集标签(2500×1);
  • 预测过程knn.findNearest对测试集进行预测,参数k=3表示通过计算测试样本与训练样本的距离,取最近的 3 个样本的多数标签作为预测结果。返回值包括:
    • result:测试集的预测标签(2500×1);
    • neighbours:最近的 3 个训练样本的索引;
    • dist:测试样本与 3 个最近邻的距离。

8. 模型评估(准确率计算)

# 通过测试集校验准确率
matches = result == test_labels  # 比较预测标签与真实标签,相同为True,否则为False
correct = np.count_nonzero(matches)  # 统计正确预测的样本数
accuracy = correct * 100.0 / result.size  # 准确率=正确数/总样本数×100%
print("当前使用KNN识别手写数字的准确率为:%s" % (accuracy))
  • 评估逻辑:通过对比预测标签(result)与真实标签(test_labels),计算准确率(模型性能的核心指标)。
  • 例如:若 2500 个测试样本中正确识别 2300 个,则准确率为2300/2500×100%=92%

4.完整代码 

import numpy as np
import cv2  # 本身就自定了机器学习的函数img = cv2.imread('shouxieshuzi.png')
gray = cv2.cvtColor(img, cv2.COLOR_BGR2GRAY)  # 转换为灰度图# 将原始图像分割成独立的数字,每个数字大小20*20,共5000个
cells = [np.hsplit(row, 100) for row in np.vsplit(gray, 50)]  # 按照顺序来看代码,python列表# 装成array,形状(50,100,20,20),50行,100列,每个图像20*20大小
x = np.array(cells)# 划分为训练集和测试集,比例各占一半
train = x[:, :50]
test = x[:, 50:100]# 将数据构造为符合KNN的输入,将每个数字的尺寸由20*20调整为1*400(一行*400个像素)
train_new = train.reshape(-1, 400).astype(np.float32)  # Size = (2500, 400)
test_new = test.reshape(-1, 400).astype(np.float32)    # Size = (2500, 400)# 分配标签,分别为训练数据、测试数据分配标签(图像对应的实际值)
k = np.arange(10)  # 0123456789
labels = np.repeat(k, 250)  # repeat重复数组中的元素,每个元素重复250次
train_labels = labels[:, np.newaxis]  # np.newaxis是NumPy中的一个特殊对象,用于在数组中增加一个新的维度
test_labels = np.repeat(k, 250)[:, np.newaxis]# 模型构建+训练,sklearn knn...opencv里面也有knn
knn = cv2.ml.KNearest_create()  # 通过cv2创建一个knn模型
knn.train(train_new, cv2.ml.ROW_SAMPLE, train_labels)  # cv2.ml.ROW_SAMPLE:这是一个标志,告诉OpenCV训练数据是按行组织的,即每一行是一个样本。
ret, result, neighbours, dist = knn.findNearest(test_new, k=3)  # knn.predict()# ret:表示查找操作是否成功。
# result:浮点数组,表示测试样本的预测标签。
# neighbours:这是一个整数数组,表示与测试样本最接近的k个邻居的索引。这些索引对应于训练集中的样本,可以用来检查哪些训练样本对预测结果产生了影响。
# dist:这是一个浮点数组,表示测试样本与每个最近邻居之间的距离。这些距离可以帮助理解预测结果的置信度:距离越近,预测通常越可靠。# 通过测试集校验准确率
matches = result == test_labels
correct = np.count_nonzero(matches)
accuracy = correct * 100.0 / result.size
print("当前使用KNN识别手写数字的准确率为:%s" % (accuracy))# 1、输入1张图片,得到这张图片的数字是几?
# 2、改sklearn库来实现,突破性的功能。

总结

      本代码完整实现了基于 KNN 算法的手写数字识别流程,核心思路是 “图像分割→特征展平→标签分配→KNN 训练与评估”。通过将图像像素作为特征,利用 KNN 的 “近邻投票” 机制实现分类,适用于小规模数据集的快速验证。 knn本来就是一个在机器学习中最简单入门的,要理解这可能都整个流程与逻辑。对于刚学习机器学习的路上,我们要从最基础的开始一步一步学习。

 

http://www.lryc.cn/news/601225.html

相关文章:

  • 设计模式(十二)结构型:享元模式详解
  • AI Coding IDE 介绍:Cursor 的入门指南
  • 设计模式(八)结构型:桥接模式详解
  • 以太坊ETF流入量超越比特币 XBIT分析买币市场动态与最新价格
  • 分类预测 | MATLAB基于四种先进的优化策略改进蜣螂优化算法(IDBO)的SVM多分类预测
  • 机器学习—线性回归
  • 数学基础薄弱者的大数据技术学习路径指南
  • Java Ai (day01)
  • Oracle EBS 库存期间关闭状态“已关闭未汇总”处理
  • 【网络协议安全】任务15:DHCP与FTP服务全配置
  • docker与k8s的容器数据卷
  • S7-1500 与 S7-1200 存储区域保持性设置特点详解
  • 三、搭建springCloudAlibaba2021.1版本分布式微服务-springcloud loadbalancer负载均衡
  • Java 大视界 -- Java 大数据机器学习模型在电商客户细分与精准营销活动策划中的应用(367)
  • 机械学习----knn实战案例----手写数字图像识别
  • 人工智能开发框架 04.网络构建
  • spring gateway 配置http和websocket路由转发规则
  • Linux驱动21 --- FFMPEG 音频 API
  • Spring Boot + @RefreshScope:动态刷新配置的终极指南
  • mysql 快速上手
  • 发布 VS Code 扩展的流程:以颜色主题为例
  • 详解力扣高频SQL50题之1164. 指定日期的产品价格【中等】
  • MCP + LLM + Agent 8大架构:Agent能力、系统架构及技术实践
  • 2025年7月25日-7月26日 · AI 今日头条
  • 【测试报告】博客系统(Java+Selenium+Jmeter自动化测试)
  • Jmeter的元件使用介绍:(八)断言器详解
  • OpenResty 高并发揭秘:架构优势与 Linux 优化实践
  • 零基础学习性能测试第六章:性能难点-Jmeter实现海量用户压测
  • 人工智能与城市:城市生活的集成智能
  • FastAPI入门:查询参数模型、多个请求体参数