.. currentmodule:: sklearn.model_selection .. _grid_search: =========================================== 调整估计器的超参数 =========================================== 超参数是那些不在估计器内部直接学习的参数。 在 scikit-learn 中,它们作为参数传递给估计器类的构造函数。典型的例子包括支持向量分类器中的 ``C`` 、 ``kernel`` 和 ``gamma`` ,Lasso 中的 ``alpha`` 等。 可以并且建议搜索超参数空间以获得最佳的 :ref:`交叉验证 ` 分数。 任何在构造估计器时提供的参数都可以通过这种方式进行优化。具体来说,要找到给定估计器的所有参数的名称和当前值,可以使用:: estimator.get_params() 一个搜索包括: - 一个估计器(回归器或分类器,例如 ``sklearn.svm.SVC()`` ); - 一个参数空间; - 一种搜索或采样候选参数的方法; - 一个交叉验证方案;以及 - 一个 :ref:`评分函数 ` 。 scikit-learn 提供了两种通用的参数搜索方法:对于给定值,:class:`GridSearchCV` 穷举考虑所有参数组合,而 :class:`RandomizedSearchCV` 可以从具有指定分布的参数空间中采样给定数量的候选参数。这两种工具都有连续减半的对应版本 :class:`HalvingGridSearchCV` 和 :class:`HalvingRandomSearchCV` ,它们可以更快地找到一个好的参数组合。 在描述这些工具之后,我们详细说明了适用于这些方法的 :ref:`最佳实践 ` 。一些模型允许专门的、高效的参数搜索策略,这些策略在 :ref:`alternative_cv` 中概述。 请注意,这些参数中的一小部分可能对模型的预测性能或计算性能有重大影响,而其他参数 可以保留其默认值。建议阅读估计器类的文档字符串,以更精细地理解其预期行为,可能通过阅读所附的文献参考。 穷举网格搜索 ============== :class:`GridSearchCV` 提供的网格搜索通过 ``param_grid`` 参数指定的参数值网格中穷举生成候选。例如,以下 ``param_grid`` :: param_grid = [ {'C': [1, 10, 100, 1000], 'kernel': ['linear']}, {'C': [1, 10, 100, 1000], 'gamma': [0.001, 0.0001], 'kernel': ['rbf']}, ] 指定应探索两个网格:一个具有线性核和 C 值在 [1, 10, 100, 1000] 中,第二个具有 RBF 核,以及 C 值在 [1, 10, 100, 1000] 和 gamma 值在 [0.001, 0.0001] 中的交叉积。 :class:`GridSearchCV` 实例实现了通常的估计器 API:当在数据集上“拟合”时,所有可能的参数值组合都会被评估,并保留最佳组合。 .. currentmodule:: sklearn.model_selection .. rubric:: 示例 - 参见 :ref:`sphx_glr_auto_examples_model_selection_plot_grid_search_digits.py` 以查看在 digits 数据集上进行网格搜索计算的示例。 - 参见 :ref:`sphx_glr_auto_examples_model_selection_plot_grid_search_text_feature_extraction.py` 以查看网格搜索结合文本文档特征提取器(n-gram 计数向量化器和 TF-IDF 转换器)与分类器(此处为使用 SGD 训练的线性 SVM,带有弹性网或 L2 惩罚)的示例,使用 :class:`~sklearn.pipeline.Pipeline` 实例。 - 参见 :ref:`sphx_glr_auto_examples_model_selection_plot_nested_cross_validation_iris.py` 以查看在 iris 数据集上进行交叉验证循环中的网格搜索的示例。这是使用网格搜索评估模型性能的最佳实践。 - 参见 :ref:`sphx_glr_auto_examples_model_selection_plot_multi_metric_evaluation.py` 以了解如何使用 :class:`GridSearchCV` 同时评估多个指标的示例。 - 参见 :ref:`sphx_glr_auto_examples_model_selection_plot_grid_search_refit_callable.py` 以了解在 :class:`GridSearchCV` 中使用 ``refit=callable`` 接口的示例。该示例展示了此接口如何在识别“最佳”估计器时增加一定的灵活性。此接口也可用于多指标评估。 - 参见 :ref:`sphx_glr_auto_examples_model_selection_plot_grid_search_stats.py` 以了解如何对 :class:`GridSearchCV` 的输出进行统计比较的示例。 .. _randomized_parameter_search: 随机参数优化 ============ 虽然使用参数设置的网格是目前最广泛使用的参数优化方法,但其他搜索方法具有更有利的特性。 :class:`RandomizedSearchCV` 实现了参数的随机搜索,其中每个设置从可能参数值的分布中采样。这比穷举搜索有两个主要好处: * 可以选择独立于参数数量和可能值的预算。 * 添加不影响性能的参数不会降低效率。 指定如何采样参数是通过一个字典完成的,与指定 :class:`GridSearchCV` 的参数非常相似。此外,使用 ``n_iter`` 参数指定计算预算,即采样候选数或采样迭代次数。对于每个参数,可以指定可能值的分布或离散选择的列表(将均匀采样):: {'C': scipy.stats.expon(scale=100), 'gamma': scipy.stats.expon(scale=.1), 'kernel': ['rbf'], 'class_weight':['balanced', None]} 此示例使用了 ``scipy.stats`` 模块,该模块包含许多有用的 用于采样参数的分布,例如 ``expon`` 、 ``gamma`` 、 ``uniform`` 、 ``loguniform`` 或 ``randint`` 。 原则上,任何提供 ``rvs`` (随机变量样本)方法来采样值的函数都可以传递。对 ``rvs`` 函数的调用应在连续调用中提供可能参数值的独立随机样本。 .. warning:: 在 scipy 0.16 之前的版本中, ``scipy.stats`` 中的分布不允许指定随机状态。相反,它们使用全局的 numpy 随机状态,可以通过 ``np.random.seed`` 进行种子设定或使用 ``np.random.set_state`` 进行设置。然而,从 scikit-learn 0.18 开始,如果 scipy >= 0.16 也可用,:mod:`sklearn.model_selection` 模块会设置用户提供的随机状态。 对于连续参数,如上述的 ``C`` ,指定一个连续分布以充分利用随机化是很重要的。这样,增加 ``n_iter`` 将始终导致更精细的搜索。 连续的对数均匀随机变量是日志间隔参数的连续版本。例如,要指定与上述 ``C`` 等效的参数,可以使用 ``loguniform(1, 100)`` 而不是 ``[1, 10, 100]`` 。 以上述网格搜索为例,我们可以指定一个在 ``1e0`` 和 ``1e3`` 之间对数均匀分布的连续随机变量:: from sklearn.utils.fixes import loguniform {'C': loguniform(1e0, 1e3), 'gamma': loguniform(1e-4, 1e-3), 'kernel': ['rbf'], 'class_weight':['balanced', None]} .. rubric:: 示例 * :ref:`sphx_glr_auto_examples_model_selection_plot_randomized_search.py` 比较了随机搜索和网格搜索的使用和效率。 .. rubric:: 参考文献 * Bergstra, J. 和 Bengio, Y., 随机搜索用于超参数优化, 机器学习研究杂志 (2012) .. _successive_halving_user_guide: 使用连续减半搜索最优参数 ======================================================== Scikit-learn 还提供了 :class:`HalvingGridSearchCV` 和 :class:`HalvingRandomSearchCV` 估计器,这些估计器可以用于使用连续减半 [1]_ [2]_ 来搜索参数空间。连续减半(SH)类似于候选参数组合之间的锦标赛。SH 是一个迭代选择过程,其中所有候选参数组合在第一次迭代中使用少量资源进行评估。只有部分候选参数组合被选中进入下一次迭代,这些组合将被分配更多资源。对于参数调整,资源通常是训练样本的数量,但也可以是任意数值参数,例如随机森林中的 `n_estimators` 。 如下图所示,只有一部分候选参数组合“存活”到最后一轮迭代。这些是在所有迭代中始终排名靠前的候选参数组合。每轮迭代为每个候选参数组合分配越来越多的资源,这里指的是样本数量。 .. figure:: ../auto_examples/model_selection/images/sphx_glr_plot_successive_halving_iterations_001.png :target: ../auto_examples/model_selection/plot_successive_halving_iterations.html :align: center 我们在这里简要描述主要参数,但每个参数及其交互作用在下面的章节中有更详细的描述。 ``factor`` (> 1)参数控制资源增长的速率以及候选参数组合数量减少的速率。在每次迭代中,每个候选参数组合的资源数量乘以 ``factor`` ,而候选参数组合的数量除以相同的因子。与 ``resource`` 和 ``min_resources`` 一起, ``factor`` 是我们实现中控制搜索的最重要参数,尽管通常值为 3 效果良好。 ``factor`` 有效地控制了连续减半过程中的迭代次数。 :class:`HalvingGridSearchCV` 和候选数量(默认情况下)以及 :class:`HalvingRandomSearchCV` 中的迭代次数。如果可用资源较少,也可以使用 ``aggressive_elimination=True`` 。通过调整 ``min_resources`` 参数可以获得更多控制。 这些估计器仍然是 **实验性的**:它们的预测和 API 可能会在没有任何弃用周期的情况下发生变化。要使用它们,您需要显式导入 ``enable_halving_search_cv`` :: >>> # 显式要求此实验性功能 >>> from sklearn.experimental import enable_halving_search_cv # noqa >>> # 现在您可以正常从 model_selection 导入 >>> from sklearn.model_selection import HalvingGridSearchCV >>> from sklearn.model_selection import HalvingRandomSearchCV .. rubric:: 示例 * :ref:`sphx_glr_auto_examples_model_selection_plot_successive_halving_heatmap.py` * :ref:`sphx_glr_auto_examples_model_selection_plot_successive_halving_iterations.py` 选择 ``min_resources`` 和候选数量 ------------------------------------------------------- 除了 ``factor`` 之外,影响连续减半搜索行为的两个主要参数是 ``min_resources`` 参数和评估的候选数量(或参数组合)。 ``min_resources`` 是在每个候选的第一次迭代中分配的资源量。候选数量在 :class:`HalvingRandomSearchCV` 中直接指定,并由 :class:`HalvingGridSearchCV` 的 ``param_grid`` 参数确定。 考虑一个资源是样本数量的情况,我们有 1000 个样本。理论上,使用 ``min_resources=10`` 和 ``factor=2`` ,我们最多可以运行 7 次迭代,样本数量如下: ``[10, 20, 40, 80, 160, 320, 640]`` 。 但根据候选数量,我们可能运行少于 7 次迭代:如果我们从 **较少** 的候选数量开始,最后 迭代可能使用少于640个样本,这意味着没有充分利用所有可用资源(样本)。例如,如果我们从5个候选开始,我们只需要2次迭代:第一次迭代有5个候选,然后第二次迭代有 `5 // 2 = 2` 个候选,之后我们就知道哪个候选表现最好(因此不需要第三次迭代)。我们最多只会使用20个样本,这是浪费,因为我们有1000个样本可供使用。另一方面,如果我们从**大量**候选开始,我们可能会在最后一次迭代中得到很多候选,这可能并不总是理想的:这意味着许多候选将使用全部资源运行,基本上将过程简化为标准搜索。 在:class:`HalvingRandomSearchCV` 的情况下,候选数量默认设置为使得最后一次迭代尽可能多地使用可用资源。对于:class:`HalvingGridSearchCV` ,候选数量由 `param_grid` 参数决定。更改 ``min_resources`` 的值将影响可能的迭代次数,因此也会影响理想的候选数量。 选择 ``min_resources`` 时的另一个考虑因素是,用少量资源区分好候选和坏候选是否容易。例如,如果你需要大量样本来区分好参数和坏参数,建议使用较高的 ``min_resources`` 。另一方面,如果即使样本量很小也能清楚地区分,那么较小的 ``min_resources`` 可能更可取,因为它会加快计算速度。 请注意,在上面的例子中,最后一次迭代没有使用最大数量的可用资源:有1000个样本可用,但最多只使用了640个。默认情况下,:class:`HalvingRandomSearchCV` 和:class:`HalvingGridSearchCV` 都试图在最后一次迭代中尽可能多地使用资源。 最后一次迭代,其约束条件是资源量必须是 `min_resources` 和 `factor` 的倍数(这一约束将在下一节中变得清晰)。:class:`HalvingRandomSearchCV` 通过采样适当数量的候选者来实现这一点,而 :class:`HalvingGridSearchCV` 则通过适当设置 `min_resources` 来实现这一点。详情请参阅 :ref:`exhausting_the_resources` 。 .. _amount_of_resource_and_number_of_candidates: 每次迭代的资源量和候选者数量 -------------------------------------------------- 在任意迭代 `i` 中,每个候选者被分配一定量的资源,我们将其表示为 `n_resources_i` 。这一数量由参数 ``factor`` 和 ``min_resources`` 控制,如下所示( `factor` 严格大于 1):: n_resources_i = factor**i * min_resources, 或者等价地:: n_resources_{i+1} = n_resources_i * factor 其中 ``min_resources == n_resources_0`` 是第一次迭代中使用的资源量。 ``factor`` 还定义了将选择用于下一次迭代的候选者的比例:: n_candidates_i = n_candidates // (factor ** i) 或者等价地:: n_candidates_0 = n_candidates n_candidates_{i+1} = n_candidates_i // factor 因此在第一次迭代中,我们使用 ``min_resources`` 资源 ``n_candidates`` 次。在第二次迭代中,我们使用 ``min_resources * factor`` 资源 ``n_candidates // factor`` 次。第三次再次增加每个候选者的资源量并减少候选者数量。这一过程在达到每个候选者的最大资源量或识别出最佳候选者时停止。最佳候选者在评估 `factor` 或更少候选者的迭代中被识别(解释如下)。 以下是一个示例,其中 ``min_resources=3`` 和 ``factor=2`` ,从 70 个候选者开始: +-----------------------+-----------------------+ | 迭代次数 | 资源量 | +=======================+=======================+ | 1 | 3 | +-----------------------+-----------------------+ | 2 | 6 | +-----------------------+-----------------------+ | 3 | 12 | +-----------------------+-----------------------+ | 4 | 24 | +-----------------------+-----------------------+ | 5 | 48 | +-----------------------+-----------------------+ | 6 | 96 | +-----------------------+-----------------------+ | 7 | 192 | +-----------------------+-----------------------+ | 8 | 384 | +-----------------------+-----------------------+ | 9 | 768 | +-----------------------+-----------------------+ | 10 | 1536 | +-----------------------+-----------------------+ | 11 | 3072 | +-----------------------+-----------------------+ | 12 | 6144 | +-----------------------+-----------------------+ | 13 | 12288 | +-----------------------+-----------------------+ | 14 | 24576 | +-----------------------+-----------------------+ | 15 | 49152 | +-----------------------+-----------------------+ | 16 | 98304 | +-----------------------+-----------------------+ | 17 | 196608 | +-----------------------+-----------------------+ | 18 | 393216 | +-----------------------+-----------------------+ | 19 | 786432 | +-----------------------+-----------------------+ | 20 | 1572864 | +-----------------------+-----------------------+ | 21 | 3145728 | +-----------------------+-----------------------+ | 22 | 6291456 | +-----------------------+-----------------------+ | 23 | 12582912 | +-----------------------+-----------------------+ | 24 | 25165824 | +-----------------------+-----------------------+ | 25 | 50331648 | +-----------------------+-----------------------+ | 26 | 100663296 | +-----------------------+-----------------------+ | 27 | 201326592 | +-----------------------+-----------------------+ | 28 | 402653184 | +-----------------------+-----------------------+ | 29 | 805306368 | +-----------------------+-----------------------+ | 30 | 1610612736 | +-----------------------+-----------------------+ | 31 | 3221225472 | +-----------------------+-----------------------+ | 32 | 6442450944 | +-----------------------+-----------------------+ | 33 | 12884901888 | +-----------------------+-----------------------+ | 34 | 25769803776 | +-----------------------+-----------------------+ | 35 | 51539607552 | +-----------------------+-----------------------+ | 36 | 103079215104 | +-----------------------+-----------------------+ | 37 | 206158430208 | +-----------------------+-----------------------+ | 38 | 412316860416 | +-----------------------+-----------------------+ | 39 | 824633720832 | +-----------------------+-----------------------+ | 40 | 1649267441664 | +-----------------------+-----------------------+ | 41 | 3298534883328 | +-----------------------+-----------------------+ | 42 | 6597069766656 | +-----------------------+-----------------------+ | 43 | 13194139533312 | +-----------------------+-----------------------+ | 44 | 26388279066624 | +-----------------------+-----------------------+ | 45 | 52776558133248 | +-----------------------+-----------------------+ | 46 | 105553116266496 | +-----------------------+-----------------------+ | 47 | 211106232532992 | +-----------------------+-----------------------+ | 48 | 422212465065984 | +-----------------------+-----------------------+ | 49 | 844424930131968 | +-----------------------+-----------------------+ | 50 | 1688849860263936 | +-----------------------+-----------------------+ | 51 | 3377699720527872 | +-----------------------+-----------------------+ | 52 | 6755399441055744 | +-----------------------+-----------------------+ | 53 | 13510798882111488 | +-----------------------+-----------------------+ | 54 | 27021597764222976 | +-----------------------+-----------------------+ | 55 | 54043195528445952 | +-----------------------+-----------------------+ | 56 | 108086391056891904 | +-----------------------+-----------------------+ | 57 | 216172782113783808 | +-----------------------+-----------------------+ | 58 | 432345564227567616 | +-----------------------+-----------------------+ | 59 | 864691128455135232 | +-----------------------+-----------------------+ | 60 | 1729382256910270464 | +-----------------------+-----------------------+ | 61 | 3458764513820540928 | +-----------------------+-----------------------+ | 62 | 6917529027641081856 | +-----------------------+-----------------------+ | 63 | 13835058055282163712 | +-----------------------+-----------------------+ | 64 | 27670116110564327424 | +-----------------------+-----------------------+ | 65 | 55340232221128654848 | +-----------------------+-----------------------+ | 66 | 110680464442257309696 | +-----------------------+-----------------------+ | 67 | 221360928884514619392 | +-----------------------+-----------------------+ | 68 | 442721857769029238784 | +-----------------------+-----------------------+ | 69 | 885443715538058477568 | +-----------------------+-----------------------+ | 70 | 1770887431076116955136| +-----------------------+-----------------------+ | 71 | 3541774862152233910272| +-----------------------+-----------------------+ | 72 | 7083549724304467820544| +-----------------------+-----------------------+ | 73 | 14167099448608935641088| +-----------------------+-----------------------+ | 74 | 28334198897217871282176| +-----------------------+-----------------------+ | 75 | 56668397794435742564352| +-----------------------+-----------------------+ | 76 | 113336795588871485128704| +-----------------------+-----------------------+ | 77 | 226673591177742970257408| +-----------------------+-----------------------+ | 78 | 453347182355485940514816| +-----------------------+-----------------------+ | 79 | 906694364710971881029632| +-----------------------+-----------------------+ | 80 | 1813388729421943762059264| +-----------------------+-----------------------+ | 81 | 3626777458843887524118528| +-----------------------+-----------------------+ | 82 | 7253554917687775048237056| +-----------------------+-----------------------+ | 83 | 14507109835375550096474112| +-----------------------+-----------------------+ | 84 | 29014219670751100192948224| +-----------------------+-----------------------+ | 85 | 58028439341502200385896448| +-----------------------+-----------------------+ | 86 | 116056878683004400771792896| +-----------------------+-----------------------+ | 87 | 232113757366008801543585792| +-----------------------+-----------------------+ | 88 | 464227514732017603087171584| +-----------------------+-----------------------+ | 89 | 928455029464035206174343168| +-----------------------+-----------------------+ | 90 | 1856910058928070412348686336| +-----------------------+-----------------------+ | 91 | 3713820117856140824697372672| +-----------------------+-----------------------+ | 92 | 7427640235712281649394745344| +-----------------------+-----------------------+ | 93 | 14855280471424563298789490688| +-----------------------+-----------------------+ | 94 | 29710560942849126597578981376| +-----------------------+-----------------------+ | 95 | 59421121885698253195157962752| +-----------------------+-----------------------+ | 96 | 118842243771396506390315925504| +-----------------------+-----------------------+ | 97 | 237684487542793012780631851008| +-----------------------+-----------------------+ | 98 | 475368975085586025561263702016| +-----------------------+-----------------------+ | 99 | 950737950171172051122527404032| +-----------------------+-----------------------+ | 100 | 1901475900342344102245054808064| +-----------------------+-----------------------+ | 101 | 3802951800684688204490109616128| +-----------------------+-----------------------+ | 102 | 7605903601369376408980219232256| +-----------------------+-----------------------+ | 103 | 15211807202738752817960438464512| +-----------------------+-----------------------+ | 104 | 30423614405477505635920876929024| +-----------------------+-----------------------+ | 105 | 60847228810955011271841753858048| +-----------------------+-----------------------+ | 106 | 121694457621910022543683507716096| +-----------------------+-----------------------+ | 107 | 243388915243820045087367015432192| +-----------------------+-----------------------+ | 108 | 486777830487640090174734030864384| +-----------------------+-----------------------+ | 109 | 973555660975280180349468061728768| +-----------------------+-----------------------+ | 110 | 1947111321950560360698936123457536| +-----------------------+-----------------------+ | 111 | 3894222643901120721397872246915072| +-----------------------+-----------------------+ | 112 | 7788445287802241 | ``n_resources_i`` | ``n_candidates_i`` | +=======================+=======================+ | 3 (=min_resources) | 70 (=n_candidates) | +-----------------------+-----------------------+ | 3 * 2 = 6 | 70 // 2 = 35 | +-----------------------+-----------------------+ | 6 * 2 = 12 | 35 // 2 = 17 | +-----------------------+-----------------------+ | 12 * 2 = 24 | 17 // 2 = 8 | +-----------------------+-----------------------+ | 24 * 2 = 48 | 8 // 2 = 4 | +-----------------------+-----------------------+ | 48 * 2 = 96 | 4 // 2 = 2 | +-----------------------+-----------------------+ 我们可以注意到: - 该过程在评估 `factor=2` 候选者的第一次迭代时停止:最佳候选者是这2个候选者中的最佳者。没有必要再进行额外的迭代,因为那只会评估一个候选者(即我们已经确定的最佳候选者)。因此,通常情况下,我们希望最后一次迭代最多运行 `factor` 个候选者。如果最后一次迭代评估的候选者数量超过 `factor` ,那么这次迭代就简化为常规搜索(如 :class:`RandomizedSearchCV` 或 :class:`GridSearchCV` )。 - 每个 ``n_resources_i`` 都是 ``factor`` 和 ``min_resources`` 的倍数(这一点在上面的定义中得到了证实)。 每次迭代使用的资源量可以在 `n_resources_` 属性中找到。 选择资源 --------- 默认情况下,资源是以样本数量定义的。也就是说,每次迭代将使用越来越多的样本来进行训练。然而,您可以通过 ``resource`` 参数手动指定一个参数作为资源。以下是一个示例,其中资源是以随机森林的估计器数量定义的:: >>> from sklearn.datasets import make_classification >>> from sklearn.ensemble import RandomForestClassifier >>> from sklearn.experimental import enable_halving_search_cv # noqa >>> from sklearn.model_selection import HalvingGridSearchCV >>> import pandas as pd >>> >>> param_grid = {'max_depth': [3, 5, 10], ... 'min_samples_split': [2, 5, 10]} >>> base_estimator = RandomForestClassifier(random_state=0) >>> X, y = make_classification(n_samples=1000, random_state=0) >>> sh = HalvingGridSearchCV(base_estimator, param_grid, cv=5, ... factor=2, resource='n_estimators', ... max_resources=30).fit(X, y) >>> sh.best_estimator_ RandomForestClassifier(max_depth=5, n_estimators=24, random_state=0) 请注意,无法对参数网格中的一部分参数进行预算。 .. _exhausting_the_resources: 耗尽可用资源 ------------ 如上所述,每次迭代使用的资源数量取决于 `min_resources` 参数。 如果您有大量可用资源但开始时资源数量较低,可能会浪费一些资源(即未使用):: >>> from sklearn.datasets import make_classification >>> from sklearn.svm import SVC >>> from sklearn.experimental import enable_halving_search_cv # noqa >>> from sklearn.model_selection import HalvingGridSearchCV >>> import pandas as pd >>> param_grid= {'kernel': ('linear', 'rbf'), ... 'C': [1, 10, 100]} >>> base_estimator = SVC(gamma='scale') >>> X, y = make_classification(n_samples=1000) >>> sh = HalvingGridSearchCV(base_estimator, param_grid, cv=5, ... factor=2, min_resources=20).fit(X, y) >>> sh.n_resources_ [20, 40, 80] 搜索过程最多只会使用 80 个资源,而我们最大可用资源数量为 ``n_samples=1000`` 。这里,我们有 ``min_resources = r_0 = 20`` 。 对于 :class:`HalvingGridSearchCV` ,默认情况下, `min_resources` 参数设置为 'exhaust'。这意味着 `min_resources` 会自动设置,使得最后一次迭代可以使用尽可能多的资源,在 `max_resources` 限制内:: >>> sh = HalvingGridSearchCV(base_estimator, param_grid, cv=5, ... factor=2, min_resources='exhaust').fit(X, y) >>> sh.n_resources_ [250, 500, 1000] 这里的 `min_resources` 自动设置为 250,这导致最后一次迭代使用了所有资源。具体使用的值取决于候选参数的数量、 `max_resources` 和 `factor` 。 对于 :class:`HalvingRandomSearchCV` ,耗尽资源可以通过两种方式完成: - 通过设置 `min_resources='exhaust'` ,就像在 :class:`HalvingGridSearchCV` 中一样; - 通过设置 `n_candidates='exhaust'` 。 这两种选项是互斥的:使用 `min_resources='exhaust'` 需要知道候选参数的数量,反之 `n_candidates='exhaust'` 需要知道 `min_resources` 。 一般来说,耗尽总资源数会得到更好的最终候选参数,但会稍微增加时间消耗。 .. _aggressive_elimination: 候选参数的积极消除 -------------------- 理想情况下,我们希望最后一次迭代评估 ``factor`` 个候选参数(参见 :ref:`amount_of_resource_and_number_of_candidates` )。然后我们只需选择最好的一个。当可用资源相对于候选参数数量较少时,最后一次迭代可能需要评估超过 ``factor`` 个候选参数:: >>> from sklearn.datasets import make_classification >>> from sklearn.svm import SVC >>> from sklearn.experimental import enable_halving_search_cv # noqa >>> from sklearn.model_selection import HalvingGridSearchCV >>> import pandas as pd >>> >>> >>> param_grid = {'kernel': ('linear', 'rbf'), ... 'C': [1, 10, 100]} >>> base_estimator = SVC(gamma='scale') >>> X, y = make_classification(n_samples=1000) >>> sh = HalvingGridSearchCV(base_estimator, param_grid, cv=5, ... factor=2, max_resources=40, ... aggressive_elimination=False).fit(X, y) >>> sh.n_resources_ [20, 40] >>> sh.n_candidates_ [6, 3] 由于我们不能使用超过 ``max_resources=40`` 的资源,进程必须在第二轮迭代时停止,该轮评估了超过 ``factor=2`` 的候选对象。 使用 ``aggressive_elimination`` 参数,您可以强制搜索进程在最后一轮迭代时得到少于 ``factor`` 的候选对象。为此,进程将使用 ``min_resources`` 资源尽可能多地淘汰候选对象:: >>> sh = HalvingGridSearchCV(base_estimator, param_grid, cv=5, ... factor=2, ... max_resources=40, ... aggressive_elimination=True, ... ).fit(X, y) >>> sh.n_resources_ [20, 20, 40] >>> sh.n_candidates_ [6, 3, 2] 注意,我们在最后一轮迭代时只剩下 2 个候选对象,因为我们已经在第一轮迭代中使用 ``n_resources = min_resources = 20`` 淘汰了足够的候选对象。 .. _successive_halving_cv_results: 使用 `cv_results_` 属性分析结果 -------------------------------------------------- ``cv_results_`` 属性包含用于分析搜索结果的有用信息。它可以转换为 pandas 数据框,使用 ``df = pd.DataFrame(est.cv_results_)`` 。:class:`HalvingGridSearchCV` 和 :class:`HalvingRandomSearchCV` 的 ``cv_results_`` 属性类似于 :class:`GridSearchCV` 和 :class:`RandomizedSearchCV` ,但包含了与连续减半过程相关的额外信息。 以下是一个包含部分列的(截断的)数据框示例: ==== ====== =============== ================= ======================================================================================== .. iter n_resources mean_test_score params ==== ====== =============== ================= ======================================================================================== 0 0 125 0.983667 {'criterion': 'log_loss', 'max_depth': None, 'max_features': 9, 'min_samples_split': 5} 1 0 125 0.983667 {'criterion': 'gini', 'max_depth': None, 'max_features': 8, 'min_samples_split': 7} 2 0 125 0.983667 {'criterion': 'gini', 'max_depth': None, 'max_features': 10, 'min_samples_split': 10} 3 0 125 0.983667 {'criterion': 'log_loss', 'max_depth': None, 'max_features': 6, 'min_samples_split': 6} ... ... ... ... ... 15 2 500 0.951958 {'criterion': 'log_loss', 'max_depth': None, 'max_features': 9, 'min_samples_split': 10} 16 2 500 0.947958 {'criterion': 'gini', 'max_depth': None, 'max_features': 10, 'min_samples_split': 10} 17 2 500 0.951958 {'criterion': 'gini', 'max_depth': None, 'max_features': 10, 'min_samples_split': 4} 18 3 1000 0.961009 {'criterion': 'log_loss', 'max_depth': None, 'max_features': 9, 'min_samples_split': 10} 19 3 1000 0.955989 {'criterion': 'gini', 'max_depth': None, 'max_features': 10, 'min_samples_split': 4} ==== ====== =============== ================= ======================================================================================== 每一行对应一个给定的参数组合(候选)和一个给定的迭代。迭代由 ``iter`` 列给出。 ``n_resources`` 列告诉您使用了多少资源。 在上面的示例中,最佳参数组合是 ``{'criterion': 'log_loss', 'max_depth': None, 'max_features': 9, 'min_samples_split': 10}`` ,因为它达到了最后一轮迭代(3)并获得了最高分数:0.96。 .. rubric:: 参考文献 .. [1] K. Jamieson, A. Talwalkar, `非随机最佳臂识别和超参数优化 `_ , 在机器学习研究会议论文集, 2016. .. [2] L. Li, K. Jamieson, G. DeSalvo, A. Rostamizadeh, A. Talwalkar, :arxiv:`Hyperband: 一种基于强盗的新型超参数优化方法 <1603.06560>`_ , 在机器学习研究 18, 2018. .. _grid_search_tips: 参数搜索技巧 ============== .. _gridsearch_scoring: 指定目标度量标准 ------------------ 默认情况下,参数搜索使用估计器的 ``score`` 函数来评估参数设置。这些是分类的 :func:`sklearn.metrics.accuracy_score` 和回归的 :func:`sklearn.metrics.r2_score` 。对于某些应用,其他评分函数更合适(例如在不平衡分类中,准确度分数通常不具有信息性)。可以通过大多数参数搜索工具的 ``scoring`` 参数指定替代评分函数。有关更多详细信息,请参阅 :ref:`scoring_parameter` 。 .. _multimetric_grid_search: 指定多个度量标准进行评估 -------------------------- :class:`GridSearchCV` 和 :class:`RandomizedSearchCV` 允许为 ``scoring`` 参数指定多个度量标准。 多度量评分可以指定为预定义评分名称的字符串列表,或映射评分器名称到评分器函数和/或预定义评分器名称的字典。有关更多详细信息,请参阅 :ref:`multimetric_scoring` 。 当指定多个度量标准时,必须将 ``refit`` 参数设置为 用于确定将在整个数据集上构建 ``best_estimator_`` 的 ``best_params_`` 的指标(字符串)。如果不需要重新拟合搜索结果,请设置 ``refit=False`` 。将 refit 保留为默认值 ``None`` 在使用多个指标时会导致错误。 有关示例用法,请参阅 :ref:`sphx_glr_auto_examples_model_selection_plot_multi_metric_evaluation.py` 。 :class:`HalvingRandomSearchCV` 和 :class:`HalvingGridSearchCV` 不支持多指标评分。 .. _composite_grid_search: 复合估计器和参数空间 ---------------------- :class:`GridSearchCV` 和 :class:`RandomizedSearchCV` 允许搜索复合或嵌套估计器的参数,例如 :class:`~sklearn.pipeline.Pipeline` 、:class:`~sklearn.compose.ColumnTransformer` 、:class:`~sklearn.ensemble.VotingClassifier` 或 :class:`~sklearn.calibration.CalibratedClassifierCV` ,使用专门的 ``__`` 语法:: >>> from sklearn.model_selection import GridSearchCV >>> from sklearn.calibration import CalibratedClassifierCV >>> from sklearn.ensemble import RandomForestClassifier >>> from sklearn.datasets import make_moons >>> X, y = make_moons() >>> calibrated_forest = CalibratedClassifierCV( ... estimator=RandomForestClassifier(n_estimators=10)) >>> param_grid = { ... 'estimator__max_depth': [2, 4, 6, 8]} >>> search = GridSearchCV(calibrated_forest, param_grid, cv=5) >>> search.fit(X, y) GridSearchCV(cv=5, estimator=CalibratedClassifierCV(...), param_grid={'estimator__max_depth': [2, 4, 6, 8]}) 在这里, ```` 是嵌套估计器的参数名称,在本例中为 ``estimator`` 。如果元估计器是作为估计器集合构建的,例如在 `pipeline.Pipeline` 中,那么 ```` 指的是估计器的名称,参见 :ref:`pipeline_nested_parameters` 。实际上,可以有多个嵌套级别:: >>> from sklearn.pipeline import Pipeline >>> from sklearn.feature_selection import SelectKBest >>> pipe = Pipeline([ ... ('select', SelectKBest()), ... ('model', calibrated_forest)]) >>> param_grid = { ... 'select__k': [1, 2], ... 'model__estimator__max_depth': [2, 4, 6, 8]} >>> search = GridSearchCV(pipe, param_grid, cv=5).fit(X, y) 请参考 :ref:`pipeline` 以执行管道上的参数搜索。 模型选择:开发与评估 ---------------------- 通过评估各种参数设置来进行模型选择,可以看作是利用标记数据来“训练”网格参数的一种方式。 在评估最终模型时,重要的是要在网格搜索过程中未见过的样本上进行评估:建议将数据分为**开发集**(用于输入 :class:`GridSearchCV` 实例)和**评估集**,以计算性能指标。 这可以通过使用 :func:`train_test_split` 实用函数来实现。 并行性 -------- 参数搜索工具在每个数据折叠上独立评估每个参数组合。可以通过使用关键字 ``n_jobs=-1`` 来并行运行计算。有关更多详细信息,请参见函数签名,以及术语表中关于 :term:`n_jobs` 的条目。 对失败的鲁棒性 ---------------- 某些参数设置可能会导致无法拟合数据的一个或多个折叠。默认情况下,这将导致整个搜索失败,即使某些参数设置可以完全评估。设置 ``error_score=0`` (或 `=np.nan` )将使过程对这种失败具有鲁棒性,发出警告并将该折叠的分数设置为 0(或 `nan` ),但完成搜索。 .. _alternative_cv: 暴力参数搜索的替代方案 ======================== 特定于模型的交叉验证 ---------------------- 一些模型可以几乎同样高效地拟合某个参数范围内的一组数据,就像拟合该参数单一值的估计器一样。这一特性可以用于执行更高效的交叉验证,以进行该参数的模型选择。 最常见的适用于此策略的参数是编码正则化器强度的参数。在这种情况下,我们称我们计算了估计器的**正则化路径**。 以下是此类模型的列表: .. currentmodule:: sklearn .. autosummary:: linear_model.ElasticNetCV linear_model.LarsCV linear_model.LassoCV linear_model.LassoLarsCV linear_model.LogisticRegressionCV linear_model.MultiTaskElasticNetCV linear_model.MultiTaskLassoCV linear_model.OrthogonalMatchingPursuitCV linear_model.RidgeCV linear_model.RidgeClassifierCV 信息准则 --------- 一些模型可以通过计算单个正则化路径(而不是使用交叉验证时的多个路径),提供一个基于信息论的正则化参数最优估计的封闭形式公式。 以下是受益于赤池信息量准则(AIC)或贝叶斯信息量准则(BIC)进行自动模型选择的模型列表: .. autosummary:: linear_model.LassoLarsIC .. _out_of_bag: 包外估计 --------- 当使用基于装袋的集成方法,即通过有放回抽样生成新的训练集时,部分训练集未被使用。对于集成中的每个分类器,不同的部分训练集被留出。 这部分留出的数据可以用来估计泛化误差,而不必依赖于单独的验证集。这种估计是“免费的”,因为不需要额外的数据,并且可以用于模型选择。 目前这在以下类中实现: .. autosummary:: ensemble.RandomForestClassifier ensemble.RandomForestRegressor ensemble.ExtraTreesClassifier ensemble.ExtraTreesRegressor ensemble.GradientBoostingClassifier ensemble.GradientBoostingRegressor