PredefinedSplit#

class sklearn.model_selection.PredefinedSplit(test_fold)#

预定义的分割交叉验证器。

提供训练/测试索引来将数据分割成训练/测试集,使用由用户通过 test_fold 参数指定的预定义方案。

更多信息请参阅 用户指南

Added in version 0.16.

Parameters:
test_foldarray-like of shape (n_samples,)

条目 test_fold[i] 表示样本 i 所属的测试集的索引。可以通过将 test_fold[i] 设置为 -1 来排除样本 i 从任何测试集(即在每个训练集中包含样本 i )。

Examples

>>> import numpy as np
>>> from sklearn.model_selection import PredefinedSplit
>>> X = np.array([[1, 2], [3, 4], [1, 2], [3, 4]])
>>> y = np.array([0, 0, 1, 1])
>>> test_fold = [0, 1, -1, 1]
>>> ps = PredefinedSplit(test_fold)
>>> ps.get_n_splits()
2
>>> print(ps)
PredefinedSplit(test_fold=array([ 0,  1, -1,  1]))
>>> for i, (train_index, test_index) in enumerate(ps.split()):
...     print(f"Fold {i}:")
...     print(f"  Train: index={train_index}")
...     print(f"  Test:  index={test_index}")
Fold 0:
  Train: index=[1 2 3]
  Test:  index=[0]
Fold 1:
  Train: index=[0 2]
  Test:  index=[1 3]
get_metadata_routing()#

获取此对象的元数据路由。

请查看 用户指南 以了解路由机制的工作原理。

Returns:
routingMetadataRequest

MetadataRequest 封装的 路由信息。

get_n_splits(X=None, y=None, groups=None)#

返回交叉验证器中的分割迭代次数。

Parameters:
Xobject

总是被忽略,存在是为了兼容性。

yobject

总是被忽略,存在是为了兼容性。

groupsobject

总是被忽略,存在是为了兼容性。

Returns:
n_splitsint

返回交叉验证器中的分割迭代次数。

split(X=None, y=None, groups=None)#

生成索引以将数据拆分为训练集和测试集。

Parameters:
X对象

总是被忽略,存在以确保兼容性。

y对象

总是被忽略,存在以确保兼容性。

groups对象

总是被忽略,存在以确保兼容性。

Yields:
trainndarray

该拆分的训练集索引。

testndarray

该拆分的测试集索引。