当前位置: 首页 > news >正文

机器学习实战·第三章 分类(1)

一、MNIST

MNIST 是经典的手写数字数据集,含 6 万训练图和 1 万测试图,每张为 28×28 像素的灰度数字(0-9)。它是机器学习和深度学习的入门基准,常用于测试分类算法效果。

1.1下载MNIST数据集

import os
import numpy as np
from sklearn.datasets import fetch_openml# 设置下载路径
download_dir = "D:\\deskTop\\机器学习实战\\第三章"
os.makedirs(download_dir, exist_ok=True)# 下载 MNIST 数据集到指定路径
mnist = fetch_openml('mnist_784', version=1, data_home=download_dir, as_frame=False)# 将数据集划分为特征和标签
#X是个70000行,784列的二维数组。y是个70000个元素的一维数组,存储的是70000个值['5' '0' '4' ... '4' '5' '6']
X, y = mnist.data, mnist.target# 将标签转换为整数类型
y = y.astype(np.int8)# 可选:将数据保存为 NumPy 数组
np.save(os.path.join(download_dir, 'mnist_X.npy'), X)
np.save(os.path.join(download_dir, 'mnist_y.npy'), y)# 查看数据的形状和标签
print(f"数据集已保存到:{download_dir}")
print("数据形状:", X.shape)
print("标签形状:", y.shape)

