1. 元数据路由#
Note
元数据路由API是实验性的,尚未在所有估计器中实现。有关更多信息,请参阅支持和不支持模型的列表 list of supported and unsupported
models 。它可能会在没有通常的弃用周期的情况下发生变化。默认情况下,此功能未启用。您可以通过将 enable_metadata_routing
标志设置为 True
来启用它:
>>> import sklearn
>>> sklearn.set_config(enable_metadata_routing=True)
请注意,本文档中介绍的方法和要求仅在您希望将 metadata (例如 sample_weight
)传递给方法时才相关。如果您仅传递 X
和 y
,而不传递其他参数/元数据给 fit 、transform 等方法,则无需设置任何内容。
本指南演示了如何在scikit-learn中将 metadata 路由和传递给对象。如果您正在开发与scikit-learn兼容的估计器或元估计器,可以查看我们的相关开发者指南 元数据路由 。
元数据是估计器、评分器或CV分割器在用户显式将其作为参数传递时考虑的数据。例如,KMeans
在其 fit()
方法中接受 sample_weight
,并使用它来计算其质心。 classes
被某些分类器消耗, groups
被某些分割器使用,但除了X和y之外传递给对象方法的任何数据都可以视为元数据。在scikit-learn 1.3版本之前,如果这些对象与其他对象一起使用,例如在:class:~model_selection.GridSearchCV
中接受 sample_weight
的评分器,则没有单一的API来传递此类元数据。
使用Metadata Routing API,我们可以通过 meta-estimators (例如 Pipeline
或 GridSearchCV
) 或如:func:~model_selection.cross_validate
这样的函数将元数据传递给估计器、评分器和CV分割器,这些函数将数据路由到其他对象。为了将元数据传递给诸如 fit
或 score
这样的方法,消费元数据的对象必须*请求*它。这是通过 set_{method}_request()
方法完成的,其中 {method}
被替换为请求元数据的方法的名称。例如,在其 fit()
方法中使用元数据的估计器将使用 set_fit_request()
,而评分器将使用 set_score_request()
。这些方法允许我们指定要请求哪些元数据,例如 set_fit_request(sample_weight=True)
。
对于如:class:~model_selection.GroupKFold
这样的分组分割器,默认情况下会请求 groups
参数。以下示例最好地展示了这一点。
1.1. 使用示例#
这里我们展示几个示例,展示一些常见的用例。我们的目标是传递 sample_weight
和 groups
通过 cross_validate
,它将元数据路由到:class:~linear_model.LogisticRegressionCV
和一个使用 make_scorer
制作的定制评分器,这两个都可以在其方法中使用元数据。在这些示例中,我们希望分别设置是否在不同的 consumers 中使用元数据。
本节中的示例需要以下导入和数据:
>>> import numpy as np
>>> from sklearn.metrics import make_scorer, accuracy_score
>>> from sklearn.linear_model import LogisticRegressionCV, LogisticRegression
>>> from sklearn.model_selection import cross_validate, GridSearchCV, GroupKFold
>>> from sklearn.feature_selection import SelectKBest
>>> from sklearn.pipeline import make_pipeline
>>> n_samples, n_features = 100, 4
>>> rng = np.random.RandomState(42)
1.1.1. 加权评分和拟合#
内部使用的分割器 LogisticRegressionCV
、 GroupKFold
默认请求 groups
。然而,我们需要通过在 LogisticRegressionCV
的 set_fit_request()
方法和 make_scorer
的 set_score_request()
方法中指定 sample_weight=True
来显式请求 sample_weight
。这两个 consumers 都知道如何在它们的 fit()
或 score()
方法中使用 sample_weight
。然后我们可以在 cross_validate
中传递元数据,这将把它路由到任何活动的消费者:
>>> weighted_acc = make_scorer(accuracy_score).set_score_request(sample_weight=True)
>>> lr = LogisticRegressionCV(
... cv=GroupKFold(),
... scoring=weighted_acc
... ).set_fit_request(sample_weight=True)
>>> cv_results = cross_validate(
... lr,
... X,
... y,
... params={"sample_weight": my_weights, "groups": my_groups},
... cv=GroupKFold(),
... scoring=weighted_acc,
... )
请注意,在这个例子中,cross_validate
将 my_weights
路由到评分器和 LogisticRegressionCV
。
如果我们将在 cross_validate
的参数中传递 sample_weight
,但没有设置任何对象来请求它,将引发 UnsetMetadataPassedError
,提示我们需要显式设置路由位置。同样,如果传递了 params={"sample_weights": my_weights, ...}
(注意拼写错误,即 weights
而不是 weight
),由于 sample_weights
没有被任何底层对象请求,也会发生同样的情况。
1.1.2. 加权评分与非加权拟合#
当将元数据(如 sample_weight
)传递给 router ( meta-estimators 或路由函数)时,所有 sample_weight
consumers 都需要明确请求或明确不请求权重(即 True
或 False
)。因此,为了执行非加权拟合,我们需要配置 LogisticRegressionCV
不请求样本权重,以使 cross_validate
不传递这些权重:
>>> weighted_acc = make_scorer(accuracy_score).set_score_request(sample_weight=True)
>>> lr = LogisticRegressionCV(
... cv=GroupKFold(), scoring=weighted_acc,
... ).set_fit_request(sample_weight=False)
>>> cv_results = cross_validate(
... lr,
... X,
... y,
... cv=GroupKFold(),
... params={"sample_weight": my_weights, "groups": my_groups},
... scoring=weighted_acc,
... )
如果未调用 linear_model.LogisticRegressionCV.set_fit_request
,cross_validate
将引发错误,因为 sample_weight
被传递但 LogisticRegressionCV
未明确配置为识别这些权重。
1.1.3. 非加权特征选择#
仅当对象的方法知道如何使用元数据时,才能进行元数据路由,这通常意味着它们将其作为显式参数。只有在这种情况下,我们才能使用 set_fit_request(sample_weight=True)
等方法为元数据设置请求值。这使得对象成为 consumer 。
与 LogisticRegressionCV
不同, SelectKBest
无法消费权重,因此在其实例上未设置 sample_weight
的请求值,并且 sample_weight
不会路由到它:
>>> weighted_acc = make_scorer(accuracy_score).set_score_request(sample_weight=True)
1.1.4. 不同的评分和拟合权重#
尽管 make_scorer
和 LogisticRegressionCV
都期望接收键 sample_weight
,我们可以使用别名将不同的权重传递给不同的消费者。在这个例子中,我们将 scoring_weight
传递给评分器,并将 fitting_weight
传递给 LogisticRegressionCV
>>> weighted_acc = make_scorer(accuracy_score).set_score_request(
... sample_weight="scoring_weight"
... )
>>> lr = LogisticRegressionCV(
... cv=GroupKFold(), scoring=weighted_acc,
... ).set_fit_request(sample_weight="fitting_weight")
>>> cv_results = cross_validate(
... lr,
... X,
... y,
... cv=GroupKFold(),
... params={
... "scoring_weight": my_weights,
... "fitting_weight": my_other_weights,
... "groups": my_groups,
... },
... scoring=weighted_acc,
... )
1.2. API 接口#
consumer 是一个对象(估计器、元估计器、评分器、分割器),它在其方法(例如 fit
、 predict
、 inverse_transform
、 transform
、 score
、 split
)中至少接受并使用某些 metadata 。仅将元数据转发给其他对象(子估计器、评分器或分割器)而不自行使用元数据的元估计器不是消费者。(元)估计器将元数据路由到其他对象的是 router 。一个(元)估计器可以同时是 consumer 和 router 。
(Meta-)Estimators和splitters为每个方法暴露了一个 set_{method}_request
方法,该方法至少接受一个元数据。例如,如果一个estimator在 fit
和 score
中支持 sample_weight
,它会暴露 estimator.set_fit_request(sample_weight=value)
和 estimator.set_score_request(sample_weight=value)
。这里的 value
可以是:
True
:方法请求一个sample_weight
。这意味着如果提供了元数据,它将被使用,否则不会引发错误。False
:方法不请求sample_weight
。None
:如果传递了sample_weight
,路由器将引发错误。这几乎在所有情况下都是对象实例化时的默认值,并确保当传递元数据时,用户明确设置元数据请求。唯一的例外是Group*Fold
splitters。"param_name"
:如果我们要向不同的消费者传递不同的权重,这是sample_weight
的别名。如果使用了别名,元估计器不应该将"param_name"
转发给消费者,而是转发sample_weight
,因为消费者会期望一个名为sample_weight
的参数。这意味着对象所需的元数据(例如sample_weight
)和用户提供的变量名(例如my_weights
)之间的映射是在路由器级别完成的,而不是由消费对象本身完成的。
使用 set_score_request
以相同的方式为评分器请求元数据。
如果用户传递了一个元数据,例如 sample_weight
,所有可能消费 sample_weight
的对象的元数据请求应由用户设置,否则路由器对象会引发错误。例如,以下代码会引发错误,因为它没有明确指定是否应将 sample_weight
传递给估计器的评分器:
>>> param_grid = {"C": [0.1, 1]}
>>> lr = LogisticRegression().set_fit_request(sample_weight=True)
>>> try:
... GridSearchCV(
... estimator=lr, param_grid=param_grid
... ).fit(X, y, sample_weight=my_weights)
... except ValueError as e:
... print(e)
[sample_weight] 被传递但未显式设置为请求或未请求,这与在 GridSearchCV.fit 中使用的 LogisticRegression.score 有关。
对于每个您想要请求/忽略的元数据,请调用 `LogisticRegression.set_score_request({metadata}=True/False)` 。
可以通过显式设置请求值来解决此问题:
>>> lr = LogisticRegression().set_fit_request(
... sample_weight=True
... ).set_score_request(sample_weight=False)
在 使用示例 部分的最后,我们禁用元数据路由的配置标志:
>>> sklearn.set_config(enable_metadata_routing=False)
1.3. 元数据路由支持状态#
所有消费者(即仅消费元数据而不路由它们的简单估计器)都支持元数据路由,这意味着它们可以在支持元数据路由的元估计器内部使用。然而,对元估计器的元数据路由支持正在开发中,以下是支持和不支持元数据路由的元估计器和工具列表。
支持元数据路由的元估计器和函数:
尚不支持元数据路由的元估计器和工具: