当前位置: 首页 > news >正文

学习笔记(29):训练集与测试集划分详解:train_test_split 函数深度解析

学习笔记(29):训练集与测试集划分详解:train_test_split 函数深度解析

一、为什么需要划分训练集和测试集?

在机器学习中,模型需要经历两个核心阶段:

  1. 训练阶段:用训练集数据学习特征与目标值的映射关系(如线性回归的权重)。
  2. 测试阶段:用测试集评估模型在未见过的数据上的表现,避免 “过拟合”(模型只记住训练数据的噪声,无法泛化到新数据)。

类比场景:学生通过 “练习题”(训练集)学习知识,再通过 “考试题”(测试集)检验真实水平。

二、train_test_split 函数的核心参数与逻辑
X_train, X_test, y_train, y_test = train_test_split(X_scaled, y, test_size=0.2, random_state=42
)
1. 输入参数解析
  • X_scaled:特征矩阵(已标准化的面积、房龄等特征)。
  • y:目标变量(房价)。
  • test_size=0.2:测试集占总数据的比例(20%),也可设为整数(如 test_size=20 表示取 20 个样本)。
  • random_state=42:随机种子,确保每次划分结果一致(与 np.random.seed(42) 作用类似)。
2. 划分逻辑
  • 随机抽样:按 test_size 比例从原始数据中随机抽取样本作为测试集,剩余作为训练集。
  • 数据对齐:确保 X 和 y 的样本顺序一一对应(如第 i 个特征向量对应第 i 个房价标签)。
三、划分结果的维度与含义

假设原始数据有 100 个样本(n_samples=100):

  • 训练集:80 个样本(X_train.shape=(80, 2)y_train.shape=(80,)),用于模型学习。
  • 测试集:20 个样本(X_test.shape=(20, 2)y_test.shape=(20,)),用于评估模型泛化能力。
四、关键参数深度解析
1. test_size:平衡训练与测试的样本量
  • 取值建议
    • 小数据集(<1000 样本):常用 test_size=0.2~0.3(20%-30% 作为测试集)。
    • 大数据集(>10000 样本):可设 test_size=0.1 甚至更低(因少量样本已足够评估)。
  • 极端案例:若 test_size=1.0,则所有数据都是测试集,无训练集;若 test_size=0,则全是训练集。
2. random_state:确保可复现的 “随机” 划分
  • 作用:固定随机种子后,每次运行代码时,训练集和测试集的样本索引完全相同。
  • 示例对比
    • 不设置 random_state:每次划分结果不同,导致模型评估指标波动。
    • 设置 random_state=42:多次运行代码,划分结果一致,便于对比不同模型效果。
3. shuffle=True(默认参数):打乱数据顺序
  • 为什么需要打乱?
    若数据按顺序排列(如前 50 个是小户型,后 50 个是大户型),不打乱会导致训练集和测试集样本分布不均(如测试集全是大户型)。
  • 参数设置train_test_split 默认为 shuffle=True,即先打乱数据再划分;若数据已随机排列,可设 shuffle=False
五、进阶应用:分层抽样(Stratified Sampling)

当目标变量是分类变量(如二分类 “是否违约”)时,普通随机划分可能导致训练 / 测试集的类别比例失衡(如测试集全是 “违约” 样本)。此时需用 StratifiedShuffleSplit 实现分层抽样:

from sklearn.model_selection import StratifiedShuffleSplit# 4. 使用分层抽样(确保类别比例平衡)
sss = StratifiedShuffleSplit(n_splits=1, test_size=0.2, random_state=42)
for train_idx, test_idx in sss.split(X_scaled, y_binary):X_train, X_test = X_scaled[train_idx], X_scaled[test_idx]y_train, y_test = y_binary[train_idx], y_binary[test_idx]print("===== 分类模型结果 =====")
print(f"原始数据类别比例:{np.bincount(y_binary)/len(y_binary)}")
print(f"训练集类别比例:{np.bincount(y_train)/len(y_train)}")
print(f"测试集类别比例:{np.bincount(y_test)/len(y_test)}")
六、实战误区与注意事项
  1. 禁止在测试集上训练:测试集只能用于评估,若根据测试集结果调整模型参数(如调优正则化系数),本质上是 “偷看答案”,会导致评估结果过于乐观。
  2. 数据标准化的顺序
    • 正确流程:先划分训练测试集,再对训练集拟合标准化器(scaler.fit(X_train)),最后用训练集的标准化参数转换测试集(scaler.transform(X_test))。
    • 错误操作:对全量数据标准化后再划分,会导致测试集 “偷看到” 全量数据的统计特征,违反 “未知数据” 假设。
  3. 多轮划分与交叉验证:当数据量较小时,可使用 K 折交叉验证(如 10 折),将数据分成 10 份,每次用 9 份训练、1 份测试,重复 10 次取平均,减少单次划分的随机性误差。
七、总结:划分训练测试集的核心原则
  1. 独立性:测试集数据必须是模型未见过的,模拟真实应用场景。
  2. 代表性:训练集和测试集的样本分布应尽可能一致(如特征取值范围、类别比例)。
  3. 可复现性:通过设置随机种子,确保实验结果可重复验证。

通过合理划分训练集与测试集,你可以更准确地评估模型的实际能力,避免被 “过拟合” 的假象误导 —— 这是机器学习工程化中至关重要的一步!

二分类问题(房价是否高于中位数)-全代码

