交叉验证太重要了!

人工智能
交叉验证是机器学习和统计学中常用的一种技术,用于评估预测模型的性能和泛化能力,特别是在数据有限或评估模型对新的未见数据的泛化能力时,交叉验证非常有价值。

首先需要搞明白,为什么需要交叉验证?

交叉验证是机器学习和统计学中常用的一种技术,用于评估预测模型的性能和泛化能力,特别是在数据有限或评估模型对新的未见数据的泛化能力时,交叉验证非常有价值。

那么具体在什么情况下会使用交叉验证呢?

  • 模型性能评估:交叉验证有助于估计模型在未见数据上的表现。通过在多个数据子集上训练和评估模型,交叉验证提供了比单一训练-测试分割更稳健的模型性能估计。
  • 数据效率:在数据有限的情况下,交叉验证充分利用了所有可用样本,通过同时使用所有数据进行训练和评估,提供了对模型性能更可靠的评估。
  • 超参数调优:交叉验证通常用于选择模型的最佳超参数。通过在不同数据子集上使用不同的超参数设置来评估模型的性能,可以确定在整体性能上表现最好的超参数值。
  • 检测过拟合:交叉验证有助于检测模型是否对训练数据过拟合。如果模型在训练集上的表现明显优于验证集,可能表明存在过拟合的情况,需要进行调整,如正则化或选择更简单的模型。
  • 泛化能力评估:交叉验证提供了对模型对未见数据的泛化能力的评估。通过在多个数据分割上评估模型,它有助于评估模型捕捉数据中的潜在模式的能力,而不依赖于随机性或特定的训练-测试分割。

交叉验证的大致思想可如图5折交叉所示,在每次迭代中,新模型在四个子数据集上训练,并在最后一个保留的子数据集上进行测试,确保所有数据得到利用。通过平均分数及标准差等指标,提供了对模型性能的真实度量

一切还得从K折交叉开始。

KFold

K折交叉在Sklearn中已经集成,此处以7折为例:

from sklearn.datasets import make_regression
from sklearn.model_selection import KFold

x, y = make_regression(n_samples=100)

# Init the splitter
cross_validation = KFold(n_splits=7)

还有一个常用操作是在执行拆分前进行Shuffle,通过破坏样本的原始顺序进一步最小化了过度拟合的风险:

cross_validation = KFold(n_splits=7, shuffle=True)

这样,一个简单的k折交叉验证就实现了,记得看源码看源码看源码!!

StratifiedKFold

StratifiedKFold是专门为分类问题而设计。

在有的分类问题中,即使将数据分成多个集合,目标分布也应该保持不变。比如大多数情况下,具有30到70类别比例的二元目标在训练集和测试集中仍应保持相同的比例,在普通的KFold中,这个规则被打破了,因为在拆分之前对数据进行shuffle时,类别比例将无法保持。

为了解决这个问题,在Sklearn中使用了另一个专门用于分类的拆分器类——StratifiedKFold:

from sklearn.datasets import make_classification
from sklearn.model_selection import StratifiedKFold

x, y = make_classification(n_samples=100, n_classes=2)

cross_validation = StratifiedKFold(n_splits=7, shuffle=True, random_state=1121218)

虽然看起来与KFold相似,但现在类别比例在所有的split和迭代中都维持一致。

ShuffleSplit

有的时候只是多次重复进行训练/测试集拆分过程,也是和交叉验证很像的一种方式。

从逻辑上讲,使用不同的随机种子生成多个训练/测试集应该在足够多的迭代中类似于一个稳健的交叉验证过程。

Sklearn中也有提供接口:

from sklearn.model_selection import ShuffleSplit

cross_validation = ShuffleSplit(n_splits=7, train_size=0.75, test_size=0.25)

TimeSeriesSplit

当数据集为时间序列时,不能使用传统的交叉验证,这将完全打乱顺序,为了解决这个问题,参考Sklearn提供了另一个拆分器——TimeSeriesSplit,

from sklearn.model_selection import TimeSeriesSplit

cross_validation = TimeSeriesSplit(n_splits=7)

如图,验证集始终位于训练集的索引之后。由于索引是日期,不会意外地在未来的日期上训练时间序列模型并对之前的日期进行预测。

非独立同分布(non-IID)数据的交叉验证

前面所述方法均在处理独立同分布数据集,也就是说生成数据的过程不会受到其他样本的影响。

然而,有些情况下,数据并不满足IID的条件,即一些样本组之间存在依赖关系,Kaggle上的竞赛就有出现,如Google Brain Ventilator Pressure,该数据记录了人工肺在数千个呼吸过程中(吸入和呼出)的气压值,并且对每次呼吸的每个时刻进行了记录,每个呼吸过程大约有80行数据,这些行之间是相互关联的,在这种情况下,传统的交叉验证无法工作,因为拆分可能会“刚好发生在一个呼吸过程的中间”。

可以理解为需要对这些数据进行“分组”,因为组内数据是有关联的,比如当从多个患者收集医疗数据时,每个患者都有多个样本,而这些数据很可能会受到患者个体差异的影响,所以也需要分组。

往往我们希望在一个特定组别上训练的模型是否能够很好地泛化到其他未见过的组别,所以在进行交差验证时给这些组别数据打上“tag”,告诉他们如何区分别瞎拆。

在Sklearn中提供了若干接口处理这些情况:

  • GroupKFold
  • StratifiedGroupKFold
  • LeaveOneGroupOut
  • LeavePGroupsOut
  • GroupShuffleSplit

强烈建议搞清楚交叉验证的思想,以及如何实现,搭配看Sklearn源码是一个肥肠不错的方式。此外,需要对自己的数据集有着清晰的定义,数据预处理真的很重要。

责任编辑:赵宁宁 来源: 啥都会一点的研究生
相关推荐

2022-03-23 10:09:27

CIOTarget公司首席

2014-03-17 09:31:36

Linux桌面

2022-12-15 16:53:55

2022-03-28 20:59:17

交叉验证模型

2022-07-11 08:37:41

nacosSLB长连接

2020-07-15 07:45:51

Python开发工具

2016-01-06 09:49:58

云计算服务器

2017-06-26 10:43:22

互联网

2024-10-30 14:00:01

2024-10-30 08:23:07

2022-03-04 15:19:59

Spring BooJavaVert.x

2021-08-30 14:23:41

身份验证隐私管理网络安全

2022-08-14 16:04:15

机器学习数据集算法

2013-12-18 14:17:00

操作系统边缘化移动设备

2021-02-15 15:20:08

架构程序员软件

2010-08-30 10:48:40

职场

2021-08-03 09:33:55

HTTP网络协议TCP

2023-03-03 09:31:52

容器技术

2018-03-28 14:04:31

2017-10-18 16:08:15

可视化交叉验证代码
点赞
收藏

51CTO技术栈公众号