交叉验证
交叉验证¶
有关交叉验证的更全面讨论,请参阅 scikit-learn 交叉验证文档。本文档仅描述了为支持 Dask 数组所做的扩展。
拆分一个或多个 Dask 数组的最简单方法是使用 dask_ml.model_selection.train_test_split()
:
In [1]: import dask.array as da
In [2]: from dask_ml.datasets import make_regression
In [3]: from dask_ml.model_selection import train_test_split
In [4]: X, y = make_regression(n_samples=125, n_features=4, random_state=0, chunks=50)
In [5]: X
Out[5]: dask.array<normal, shape=(125, 4), dtype=float64, chunksize=(50, 4), chunktype=numpy.ndarray>
拆分 Dask 数组的接口与 scikit-learn 的版本相同。
In [6]: X_train, X_test, y_train, y_test = train_test_split(X, y)
In [7]: X_train # A dask Array
Out[7]: dask.array<concatenate, shape=(112, 4), dtype=float64, chunksize=(45, 4), chunktype=numpy.ndarray>
In [8]: X_train.compute()[:3]
Out[8]:
array([[ 1.32640188, -0.66775439, 0.61306745, 0.27349377],
[-1.05879148, -0.89722898, 0.96089406, -0.97585992],
[ 0.00550963, 0.02080571, -0.4763186 , -1.56842247]])
虽然可以将 dask 数组传递给 sklearn.model_selection.train_test_split()
,但我们出于性能原因推荐使用 Dask 版本:Dask 版本更快,原因有二:
首先,Dask 版本按块进行洗牌。在分布式环境中,块之间的洗牌可能需要在机器之间发送大量数据,这可能会很慢。然而,如果你的数据有很强的模式,你可能需要执行完全洗牌。
其次,Dask 版本避免了分配用于存储切片索引的大型中间 NumPy 数组。对于非常大的数据集,创建和传输 np.arange(n_samples)
可能会很昂贵。