当前位置: 首页 > news >正文

StratifiedKFold解释和代码实现

StratifiedKFold解释和代码实现

文章目录

  • 一、StratifiedKFold是什么?
  • 二、 实验数据设置
    • 2.1 实验数据生成代码
    • 2.2 代码结果
  • 三、实验代码
    • 3.1 实验代码
    • 3.2 实验结果
    • 3.3 结果解释
    • 3.4 数据打乱对这种交叉验证的影响。
  • 四、总结


一、StratifiedKFold是什么?

0,1,2,3:每一行表示测试集和训练集的划分的一种方式。
class:表示类别的个数(下图显示的是3类),有些交叉验证根据类别的比例划分测试集和训练集(例三)。
group:表示从不同的组采集到的样本,颜色的个数表示组的个数(有些时候我们关注在一组特定组上训练的模型是否能很好地泛化到看不见的组)。举个例子(解释“组”的意思):我们有10个人,我们想要希望训练集上所用的数据来自(1,2,3,4,5,6,7,8),测试集上的数据来自(9,10),也就是说我们不希望测试集上的数据和训练集上的数据来自同一个人(如果来自同一个人的话,训练集上的信息泄漏到测试集上了,模型的泛化性能会降低,测试结果会偏好)。
在这里插入图片描述

二、 实验数据设置

2.1 实验数据生成代码

X, y = np.arange(0,60).reshape((30,2)), np.hstack(([0] * 3, [1] * 9, [2] * 18))
print("数据:", end=" ")
for l in X:print(l, end=' ')
print("")
print("标签:", y)

2.2 代码结果

数据: [0 1] [2 3] [4 5] [6 7] [8 9] [10 11] [12 13] [14 15] [16 17] [18 19] [20 21] [22 23] [24 25] [26 27] [28 29] [30 31] [32 33] [34 35] [36 37] [38 39] [40 41] [42 43] [44 45] [46 47] [48 49] [50 51] [52 53] [54 55] [56 57] [58 59] 
标签: [0 0 0 1 1 1 1 1 1 1 1 1 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2]

数据个数、标签个数:30个
类别个数:3个(分别是0,1,2,比例是0.1:0.3:0.6和class每类对应),StratifiedKFold
组别(group):由于StratifiedKFold交叉验证结果和group无关,所以这里不再设置。

三、实验代码

3.1 实验代码

代码如下:

from sklearn.model_selection import StratifiedKFold
import numpy as np
# X, y = np.ones((30, 1)), np.hstack(([0] * 20, [1] * 10))
# print(np.arange(0,30).reshape((30,1)))
X, y = np.arange(0,60).reshape((30,2)), np.hstack(([0] * 3, [1] * 9, [2] * 18))
print("数据:", end=" ")
for l in X:print(l, end=' ')
print("")
print("标签:", y)
skf = StratifiedKFold(n_splits=3)
for i,(train, test) in enumerate(skf.split(X, y)):print("=================StratifiedKFold 第%d折叠 ===================="% (i+1))print('train -  {}'.format(np.bincount(y[train])))print("  训练集索引:%s" % train)print("  训练集标签:", y[train])print("  训练集数据:", end=" ")for l in X[train]:print(l, end=' ')print("")# print("  训练集数据:", X[train])print("test  -  {}".format(np.bincount(y[test])))print("  测试集索引:%s" % test)print("  测试集标签:", y[test])print("  测试集数据:", end=" ")for l in X[test]:print(l, end=' ')print("")# print("  测试集数据:", X[test])print("=============================================================")

3.2 实验结果

结果如下:

数据: [0 1] [2 3] [4 5] [6 7] [8 9] [10 11] [12 13] [14 15] [16 17] [18 19] [20 21] [22 23] [24 25] [26 27] [28 29] [30 31] [32 33] [34 35] [36 37] [38 39] [40 41] [42 43] [44 45] [46 47] [48 49] [50 51] [52 53] [54 55] [56 57] [58 59] 
标签: [0 0 0 1 1 1 1 1 1 1 1 1 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2]
=================StratifiedKFold 第1折叠 ====================
train -  [ 2  6 12]训练集索引:[ 1  2  6  7  8  9 10 11 18 19 20 21 22 23 24 25 26 27 28 29]训练集标签: [0 0 1 1 1 1 1 1 2 2 2 2 2 2 2 2 2 2 2 2]训练集数据: [2 3] [4 5] [12 13] [14 15] [16 17] [18 19] [20 21] [22 23] [36 37] [38 39] [40 41] [42 43] [44 45] [46 47] [48 49] [50 51] [52 53] [54 55] [56 57] [58 59] 
test  -  [1 3 6]测试集索引:[ 0  3  4  5 12 13 14 15 16 17]测试集标签: [0 1 1 1 2 2 2 2 2 2]测试集数据: [0 1] [6 7] [8 9] [10 11] [24 25] [26 27] [28 29] [30 31] [32 33] [34 35] 
=============================================================
=================StratifiedKFold 第2折叠 ====================
train -  [ 2  6 12]训练集索引:[ 0  2  3  4  5  9 10 11 12 13 14 15 16 17 24 25 26 27 28 29]训练集标签: [0 0 1 1 1 1 1 1 2 2 2 2 2 2 2 2 2 2 2 2]训练集数据: [0 1] [4 5] [6 7] [8 9] [10 11] [18 19] [20 21] [22 23] [24 25] [26 27] [28 29] [30 31] [32 33] [34 35] [48 49] [50 51] [52 53] [54 55] [56 57] [58 59] 
test  -  [1 3 6]测试集索引:[ 1  6  7  8 18 19 20 21 22 23]测试集标签: [0 1 1 1 2 2 2 2 2 2]测试集数据: [2 3] [12 13] [14 15] [16 17] [36 37] [38 39] [40 41] [42 43] [44 45] [46 47] 
=============================================================
=================StratifiedKFold 第3折叠 ====================
train -  [ 2  6 12]训练集索引:[ 0  1  3  4  5  6  7  8 12 13 14 15 16 17 18 19 20 21 22 23]训练集标签: [0 0 1 1 1 1 1 1 2 2 2 2 2 2 2 2 2 2 2 2]训练集数据: [0 1] [2 3] [6 7] [8 9] [10 11] [12 13] [14 15] [16 17] [24 25] [26 27] [28 29] [30 31] [32 33] [34 35] [36 37] [38 39] [40 41] [42 43] [44 45] [46 47] 
test  -  [1 3 6]测试集索引:[ 2  9 10 11 24 25 26 27 28 29]测试集标签: [0 1 1 1 2 2 2 2 2 2]测试集数据: [4 5] [18 19] [20 21] [22 23] [48 49] [50 51] [52 53] [54 55] [56 57] [58 59] 
=============================================================进程已结束,退出代码 0

3.3 结果解释

可以看到测试集和训练集划分是根据折叠数和标签的比例。例如:这里的折叠数是3,标签的比例是1:3:6,所以在第一折叠处测试集标签0的个数是1/3(折叠数)*0.1(标签比例)*30(样本数)=1个。剩余的分析同理。

=================StratifiedKFold 第1折叠 ====================
train -  [ 2  6 12]训练集索引:[ 1  2  6  7  8  9 10 11 18 19 20 21 22 23 24 25 26 27 28 29]训练集标签: [0 0 1 1 1 1 1 1 2 2 2 2 2 2 2 2 2 2 2 2]训练集数据: [2 3] [4 5] [12 13] [14 15] [16 17] [18 19] [20 21] [22 23] [36 37] [38 39] [40 41] [42 43] [44 45] [46 47] [48 49] [50 51] [52 53] [54 55] [56 57] [58 59] 
test  -  [1 3 6]测试集索引:[ 0  3  4  5 12 13 14 15 16 17]测试集标签: [0 1 1 1 2 2 2 2 2 2]测试集数据: [0 1] [6 7] [8 9] [10 11] [24 25] [26 27] [28 29] [30 31] [32 33] [34 35] 
=============================================================

3.4 数据打乱对这种交叉验证的影响。

X, y = np.arange(0,60).reshape((30,2)), np.hstack(([0] * 3, [1] * 9, [2] * 18))

改为下面的代码

arr = np.hstack(([0] * 3, [1] * 9, [2] * 18))
print("原始标签:", arr)
# 使用np.random.shuffle函数将数组打乱
np.random.shuffle(arr)
X, y = np.arange(0,60).reshape((30,2)), arr

可以看出划分和标签的先后顺序有一定的关系。

四、总结

StratifiedKFold:考虑了标签(class),但没考虑组(group)的影响。

http://www.lryc.cn/news/271305.html

相关文章:

  • 四十八----react实战
  • 三步实现Java的SM2前端加密后端解密
  • 1分钟带你了解golang(go语言)
  • CSS-4
  • Python为何适合开发AI项目?
  • 总结心得:各设计模式使用场景
  • 详解Vue3中的事件监听方式
  • Unity关于easySave2 easySave3保存数据的操作;包含EasySave3运行报错的解决
  • 2022年全球软件质量效能大会(QECon上海站)-核心PPT资料下载
  • 【python报错】UserWarning: train_labels has been renamed targets
  • 算法专题四:前缀和
  • STM32学习笔记十五:WS2812制作像素游戏屏-飞行射击游戏(5)探索动画之帧动画
  • 期末复习(程序设计)
  • html-css-js移动端导航栏底部固定+i18n国际化全局
  • Ubuntu Linux 入门指南:面向初学者
  • 常见算法面试题目
  • PiflowX组件-JDBCWrite
  • 算法导论复习题目
  • HTTPS协议详解
  • 菜鸟学习vue3笔记-vue3 router回顾
  • Mybatis枚举类型处理和类型处理器
  • 2023 NCTF writeup
  • golang的大杀器协程goroutine
  • [Angular] 笔记 9:list/detail 页面以及@Output
  • Linux学习笔记(一)
  • Python 爬虫 教程
  • uniapp原生插件 - android原生插件打包流程 ( 避坑指南一)
  • 搭建maven私服
  • EST-100身份证社保卡签批屏按捺终端PC版web版本http协议接口文档,支持web网页开发对接使用
  • 基于SpringBoot的毕业论文管理系统