import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
from sklearn.linear_model import LogisticRegression
from sklearn.model_selection import StratifiedShuffleSplit
from sklearn.preprocessing import StandardScaler
from sklearn.metrics import accuracy_score, confusion_matrix, classification_report# 配置中文显示
plt.rcParams['font.sans-serif'] = ['SimHei']
plt.rcParams['axes.unicode_minus'] = False# 1. 生成模拟数据(假设房价与面积、房龄的关系)
np.random.seed(42)
n_samples = 100
# 面积(平方米),房龄(年)
X = np.random.rand(n_samples, 2) * 100
X[:, 0] = X[:, 0]  # 面积范围:0-100
X[:, 1] = X[:, 1]  # 房龄范围:0-100# 真实房价 = 5000*面积 + 1000*房龄 + 随机噪声(模拟真实场景)
y = 5000 * X[:, 0] + 1000 * X[:, 1] + np.random.randn(n_samples) * 10000# 2. 将连续的房价y转换为二分类标签(是否高于中位数)
threshold = np.median(y)
y_binary = (y > threshold).astype(int)  # 0=低于中位数,1=高于中位数# 3. 数据预处理:标准化特征
scaler = StandardScaler()
X_scaled = scaler.fit_transform(X)# 4. 使用分层抽样(确保类别比例平衡)
sss = StratifiedShuffleSplit(n_splits=1, test_size=0.2, random_state=42)
for train_idx, test_idx in sss.split(X_scaled, y_binary):X_train, X_test = X_scaled[train_idx], X_scaled[test_idx]y_train, y_test = y_binary[train_idx], y_binary[test_idx]# 5. 训练逻辑回归模型(分类模型)
model = LogisticRegression()
model.fit(X_train, y_train)# 6. 预测并评估模型
y_pred = model.predict(X_test)
accuracy = accuracy_score(y_test, y_pred)
conf_matrix = confusion_matrix(y_test, y_pred)
class_report = classification_report(y_test, y_pred)# 7. 输出结果
print("===== 分类模型结果 =====")
print(f"原始数据类别比例:{np.bincount(y_binary)/len(y_binary)}")
print(f"训练集类别比例:{np.bincount(y_train)/len(y_train)}")
print(f"测试集类别比例:{np.bincount(y_test)/len(y_test)}")
print(f"准确率: {accuracy:.2f}")
print("混淆矩阵:")
print(conf_matrix)
print("分类报告:")
print(class_report)
print(f"模型系数: {model.coef_}")  # 面积和房龄的权重
print(f"模型截距: {model.intercept_}")

打印:

原始数据类别比例:[0.34 0.32 0.34]
训练集类别比例:[0.3375 0.325  0.3375]
测试集类别比例:[0.35 0.3  0.35]
均方误差: 101112597.45
决定系数R²: 1.00

代码解析:
核心步骤解析
  1. 数据准备与二分类转换

    • 生成与方案 1 相同的模拟数据(面积、房龄 → 房价)。
    • 将连续的房价y转换为二分类标签:
threshold = np.median(y)  # 使用中位数作为阈值
y_binary = (y > threshold).astype(int)  # 0=低于中位数,1=高于中位数
  1. 这样做的目的是将 “预测具体房价” 转化为 “判断房价高低”。

分层抽样(Stratified Sampling)

  • 使用StratifiedShuffleSplit确保训练集和测试集中高低房价的比例与原始数据一致:
sss = StratifiedShuffleSplit(n_splits=1, test_size=0.2, random_state=42)
for train_idx, test_idx in sss.split(X_scaled, y_binary):X_train, X_test = X_scaled[train_idx], X_scaled[test_idx]y_train, y_test = y_binary[train_idx], y_binary[test_idx]

打印类别比例验证分层效果:

原始数据类别比例:[0.5 0.5]
训练集类别比例:[0.5 0.5]
测试集类别比例:[0.5 0.5]

http://www.lryc.cn/news/578811.html

相关文章:

  • Servlet开发流程(包含IntelliJ IDEA项目添加Tomcat依赖的详细教程)
  • 玄机——某学校系统中挖矿病毒应急排查
  • 打造Docker Swarm集群服务编排部署指南:从入门到精通
  • 【公司环境下发布个人NPM包完整教程】
  • 网络协议概念与应用层
  • 解释LLM怎么预测下一个词语的
  • 图像二值化方法及 Python OpenCV 实现
  • 使用v-bind指令绑定属性
  • 【第三章:神经网络原理详解与Pytorch入门】01.神经网络算法理论详解与实践-(1)神经网络预备知识(线性代数、微积分、概率等)
  • 新能源汽车功率级测试自动化方案:从理论到实践的深度解析
  • 如何将文件从 iPhone 传输到 Android(新指南)
  • 网安-XSS-pikachu
  • MUX-VLAN基本概述
  • 【格与代数系统】格与哈斯图
  • 【分明集合】特征函数、关系与运算
  • 【HarmonyOS】鸿蒙使用仓颉编程入门
  • 【1.6 漫画数据库设计实战 - 从零开始设计高性能数据库】
  • UniApp完全支持快应用QUICKAPP-以及如何采用 Uni 模式开发发行快应用优雅草卓伊凡
  • 飞算智造JavaAI:智能编程革命——AI重构Java开发新范式
  • uniapp内置蓝牙打印
  • WPF中Style和Template异同
  • LEFE-Net:一种轴承故障诊断的轻量化高效特征提取网络
  • 设计模式(七)
  • 08跨域
  • 【环境配置】Neo4j Community Windows 安装教程
  • 7.可视化的docker界面——portainer
  • docker拉取镜像报错:Get https://registry-1.docker.io/v2/: net/http: request canceled
  • 基于SpringBoot + HTML 的网上书店系统
  • 大模型及agent开发5 OpenAI Assistant API 进阶应用
  • 电源芯片之DCDC初探索ING