预定义的保留分割:与scikit-learn兼容的保留方法工具
将数据集拆分为训练子集和验证子集,供用户根据指定的索引验证。
> `from mlxtend.evaluate import PredefinedHoldoutSplit`
概述
PredefinedHoldoutSplit
类是 scikit-learn 的 KFold
类的替代类,其中 PredefinedHoldoutSplit
类根据用户指定的验证索引将数据集划分为训练集和验证集,而不进行循环。 PredefinedHoldoutSplit
可以用作 scikit-learn 的 GridSearchCV
等的 cv
参数的参数。
如需进行随机划分,请参见相关的 RandomHoldoutSplit
类。
示例 1 -- 迭代预定义的 Holdout Split
from mlxtend.evaluate import PredefinedHoldoutSplit
from mlxtend.data import iris_data
X, y = iris_data()
h_iter = PredefinedHoldoutSplit(valid_indices=[0, 1, 99])
cnt = 0
for train_ind, valid_ind in h_iter.split(X, y):
cnt += 1
print(cnt)
1
print(train_ind[:5])
print(valid_ind[:5])
[2 3 4 5 6]
[ 0 1 99]
示例 2 -- 在网格搜索中使用预定义的保留拆分
from sklearn.model_selection import GridSearchCV
from sklearn.neighbors import KNeighborsClassifier
from mlxtend.evaluate import PredefinedHoldoutSplit
from mlxtend.data import iris_data
X, y = iris_data()
params = {'n_neighbors': [1, 2, 3, 4, 5]}
grid = GridSearchCV(KNeighborsClassifier(),
param_grid=params,
cv=PredefinedHoldoutSplit(valid_indices=[0, 1, 99]))
grid.fit(X, y)
GridSearchCV(cv=<mlxtend.evaluate.holdout.PredefinedHoldoutSplit object at 0x7fb300565610>,
estimator=KNeighborsClassifier(),
param_grid={'n_neighbors': [1, 2, 3, 4, 5]})
API
PredefinedHoldoutSplit(valid_indices)
Train/Validation set splitter for sklearn's GridSearchCV etc.
Uses user-specified train/validation set indices to split a dataset
into train/validation sets using user-defined or random
indices.
Parameters
-
valid_indices
: array-like, shape (num_examples,)Indices of the training examples in the training set to be used for validation. All other indices in the training set are used to for a training subset for model fitting.
Examples
For usage examples, please see https://rasbt.github.io/mlxtend/user_guide/evaluate/PredefinedHoldoutSplit/
Methods
get_n_splits(X=None, y=None, groups=None)
Returns the number of splitting iterations in the cross-validator
Parameters
-
X
: objectAlways ignored, exists for compatibility.
-
y
: objectAlways ignored, exists for compatibility.
-
groups
: objectAlways ignored, exists for compatibility.
Returns
-
n_splits
: 1Returns the number of splitting iterations in the cross-validator. Always returns 1.
split(X, y, groups=None)
Generate indices to split data into training and test set.
Parameters
-
X
: array-like, shape (num_examples, num_features)Training data, where num_examples is the number of examples and num_features is the number of features.
-
y
: array-like, shape (num_examples,)The target variable for supervised learning problems. Stratification is done based on the y labels.
-
groups
: objectAlways ignored, exists for compatibility.
Yields
-
train_index
: ndarrayThe training set indices for that split.
-
valid_index
: ndarrayThe validation set indices for that split.