fetch_mldata 在 scikit-learn 0.20 版本后已被移除,改用 fetch_openml 替代。

  • fetch_openml('mnist_784') 是获取 MNIST 数据集的新方式
  • 加载后的数据结构与原来类似,包含数据(data)和标签(target
  • version=1 确保获取的是原始格式的 MNIST 数据

1.2 抓取特征向量并显示

MNIST中每张图片是28×28像素,每个像素的强度从0(白色)到255(黑色)。因此每个图片共28×28=784个特征(像素)。将这些特征(像素)重新组成一个28×28的数组再通过Matplotlib显示出来。

# 导入matplotlib库,用于绘制图像
import matplotlib
# 从sklearn库中导入获取公开数据集的工具
from sklearn.datasets import fetch_openml
# 导入matplotlib的绘图模块,给它起个简单的别名plt
import matplotlib.pyplot as plt# 加载MNIST数据集(手写数字图片集)
# 'mnist_784':这是MNIST数据集的标准名称,784表示每张图片是28×28像素(28×28=784)
# version=1:指定使用第1版数据集,保证数据格式正确
# data_home:指定数据集在电脑中的存储路径(这里是你自己的路径)
# as_frame=False:表示让数据以数组形式返回,方便我们处理图像
mnist = fetch_openml('mnist_784', version=1, data_home="D:\\deskTop\\机器学习实战\\第三章", as_frame=False)# 从加载的数据集里提取图片数据和对应的标签
# X:存储所有图片的像素数据,每个图片被转换成了784个数字(28×28像素展开)
# y:存储每个图片对应的数字标签(0-9中的一个)
X, y = mnist.data, mnist.target# 从所有图片中选第10000张(计算机计数从0开始,所以10000是第10001张)
some_digit = X[10000]# 把784个数字重新变成28×28的矩阵(恢复成图片的原始尺寸)
some_digit_image = some_digit.reshape(28, 28)# 显示这张图片
# imshow():matplotlib中显示图像的函数
# cmap=matplotlib.cm.binary:用黑白颜色显示(手写数字适合黑白显示)
# interpolation='nearest':让图像边缘清晰,不模糊
plt.imshow(some_digit_image, cmap=matplotlib.cm.binary, interpolation='nearest')# 关闭图像的坐标轴(不需要显示x轴和y轴)
plt.axis('off')# 弹出窗口显示图片
plt.show()#显示标签
print(y[10000])

1.3 将训练集数据洗牌

# 导入numpy库(用于数值计算,这里用它生成随机序列)
import numpy as np# 生成0到59999的随机排列序列(共60000个数字)
# np.random.permutation(60000)会随机打乱0-59999的顺序,比如可能生成[3, 1, 4, 0, ..., 59998]这样的序列
# 这个序列就是我们用来打乱训练集的"随机索引"
shuffle_index = np.random.permutation(60000)# 根据随机索引重新排列训练集的特征和标签
# X[shuffle_index]:用随机索引重新排序图像数据,让原来的顺序被打乱
# y[shuffle_index]:标签也用同样的随机索引打乱,保证每个图像和它的标签对应关系不变
# 假设原来X[0]对应y[0],打乱后X[新位置]依然对应y[新位置]
X_train, y_train = X[shuffle_index], y[shuffle_index]
  1. 为什么要打乱数据?
    很多机器学习模型(比如梯度下降)在训练时对数据顺序敏感,如果数据按固定顺序排列(比如先全是 0,再全是 1),模型可能学不好。打乱后数据分布更均匀,模型能更稳定地学习。

  2. 代码的核心作用

    • 生成一组 "随机序号"(比如把 60000 个样本的顺序彻底打乱)
    • 用这组序号重新排列图像数据(X)和标签(y),确保每个图像和它的标签始终配对(不会张冠李戴)
  3. 举个小例子
    假设原来数据顺序是:X=[图0, 图1, 图2]y=[0, 1, 2]
    生成的随机索引是[2, 0, 1]
    打乱后变成:X_train=[图2, 图0, 图1]y_train=[2, 0, 1](图像和标签依然对应)

这样处理后,训练集的顺序被随机化,能帮助模型更好地学习通用规律,避免受原始数据顺序的干扰。

二、训练二元分类器

# 从sklearn库中导入获取公开数据集的工具
from sklearn.datasets import fetch_openml  
# 导入处理数组和随机数的工具库
import numpy as np 
# 导入SGD分类器(一种快速的机器学习分类模型)
from sklearn.linear_model import SGDClassifiermnist = fetch_openml('mnist_784', version=1, data_home="D:\\deskTop\\机器学习实战\\第三章", as_frame=False)X, y = mnist.data, mnist.target# 划分训练集和测试集(MNIST标准划分方式)
# 前60000张作为训练集(让模型学习),后10000张作为测试集(检验模型效果)
#这里的X_train是个二维数组(60000, 784)、X_test (10000, 784) 、y_train(60000,) 、y_test(10000,)
X_train, X_test, y_train, y_test = X[:60000], X[60000:], y[:60000], y[60000:]shuffle_index = np.random.permutation(60000)X_train, y_train = X_train[shuffle_index], y_train[shuffle_index]# 定义二分类任务:判断一张图片"是不是数字5"
# 把原始标签转换成True/False(布尔值)
# y_train_5:训练集中的标签,是5就为True,不是5就为False
y_train_5 = (y_train == '5')
# y_test_5:测试集中的标签,同样用True/False表示是否为5
y_test_5 = (y_test == '5')# 创建SGD分类器实例
# random_state=42:固定随机种子,让每次运行结果一致(方便调试)
sgd_clf = SGDClassifier(random_state=42)# 训练模型:让模型学习如何区分"5"和"非5"
# 用训练集的图片(X_train)和对应标签(y_train_5)进行训练
sgd_clf.fit(X_train, y_train_5)# 选一个样本图片来测试模型(这里选第10000张图片)
some_digit = X[10000]# 用训练好的模型预测:这张图片是不是5?
# 输出结果是[False]表示模型认为不是5,[True]表示认为是5
print(sgd_clf.predict([some_digit]))#输出X[10000]这个数字
print(y_train[10000])

2.1 性能考核

2.1.1 使用交叉验证测量精度

机器学习中的交叉验证(Cross-Validation) 是一种评估模型性能的统计方法,通过多次划分训练数据和验证数据,减少因单次数据划分带来的随机性影响,从而更可靠地估计模型在 unseen(未见过)数据上的泛化能力。

核心思想

将原始训练数据集分为两部分:

  • 训练集(Training Set):用于模型训练。
  • 验证集(Validation Set):用于评估模型在训练过程中的性能,帮助调整超参数或选择模型。

通过多次重复 “划分 - 训练 - 验证” 的过程,取多次评估结果的平均值作为最终性能指标,降低数据划分偶然性的影响。

常见交叉验证方法

  1. K 折交叉验证(K-Fold Cross-Validation)

    • 将训练集平均分为 K 个互斥的子集(称为 “折”,Fold)。
    • 每次用 K-1 个折作为训练集,剩余 1 个折作为验证集,重复 K 次(确保每个折都做过验证集)。
    • 最终性能为 K 次验证结果的平均值。
    • 优点:充分利用数据,结果稳定;缺点:计算成本随 K 增加而提高(常用 K=5 或 10)。
  2. 分层 K 折交叉验证(Stratified K-Fold)

    • 针对分类问题,确保每个折中各类别的样本比例与原始训练集一致(避免某一折中类别失衡)。
    • 例如:二分类问题中原始数据正负样本比例为 3:1,每个折中也保持 3:1。
  3. 留一法(Leave-One-Out,LOO)

    • 极端情况的 K 折交叉验证(K = 样本总数)。
    • 每次留 1 个样本作为验证集,其余全部作为训练集,重复 N 次(N 为样本数)。
    • 优点:结果稳定(无随机划分);缺点:计算量极大(适用于小数据集)。
  4. 随机拆分交叉验证(Shuffle-Split Cross-Validation)

    • 多次随机将数据拆分为训练集和验证集(比例可自定义,如 7:3),重复指定次数(如 10 次)。
    • 优点:灵活,可控制训练集 / 验证集比例;缺点:可能有样本重复出现在验证集中。

在本例中使用三折叠的方式,也就是将训练数据平均分为三个互斥的子集,每次使用两个最为训练集一个作为验证集。

from sklearn.datasets import fetch_openml
import numpy as np
from sklearn.linear_model import SGDClassifier
# 导入交叉验证评分函数,用于评估模型性能
from sklearn.model_selection import cross_val_scoremnist = fetch_openml('mnist_784', version=1, data_home="D:\\deskTop\\机器学习实战\\第三章", as_frame=False)X, y = mnist.data, mnist.targetX_train, X_test, y_train, y_test = X[:60000], X[60000:], y[:60000], y[60000:]shuffle_index = np.random.permutation(60000)X_train, y_train = X_train[shuffle_index], y_train[shuffle_index]y_train_5 = (y_train == '5')
y_test_5 = (y_test == '5')sgd_clf = SGDClassifier(random_state=42)sgd_clf.fit(X_train, y_train_5)# 使用交叉验证评估模型性能
# 参数说明:
# - sgd_clf:要评估的模型
# - X_train:训练集特征
# - y_train_5:训练集标签
# - cv=3:采用3折交叉验证(将训练集分为3份,轮流作为验证集)
# - scoring='accuracy':评估指标为准确率(正确预测的样本比例)
# 输出:3次验证的准确率结果(数组形式)
print(cross_val_score(sgd_clf, X_train, y_train_5, cv=3, scoring='accuracy'))

局限性:MNIST 数据集中,“5” 的样本占比约 10%(60000 训练集中约 5421 个 “5”)。即使模型简单地将所有样本预测为 “非 5”,也能达到约 90% 的准确率。此时高准确率并不能反映模型的真实能力(无法识别 “5”)。

假设有一个模型,它不做任何学习,对所有图片都无脑预测为 “不是 5”(即不管输入是什么,输出都是 “非 5”)。

此时:

(正确预测的样本数)÷(总样本数)= 54579 ÷ 60000 ≈ 90.97%

  • 对于 54579 张 “非 5” 图片:模型预测正确(因为确实不是 5)
  • 对于 5421 张 “5” 图片:模型预测错误(把 5 当成了非 5)

这个 “摆烂” 模型完全无法识别 “5”(所有 “5” 都被误判),但准确率却高达 90% 以上。这显然很荒谬 —— 高准确率只是因为 “非 5” 的样本占绝大多数,模型哪怕乱猜 “非 5” 也能蒙对大多数。

这说明:当数据类别不平衡(某一类占比极高)时,准确率会 “欺骗” 我们,让我们误以为模型很好,但实际上模型可能完全没学会核心任务(比如这里的 “识别 5”)

2.1.2 混淆矩阵

混淆矩阵是评估分类模型性能的基础工具,它用一个表格清晰展示模型预测结果与真实标签的匹配情况,能直观反映模型的错误类型(比如 “把 A 错当成 B” 还是 “漏检了 A”),简单的说混淆矩阵其实是模型的成绩单。

以二分类问题(比如 “判断是否为5”)为例,混淆矩阵是一个 2×2 的表格,包含 4 种核心结果:

  • TP:模型猜对了(实际是正类,预测也是正类)​
  • FP:模型瞎报了(实际是负类,却预测为正类)​
  • FN:模型漏检了(实际是正类,却预测为负类)​
  • TN:模型正确排除(实际是负类,预测也是负类)
# 导入需要的工具包
# 从sklearn获取公开数据集的工具
from sklearn.datasets import fetch_openml
# 处理数组的工具
import numpy as np
# 导入SGD分类器(一种快速的机器学习模型)
from sklearn.linear_model import SGDClassifier
# 导入交叉验证相关工具:评估模型和获取预测结果
from sklearn.model_selection import cross_val_score, cross_val_predict
# 导入混淆矩阵工具:展示模型的预测对错情况
from sklearn.metrics import confusion_matrix# 加载MNIST手写数字数据集
# 这个数据集包含70000张手写数字图片,每张图片是28×28像素
# as_frame=False表示返回的是数组格式(不是表格格式)
mnist = fetch_openml('mnist_784', version=1, data_home="D:\\deskTop\\机器学习实战\\第三章", as_frame=False)# 拆分数据:X是图片的像素数据,y是对应的数字标签
# X是70000行784列的二维数组(70000张图,每张图784个像素值)
# y是70000个元素的一维数组(每个元素是图片对应的数字,比如'5'、'0'等字符串)
X, y = mnist.data, mnist.target# 划分训练集和测试集
# 前60000张图片和标签作为训练集(让模型学习)
# 后10000张作为测试集(之后用来检验模型效果)
# X_train:(60000, 784)的二维数组(训练图片)
# X_test:(10000, 784)的二维数组(测试图片)
# y_train:(60000,)的一维数组(训练图片的真实数字)
# y_test:(10000,)的一维数组(测试图片的真实数字)
X_train, X_test, y_train, y_test = X[:60000], X[60000:], y[:60000], y[60000:]# 生成打乱顺序的索引:0到59999的随机排列
# 作用是打乱训练集的顺序,避免模型学到数据的顺序规律(比如前面都是0,后面都是1)
shuffle_index = np.random.permutation(60000)# 按照随机索引重新排列训练集的图片和标签
X_train, y_train = X_train[shuffle_index], y_train[shuffle_index]# 定义二分类任务:只判断图片是不是数字"5"
# 生成新的标签:如果原始标签是'5',就记为True(是5),否则记为False(不是5)
# y_train_5:训练集的二分类标签(60000个布尔值)
# y_test_5:测试集的二分类标签(10000个布尔值)
y_train_5 = (y_train == '5')
y_test_5 = (y_test == '5')# 创建SGD分类器实例
# random_state=42:固定随机数种子,让每次运行结果一样(方便调试)
sgd_clf = SGDClassifier(random_state=42)# 用训练集训练模型
# 让模型学习:从图片像素(X_train)中识别出哪些是5(y_train_5为True),哪些不是
sgd_clf.fit(X_train, y_train_5)# 用交叉验证获取模型的预测结果
# 把训练集分成3份,轮流用2份训练、1份预测,最后得到所有样本的预测结果
# y_train_pred:60000个布尔值的一维数组(模型认为每个样本是不是5)
y_train_pred = cross_val_predict(sgd_clf, X_train, y_train_5, cv=3)# 打印混淆矩阵:展示模型预测的"对错明细"
# 输入真实标签(y_train_5)和预测标签(y_train_pred)
# 输出一个2×2的表格,直观显示模型哪里对了、哪里错了
print(confusion_matrix(y_train_5, y_train_pred))

    • TP = 4193(真实是 5,模型也预测是 5 → 猜对的 5)
    • FP = 773(真实不是 5,模型预测是 5 → 误报成 5 的非 5 样本)
    • FN = 1228(真实是 5,模型预测不是 5 → 漏检的 5)
    • TN = 53806(真实不是 5,模型也预测不是 5 → 正确排除的非 5 样本)

    但是一个完美的分类器只有真正类和真负类,也就是只有对角线上存在数据。

    2.1.3 精度和召回率

    在分类任务中,精度(Precision)也称为精确率、查准率,用于衡量模型预测为正类的样本中,真正是正类的比例 。

    • TP(True Positive,真正例):真实类别为正类,模型预测也为正类的样本数量。比如在手写数字识别中,图片真实数字是 “5”,模型也预测为 “5” 的样本数量。
    • FP(False Positive,假正例):真实类别为负类,模型却预测为正类的样本数量。例如图片真实数字不是 “5”,但模型预测为 “5” 的样本数量。

    举个例子,假设模型预测了 100 个样本为 “5”,其中实际是 “5” 的有 80 个(TP = 80),不是 “5” 却被误判为 “5” 的有 20 个(FP = 20),那么精度为:


    即模型预测为 “5” 的样本中,真正是 “5” 的比例为 80% 。

    召回率(Recall),也叫查全率,是评估分类模型性能的重要指标,尤其在关注 “不能漏掉正例” 的场景(如疾病筛查、安全检测)中非常关键。

    对于 “识别 5” 这类二分类任务,召回率公式为:

    • TP(True Positive):真实是正例(比如真实是 5),模型也预测为正例(模型猜对是 5)
    • FN(False Negative):真实是正例(真实是 5),但模型预测为负例(模型漏检、没认出是 5)

    召回率回答的问题是:“所有真实的正例里,模型成功找出来多少?”

    指标关注问题公式场景侧重
    召回率(Recall)真实正例被模型 “找全” 了吗?TP/FN+TP​不能漏检(如疾病筛查、安防)
    精确率(Precision)模型预测的正例 “真的对吗?”TP/FP+TP​不能误报(如垃圾邮件过滤)
    # 从sklearn的评估指标工具中导入精确率和召回率的计算函数
    # precision_score:用于计算精确率(查准率)
    # recall_score:用于计算召回率(查全率)
    from sklearn.metrics import precision_score, recall_score# ...(省略前面的数据加载、预处理和模型训练代码)# 计算并打印精确率(Precision)
    # 参数说明:
    #   y_train_5:真实标签(布尔值数组),True表示真实是5,False表示真实不是5
    #   y_train_pred:模型预测结果(布尔值数组),True表示模型认为是5,False表示模型认为不是5
    # 精确率公式:TP / (TP + FP) 
    #   TP:真实是5且模型预测是5(猜对的5)
    #   FP:真实不是5但模型预测是5(误判为5的非5样本)
    # 含义:模型预测为"是5"的样本中,真正是5的比例(反映模型"不乱报"的能力)
    print(precision_score(y_train_5, y_train_pred))# 计算并打印召回率(Recall)
    # 参数说明同上,使用相同的真实标签和预测结果
    # 召回率公式:TP / (TP + FN)
    #   TP:同上(猜对的5)
    #   FN:真实是5但模型预测不是5(漏检的5样本)
    # 含义:所有真实是5的样本中,被模型成功识别出来的比例(反映模型"不漏检"的能力)
    print(recall_score(y_train_5, y_train_pred))

    2.1.4F1分数

    F1 分数(F1-Score)是一个综合指标,用来平衡合精确率(Precision) 和召回率(Recall) 的表现,避免单一指标的片面性。

    当精确率和召回率发生冲突时(比如一个高一个低),F1 分数能给出一个平衡的评价。比如:

    • 模型 A:精确率 90%,召回率 60%(太保守,漏检多)
    • 模型 B:精确率 60%,召回率 90%(太激进,误报多)
      F1 分数能客观比较哪个模型整体表现更好

    F1 分数是精确率和召回率的调和平均数

    • 取值范围:0~1(越接近 1,说明精确率和召回率都高,模型越均衡)
    • 特点:对精确率和召回率 “一视同仁”,如果其中一个很低,F1 分数也会很低。
    from sklearn.metrics import f1_score......#输出F1分数
    print(f1_score(y_train_5, y_train_pred))

    在很多实际场景中,我们需要根据具体需求优先关注精确率(Precision)或召回率(Recall),而不是一味追求两者平衡的 F1 分数。这取决于业务中 “误判” 和 “漏判” 的代价哪个更高。

    当 “把不是目标的样本错判为目标” 的代价很高时,需要优先保证精确率(即 “模型说‘是’的,尽量真的是”)。

    当 “把目标样本漏判为非目标” 的代价很高时,需要优先保证召回率(即 “所有真的是目标的,尽量都被模型找出来”)。

    • 垃圾邮件过滤:如果把正常邮件(非垃圾)误判为垃圾邮件(模型说 “是垃圾”),用户可能错过重要邮件,代价很高。此时更关注精确率 —— 宁可漏判一些垃圾邮件(让少量垃圾进 inbox),也不能误删正常邮件。
    • 肿瘤筛查(初步检测):如果用 AI 初步筛选 “疑似癌症”,如果把健康人误判为 “疑似癌症”(模型说 “是癌症”),会导致用户恐慌、不必要的进一步检查(耗时耗钱),此时需要高精确率 —— 模型标记的 “疑似” 必须尽可能真的有问题。
    • 推荐系统(如奢侈品推荐):如果给用户推荐不相关的商品(模型说 “用户会买” 但实际不会),会浪费流量、降低用户体验,此时需要精确率 —— 推荐的东西必须尽可能贴合用户需求。
    • 犯罪嫌疑人识别:如果漏掉了真正的嫌疑人(模型说 “不是嫌疑人” 但实际是),可能导致罪犯逃脱,代价极高。此时更关注召回率 —— 宁可错判一些无辜者(精确率低一点),也不能漏掉真凶。
    • 癌症确诊(最终诊断):如果把真正的癌症患者漏判为 “健康”(模型说 “不是癌症”),会延误治疗,甚至危及生命。此时必须优先保证召回率 —— 哪怕把很多健康人误判为 “疑似”(后续再人工复核),也不能漏掉一个真患者。
    • 网络攻击检测:如果漏掉了真正的攻击行为(模型说 “不是攻击” 但实际是),可能导致系统被入侵、数据泄露。此时需要高召回率 —— 宁可误报一些正常操作(后续人工排查),也不能放过任何潜在攻击。

    2.1.5 精度/召回率权衡

    SGDClassifier 识别 MNIST 中的 "5" 时,会给 784 个像素分别分配权重(比如 "5" 特有的像素位置权重为正),将图片像素值与对应权重相乘后求和,再和一个阈值比较 —— 超过阈值就判定为 "5",否则不是,就像根据关键特征打分后判断是否达标。阈值越高,模型判定为 “5” 的标准就越严格。比如原来 50 分就能算 “是 5”,现在阈值提到 80 分 —— 只有那些像素特征特别符合 “5” 的图片(总分超 80)才会被认成 5,其他哪怕有点像 5 但分数不够的,都会被归为 “不是 5”。

    结果就是:模型说 “是 5” 的图片里,真正是 5 的比例(精度)会提高,但很多真正的 5 可能因为分数不够被漏掉(召回率会下降)。

    # 导入从开放数据源获取数据集的工具
    from sklearn.datasets import fetch_openml
    # 导入数值计算库,用于处理数组和矩阵
    import numpy as np
    # 导入绘图库,用于可视化图像(本代码未实际绘图,可用于扩展)
    import matplotlib.pyplot as plt
    # 导入随机梯度下降分类器模型
    from sklearn.linear_model import SGDClassifier# 获取MNIST手写数字数据集
    # 'mnist_784'是数据集标识,包含70000张28x28像素的手写数字图片(展平为784维特征)
    # version=1指定数据集版本,data_home指定下载路径,as_frame=False返回numpy数组格式
    mnist = fetch_openml('mnist_784', version=1, data_home="D:\\deskTop\\机器学习实战\\第三章", as_frame=False)# 拆分特征数据和标签:
    # X:所有图片的像素数据(70000行×784列),每行是一张图片的784个像素值
    # y:所有图片对应的数字标签(70000个元素),以字符串形式存储(如'5'、'3')
    X, y = mnist.data, mnist.target  # 划分训练集和测试集:
    # 前60000个样本作为训练集(模型学习用),后10000个作为测试集(模型评估用)
    X_train, X_test, y_train, y_test = X[:60000], X[60000:], y[:60000], y[60000:]# 设置随机种子,确保每次运行时随机操作的结果完全一致(可复现性)
    np.random.seed(42)
    # 生成0到59999的随机打乱的索引序列,用于打乱训练集顺序
    # 打乱数据能避免模型学习到样本顺序的无关规律,提高模型泛化能力
    shuffle_index = np.random.permutation(60000)# 使用随机索引重新排列训练集的特征和标签(保持样本与标签的对应关系)
    X_train, y_train = X_train[shuffle_index], y_train[shuffle_index]# 构建二分类任务标签:判断样本是否为数字5
    # y_train_5是布尔数组,True表示对应样本是5,False表示不是5
    y_train_5 = (y_train == '5')
    # 同样处理测试集标签,用于后续模型评估
    y_test_5 = (y_test == '5')# 创建SGD分类器实例,random_state=42固定随机种子,保证训练过程可复现
    sgd_clf = SGDClassifier(random_state=42)# 用训练集训练模型,学习如何区分"是5"和"不是5"
    sgd_clf.fit(X_train, y_train_5)# 找到训练集中所有标签为'5'的样本索引
    # np.where返回满足条件的索引位置,[0]提取索引数组(因为返回格式是元组)
    five_indices = np.where(y_train == '5')[0]
    # 将索引转换为列表并打印,展示所有数字5在训练集中的位置
    print(five_indices.tolist())# 从训练集中选取索引为51868的样本(已知是数字5,从上面的索引列表中选取)
    some_digit = X_train[51868]# 计算该样本的决策分数:
    # 1. 先用reshape(1, -1)将一维数组转为二维数组(1个样本 × 784个特征),符合模型输入要求
    # 2. decision_function返回模型对该样本的评分(正数表示模型认为是5,负数表示认为不是5)
    y_scores = sgd_clf.decision_function(some_digit.reshape(1, -1))# 打印决策分数
    print(y_scores)

    在这个判断是否为 5 的二分类任务里,decision_function 输出分数越高,模型越笃定该样本是数字 5 ,分数高就代表模型对 “是 5” 这个判断的信心很足~。但是究竟分数多低就表示信心不足了呢?换言之,如何决定使用什么样的阈值呢?

    绘制精确率-召回率-阈值曲线
    
    # 导入需要的库
    # fetch_openml用于获取公开数据集,这里用于加载MNIST手写数字数据集
    from sklearn.datasets import fetch_openml
    # numpy用于数值计算和数组操作
    import numpy as np
    # matplotlib.pyplot用于数据可视化
    import matplotlib.pyplot as plt
    # SGDClassifier是随机梯度下降分类器,适用于大规模数据集
    from sklearn.linear_model import SGDClassifier
    # cross_val_score用于交叉验证评分,cross_val_predict用于交叉验证预测
    from sklearn.model_selection import cross_val_score, cross_val_predict
    # precision_recall_curve用于计算精确率、召回率和阈值的关系
    from sklearn.metrics import precision_recall_curve# 设置matplotlib支持中文显示
    # 解决中文显示为方框的问题
    plt.rcParams["font.family"] = ["SimHei", "WenQuanYi Micro Hei", "Heiti TC"]
    # 解决负号显示异常的问题
    plt.rcParams["axes.unicode_minus"] = False# 获取MNIST数据集
    # 'mnist_784'是数据集名称,包含70000张28x28像素的手写数字图片
    # version=1指定数据集版本
    # data_home指定数据集下载后保存的本地路径
    # as_frame=False表示返回numpy数组格式而非DataFrame
    mnist = fetch_openml('mnist_784', version=1, data_home="D:\\deskTop\\机器学习实战\\第三章", as_frame=False)# 提取特征数据和标签
    # X包含所有图像的像素数据,形状为(70000, 784),每行代表一张图片的784个像素(28×28)
    # y包含所有图像对应的标签,形状为(70000,),存储的是0-9的字符串形式数字
    X, y = mnist.data, mnist.target# 划分训练集和测试集
    # MNIST数据集默认前60000个样本为训练集,后10000个为测试集
    X_train, X_test, y_train, y_test = X[:60000], X[60000:], y[:60000], y[60000:]# 打乱训练集顺序
    # 生成0-59999的随机排列索引,用于打乱训练集,避免数据顺序对模型产生影响
    shuffle_index = np.random.permutation(60000)
    # 根据随机索引重新排列训练集的特征和标签
    X_train, y_train = X_train[shuffle_index], y_train[shuffle_index]# 创建二分类标签:是否为数字5
    # 将原始标签转换为布尔值,True表示该样本是数字5,False表示不是
    # 这样就将问题转化为"是5"和"不是5"的二分类问题
    y_train_5 = (y_train == '5')  # 训练集的二分类标签
    y_test_5 = (y_test == '5')    # 测试集的二分类标签# 初始化并训练SGD分类器
    # SGDClassifier使用随机梯度下降算法,适合处理大型数据集
    # random_state=42设置随机种子,保证结果可复现
    sgd_clf = SGDClassifier(random_state=42)
    # 使用训练集训练模型,学习如何区分数字5和非5
    sgd_clf.fit(X_train, y_train_5)# 使用交叉验证获取决策分数
    # cross_val_predict进行交叉验证,并返回每个样本的决策分数
    # cv=3表示3折交叉验证,将训练集分成3份,轮流用2份训练1份验证
    # method="decision_function"表示返回决策函数的值(而非预测类别),用于后续计算阈值
    y_scores = cross_val_predict(sgd_clf, X_train, y_train_5,cv=3, method="decision_function"
    )# 输出部分决策分数
    # 决策分数表示模型认为样本属于正类(数字5)的置信度,值越高越可能是5
    # 只输出前10个,避免输出过多
    print("部分决策分数:", y_scores[:10])# 计算不同阈值下的精确率和召回率
    # precision_recall_curve函数会计算所有可能阈值对应的精确率和召回率
    # 精确率(Precision):预测为5的样本中真正是5的比例
    # 召回率(Recall):所有真正是5的样本中被成功预测为5的比例
    # 阈值(Threshold):用于判断样本属于正类或负类的临界值
    precisions, recalls, thresholds = precision_recall_curve(y_train_5, y_scores)# 定义绘制精确率-召回率-阈值曲线的函数
    def plot_precision_recall_curve(precisions, recalls, thresholds):# 创建一个10x6英寸的图形plt.figure(figsize=(10, 6))# 绘制精确率曲线,蓝色虚线,标签为"精确率(Precision)"plt.plot(thresholds, precisions[:-1], 'b--', label='精确率(Precision)')# 绘制召回率曲线,绿色实线,标签为"召回率(Recall)"plt.plot(thresholds, recalls[:-1], 'g-', label='召回率(Recall)')# 设置x轴标签plt.xlabel('阈值(Threshold)')# 设置y轴标签plt.ylabel('值')# 设置图表标题plt.title('精确率和召回率随阈值变化的曲线')# 显示图例,自动选择最佳位置plt.legend(loc='best')# 设置y轴范围,稍微超过1使图形更美观plt.ylim([0, 1.05])# 添加网格线,透明度0.3plt.grid(True, alpha=0.3)# 调用函数绘制曲线
    plot_precision_recall_curve(precisions, recalls, thresholds)
    # 显示图形
    plt.show()

    1. 先理解「阈值(Threshold)」:模型的 “5 判定标准”

    模型判断一个手写数字是不是 5 时,会先算一个 **“分数”**(由 SGDClassifier 的 decision_function 输出)。

    • 分数越高 → 模型越觉得这是 5;
    • 分数越低 → 模型越觉得这不是 5。

    阈值就是你定的一条 “及格线”:

    • 分数 ≥ 阈值 → 模型说 “是 5”;
    • 分数 < 阈值 → 模型说 “不是 5”。

    2. 绿色线(召回率):「真 5 被找全了吗?」

    召回率的意思是:所有真正是 5 的数字里,有多少被模型找出来了

    • 当阈值特别小(比如图最左边,阈值 =-200000):
      模型会 “不管三七二十一,觉得谁都像 5”,几乎所有真 5 都会被选中 → 召回率接近 1(绿色线贴到最顶)。
      但代价是:很多不是 5 的数字(比如 2、3、7)也会被误判成 5 → 精确率会特别低(蓝色线贴到最底)。

    • 当阈值特别大(比如图最右边,阈值 = 50000):
      模型会 “特别挑剔,只有超级像 5 的才敢说是 5”,几乎不会误判 → 精确率接近 1(蓝色线贴到最顶)。
      但代价是:很多真正的 5 会因为分数不够,被模型漏掉 → 召回率会特别低(绿色线掉到最底)。

    3. 蓝色线(精确率):「模型说 “是 5” 的,真的是 5 吗?」

    精确率的意思是:模型说 “是 5” 的数字里,有多少真的是 5

    • 当阈值小(左):
      模型 “乱认 5” → 精确率低(蓝色线低),但召回率高(绿色线高)。

    • 当阈值大(右):
      模型 “谨慎认 5” → 精确率高(蓝色线高),但召回率低(绿色线低)。

    4. 核心规律:「鱼和熊掌不可兼得」

    调整阈值时,召回率和精确率永远反向变化

    • 想找全所有真 5(高召回率)→ 必须接受误判(低精确率);
    • 想让模型说 “是 5” 的都真的是 5(高精确率)→ 必须接受漏掉一些真 5(低召回率)。

    5. 怎么选阈值?看你的需求!

    • 如果是「找错题」场景(比如老师想找出所有学生写的 5 分步骤批改):
      宁愿误判一些非 5 为 5,也不想漏掉真 5 → 选低阈值(往左边调),优先保召回率。

    • 如果是「自动分类归档」场景(比如把数字 5 单独放进一个文件夹,错放会打乱分类):
      宁愿漏掉一些真 5,也不想把非 5 放进来 → 选高阈值(往右边调),优先保精确率。

      http://www.lryc.cn/news/615704.html

      相关文章:

    • Deep Learning MNIST手写数字识别 Mac
    • 【Elasticsearch入门到落地】16、RestClient查询文档-快速入门
    • Lua的数组、迭代器、table、模块
    • 黑马SpringBoot+Elasticsearch作业2实战:商品搜索与竞价排名功能实现
    • sqli-labs-master/Less-51~Less-61
    • Lua语言变量、函数、运算符、循环
    • 【RocketMQ 生产者和消费者】- ConsumeMessageOrderlyService 顺序消费消息
    • 在windows安装colmap并在cmd调用
    • vue3前端项目cursor rule
    • 常用hook钩子函数
    • 海关 瑞数 失信企业 逆向 分析 后缀 rs
    • 从神经网络语言模型(NNLM)到Word2Vec:自然语言处理中的词向量学习
    • 【Html网页模板】炫酷科技风公司首页
    • Axure设计下的智慧社区数据可视化大屏:科技赋能社区管理
    • [0CTF 2016]piapiapia
    • PhotoDirector 安卓版:功能强大的照片编辑与美化应用
    • Dify集成 Echarts 实现智能数据报表集成与展示实战详解
    • 复杂项目即时通讯从android 5升级android x后遗症之解决 ANR: Input dispatching timed out 问题 -优雅草卓伊凡
    • 咪咕MGV3200-KLH_GK6323V100C_板号E503744_安卓9_短接强刷包-可救砖
    • WebAssembly技术详解:从浏览器到云原生的高性能革命
    • Flutter 与 Android NDK 集成实战:实现高性能原生功能
    • Vue3 组件化开发
    • Solana上Launchpad混战:新颖性应被重视
    • 一个“加锁无效“的诡异现象
    • BGP 笔记
    • Python 中的 Mixin
    • 第4章 程序段的反复执行2 while语句P128练习题(题及答案)
    • 【动态数据源】⭐️@DS注解实现项目中多数据源的配置
    • Datawhale AI夏令营第三期,多模态RAG方向 Task2
    • 深度学习入门Day8:生成模型革命——从GAN到扩散模型