.. DO NOT EDIT. .. THIS FILE WAS AUTOMATICALLY GENERATED BY SPHINX-GALLERY. .. TO MAKE CHANGES, EDIT THE SOURCE PYTHON FILE: .. "auto_examples/miscellaneous/plot_metadata_routing.py" .. LINE NUMBERS ARE GIVEN BELOW. .. only:: html .. note:: :class: sphx-glr-download-link-note :ref:`Go to the end ` to download the full example code. or to run this example in your browser via Binder .. rst-class:: sphx-glr-example-title .. _sphx_glr_auto_examples_miscellaneous_plot_metadata_routing.py: ================ 元数据路由 ================ .. currentmodule:: sklearn 本文档展示了如何在scikit-learn中使用:ref:`元数据路由机制` 将元数据路由到使用它们的估计器、评分器和交叉验证分割器。 为了更好地理解以下文档,我们需要介绍两个概念:路由器和消费者。路由器是一个将给定数据和元数据转发到其他对象的对象。在大多数情况下,路由器是一个:term:`元估计器` ,即一个将另一个估计器作为参数的估计器。像:func:`sklearn.model_selection.cross_validate` 这样的函数,它将估计器作为参数并转发数据和元数据,也是一个路由器。 另一方面,消费者是一个接受并使用给定元数据的对象。例如,在其:term:`fit` 方法中考虑 ``sample_weight`` 的估计器是 ``sample_weight`` 的消费者。 一个对象可以同时是路由器和消费者。例如,一个元估计器可能在某些计算中考虑 ``sample_weight`` ,但它也可能将其路由到底层估计器。 首先是一些导入和一些用于其余脚本的随机数据。 .. GENERATED FROM PYTHON SOURCE LINES 20-53 .. code-block:: Python import warnings from pprint import pprint import numpy as np from sklearn import set_config from sklearn.base import ( BaseEstimator, ClassifierMixin, MetaEstimatorMixin, RegressorMixin, TransformerMixin, clone, ) from sklearn.linear_model import LinearRegression from sklearn.utils import metadata_routing from sklearn.utils.metadata_routing import ( MetadataRouter, MethodMapping, get_routing_for_object, process_routing, ) from sklearn.utils.validation import check_is_fitted n_samples, n_features = 100, 4 rng = np.random.RandomState(42) X = rng.rand(n_samples, n_features) y = rng.randint(0, 2, size=n_samples) my_groups = rng.randint(0, 10, size=n_samples) my_weights = rng.rand(n_samples) my_other_weights = rng.rand(n_samples) .. GENERATED FROM PYTHON SOURCE LINES 54-55 元数据路由仅在明确启用时可用: .. GENERATED FROM PYTHON SOURCE LINES 55-59 .. code-block:: Python set_config(enable_metadata_routing=True) .. GENERATED FROM PYTHON SOURCE LINES 60-61 这个实用函数是一个虚拟函数,用于检查是否传递了元数据: .. GENERATED FROM PYTHON SOURCE LINES 61-72 .. code-block:: Python def check_metadata(obj, **kwargs): for key, value in kwargs.items(): if value is not None: print( f"Received {key} of length = {len(value)} in {obj.__class__.__name__}." ) else: print(f"{key} is None in {obj.__class__.__name__}.") .. GENERATED FROM PYTHON SOURCE LINES 73-74 一个用于美观地打印对象路由信息的实用函数: .. GENERATED FROM PYTHON SOURCE LINES 74-79 .. code-block:: Python def print_routing(obj): pprint(obj.get_metadata_routing()._serialize()) .. GENERATED FROM PYTHON SOURCE LINES 80-85 消耗估计器 ----------- 在这里我们演示了一个估计器如何暴露所需的 API 以支持作为消费者的元数据路由。想象一个简单的分类器在其 ``fit`` 方法中接受 ``sample_weight`` 作为元数据,并在其 ``predict`` 方法中接受 ``groups`` : .. GENERATED FROM PYTHON SOURCE LINES 85-101 .. code-block:: Python class ExampleClassifier(ClassifierMixin, BaseEstimator): def fit(self, X, y, sample_weight=None): check_metadata(self, sample_weight=sample_weight) # 所有分类器在拟合后都需要暴露一个 classes_ 属性。 self.classes_ = np.array([0, 1]) return self def predict(self, X, groups=None): check_metadata(self, groups=groups) # 返回一个常数值1,不是一个非常聪明的分类器! return np.ones(len(X)) .. GENERATED FROM PYTHON SOURCE LINES 102-105 上述估计器现在具备了处理元数据所需的一切。这是通过一些在 :class:`~base.BaseEstimator` 中完成的魔法实现的。现在,上述类公开了三个方法: ``set_fit_request`` 、 ``set_predict_request`` 和 ``get_metadata_routing`` 。此外,还有一个用于 ``sample_weight`` 的 ``set_score_request`` ,因为 :class:`~base.ClassifierMixin` 实现了一个接受 ``sample_weight`` 的 ``score`` 方法。同样的情况也适用于继承自 :class:`~base.RegressorMixin` 的回归器。 默认情况下,不会请求任何元数据,我们可以看到如下: .. GENERATED FROM PYTHON SOURCE LINES 105-108 .. code-block:: Python print_routing(ExampleClassifier()) .. rst-class:: sphx-glr-script-out .. code-block:: none {'fit': {'sample_weight': None}, 'predict': {'groups': None}, 'score': {'sample_weight': None}} .. GENERATED FROM PYTHON SOURCE LINES 109-110 上述输出意味着 `ExampleClassifier` 并不需要 `sample_weight` 和 `groups` ,如果路由器接收到这些元数据,它应该抛出一个错误,因为用户没有明确设置它们是否是必需的。同样的情况也适用于从 :class:`~base.ClassifierMixin` 继承的 `score` 方法中的 `sample_weight` 。为了明确设置这些元数据的请求值,我们可以使用以下方法: .. GENERATED FROM PYTHON SOURCE LINES 110-120 .. code-block:: Python est = ( ExampleClassifier() .set_fit_request(sample_weight=False) .set_predict_request(groups=True) .set_score_request(sample_weight=False) ) print_routing(est) .. rst-class:: sphx-glr-script-out .. code-block:: none {'fit': {'sample_weight': False}, 'predict': {'groups': True}, 'score': {'sample_weight': False}} .. GENERATED FROM PYTHON SOURCE LINES 121-123 .. 注意 :: 请注意,只要上述估计器未在元估计器中使用,用户就不需要设置任何元数据请求,并且设置的值将被忽略,因为消费者不会验证或路由给定的元数据。上述估计器的简单使用将按预期工作。 .. GENERATED FROM PYTHON SOURCE LINES 123-129 .. code-block:: Python est = ExampleClassifier() est.fit(X, y, sample_weight=my_weights) est.predict(X[:3, :], groups=my_groups) .. rst-class:: sphx-glr-script-out .. code-block:: none Received sample_weight of length = 100 in ExampleClassifier. Received groups of length = 100 in ExampleClassifier. array([1., 1., 1.]) .. GENERATED FROM PYTHON SOURCE LINES 130-133 路由元估计器 ---------------------- 现在,我们展示如何设计一个元估计器作为路由器。作为一个简化的例子,这里有一个元估计器,它除了路由元数据之外没有做太多其他事情。 .. GENERATED FROM PYTHON SOURCE LINES 133-179 .. code-block:: Python class MetaClassifier(MetaEstimatorMixin, ClassifierMixin, BaseEstimator): def __init__(self, estimator): self.estimator = estimator def get_metadata_routing(self): # 此方法定义了该元估计器的路由。 # 为此,会创建一个 `MetadataRouter` 实例,并将路由添加到其中。更多解释如下。 router = MetadataRouter(owner=self.__class__.__name__).add( estimator=self.estimator, method_mapping=MethodMapping() .add(caller="fit", callee="fit") .add(caller="predict", callee="predict") .add(caller="score", callee="score"), ) return router def fit(self, X, y, **fit_params): # `get_routing_for_object` 返回一个由上述 `get_metadata_routing` 方法构建的 `MetadataRouter` 的副本,该方法在内部被调用。 request_router = get_routing_for_object(self) # 元估计器负责验证给定的元数据。 # `method` 指的是父类的方法,即本例中的 `fit` 。 request_router.validate_metadata(params=fit_params, method="fit") # `MetadataRouter.route_params` 根据 MetadataRouter 定义的路由信息,将给定的元数据映射到底层估计器所需的元数据。类型为 `Bunch` 的输出包含每个消费对象的键,这些键包含其消费方法的键,然后这些键包含应路由到它们的元数据的键。 routed_params = request_router.route_params(params=fit_params, caller="fit") # 子估计器被拟合,其类别被归属于元估计器。 self.estimator_ = clone(self.estimator).fit(X, y, **routed_params.estimator.fit) self.classes_ = self.estimator_.classes_ return self def predict(self, X, **predict_params): check_is_fitted(self) # 正如在 `fit` 中一样,我们获得了对象的 MetadataRouter 的副本, request_router = get_routing_for_object(self) # 然后我们验证给定的元数据, request_router.validate_metadata(params=predict_params, method="predict") # 然后准备底层 `predict` 方法的输入。 routed_params = request_router.route_params( params=predict_params, caller="predict" ) return self.estimator_.predict(X, **routed_params.estimator.predict) .. GENERATED FROM PYTHON SOURCE LINES 180-187 让我们分解上述代码的不同部分。 首先,:meth:`~utils.metadata_routing.get_routing_for_object` 方法接收我们的元估计器( ``self`` ),并返回一个 :class:`~utils.metadata_routing.MetadataRouter` 对象,或者如果对象是消费者,则返回一个 :class:`~utils.metadata_routing.MetadataRequest` 对象,这取决于估计器的 ``get_metadata_routing`` 方法的输出。 然后在每个方法中,我们使用 ``route_params`` 方法构建一个形式为 ``{"object_name": {"method_name": {"metadata": value}}}`` 的字典,以传递给底层估计器的方法。 ``object_name`` (在上述 ``routed_params.estimator.fit`` 示例中为 ``estimator`` )与在 ``get_metadata_routing`` 中添加的相同。 ``validate_metadata`` 确保所有给定的元数据都被请求,以避免静默错误。 接下来,我们将说明不同的行为,特别是引发的错误类型。 .. GENERATED FROM PYTHON SOURCE LINES 187-193 .. code-block:: Python meta_est = MetaClassifier( estimator=ExampleClassifier().set_fit_request(sample_weight=True) ) meta_est.fit(X, y, sample_weight=my_weights) .. rst-class:: sphx-glr-script-out .. code-block:: none Received sample_weight of length = 100 in ExampleClassifier. .. raw:: html
MetaClassifier(estimator=ExampleClassifier())
In a Jupyter environment, please rerun this cell to show the HTML representation or trust the notebook.
On GitHub, the HTML representation is unable to render, please try loading this page with nbviewer.org.


.. GENERATED FROM PYTHON SOURCE LINES 194-195 请注意,上述示例通过 `ExampleClassifier` 调用了我们的实用函数 `check_metadata()` 。它检查 ``sample_weight`` 是否被正确传递。如果没有正确传递,如以下示例所示,它将打印 ``sample_weight`` 为 ``None`` : .. GENERATED FROM PYTHON SOURCE LINES 195-199 .. code-block:: Python meta_est.fit(X, y) .. rst-class:: sphx-glr-script-out .. code-block:: none sample_weight is None in ExampleClassifier. .. raw:: html
MetaClassifier(estimator=ExampleClassifier())
In a Jupyter environment, please rerun this cell to show the HTML representation or trust the notebook.
On GitHub, the HTML representation is unable to render, please try loading this page with nbviewer.org.


.. GENERATED FROM PYTHON SOURCE LINES 200-201 如果我们传递了未知的元数据,将会引发错误: .. GENERATED FROM PYTHON SOURCE LINES 201-207 .. code-block:: Python try: meta_est.fit(X, y, test=my_weights) except TypeError as e: print(e) .. rst-class:: sphx-glr-script-out .. code-block:: none MetaClassifier.fit got unexpected argument(s) {'test'}, which are not routed to any object. .. GENERATED FROM PYTHON SOURCE LINES 208-209 如果我们传递一个未明确请求的元数据: .. GENERATED FROM PYTHON SOURCE LINES 209-215 .. code-block:: Python try: meta_est.fit(X, y, sample_weight=my_weights).predict(X, groups=my_groups) except ValueError as e: print(e) .. rst-class:: sphx-glr-script-out .. code-block:: none Received sample_weight of length = 100 in ExampleClassifier. [groups] are passed but are not explicitly set as requested or not requested for ExampleClassifier.predict, which is used within MetaClassifier.predict. Call `ExampleClassifier.set_predict_request({metadata}=True/False)` for each metadata you want to request/ignore. .. GENERATED FROM PYTHON SOURCE LINES 216-217 此外,如果我们明确设置为不请求,但它已被提供: .. GENERATED FROM PYTHON SOURCE LINES 217-228 .. code-block:: Python meta_est = MetaClassifier( estimator=ExampleClassifier() .set_fit_request(sample_weight=True) .set_predict_request(groups=False) ) try: meta_est.fit(X, y, sample_weight=my_weights).predict(X[:3, :], groups=my_groups) except TypeError as e: print(e) .. rst-class:: sphx-glr-script-out .. code-block:: none Received sample_weight of length = 100 in ExampleClassifier. MetaClassifier.predict got unexpected argument(s) {'groups'}, which are not routed to any object. .. GENERATED FROM PYTHON SOURCE LINES 229-230 另一个需要介绍的概念是 **别名元数据** 。这是指估计器请求的元数据使用了不同于默认变量名的变量名。例如,在一个包含两个估计器的管道中,一个可以请求 ``sample_weight1`` ,另一个可以请求 ``sample_weight2`` 。需要注意的是,这并不会改变估计器的预期,它只是告诉元估计器如何将提供的元数据映射到所需的内容。以下是一个示例,我们将 ``aliased_sample_weight`` 传递给元估计器,但元估计器理解 ``aliased_sample_weight`` 是 ``sample_weight`` 的别名,并将其作为 ``sample_weight`` 传递给底层估计器: .. GENERATED FROM PYTHON SOURCE LINES 230-236 .. code-block:: Python meta_est = MetaClassifier( estimator=ExampleClassifier().set_fit_request(sample_weight="aliased_sample_weight") ) meta_est.fit(X, y, aliased_sample_weight=my_weights) .. rst-class:: sphx-glr-script-out .. code-block:: none Received sample_weight of length = 100 in ExampleClassifier. .. raw:: html
MetaClassifier(estimator=ExampleClassifier())
In a Jupyter environment, please rerun this cell to show the HTML representation or trust the notebook.
On GitHub, the HTML representation is unable to render, please try loading this page with nbviewer.org.


.. GENERATED FROM PYTHON SOURCE LINES 237-238 传递 ``sample_weight`` 会失败,因为它是通过别名请求的,而没有请求名为 ``sample_weight`` 的参数: .. GENERATED FROM PYTHON SOURCE LINES 238-244 .. code-block:: Python try: meta_est.fit(X, y, sample_weight=my_weights) except TypeError as e: print(e) .. rst-class:: sphx-glr-script-out .. code-block:: none MetaClassifier.fit got unexpected argument(s) {'sample_weight'}, which are not routed to any object. .. GENERATED FROM PYTHON SOURCE LINES 245-246 这将引导我们了解 ``get_metadata_routing`` 。在 scikit-learn 中,路由的工作方式是消费者请求他们需要的内容,而路由器则传递这些内容。此外,路由器会公开其自身的需求,以便可以在另一个路由器内部使用,例如网格搜索对象中的管道。 ``get_metadata_routing`` 的输出是 :class:`~utils.metadata_routing.MetadataRouter` 的字典表示形式,其中包括所有嵌套对象请求的元数据的完整树及其相应的方法路由,即子估计器的哪个方法在元估计器的哪个方法中使用。 .. GENERATED FROM PYTHON SOURCE LINES 246-250 .. code-block:: Python print_routing(meta_est) .. rst-class:: sphx-glr-script-out .. code-block:: none {'estimator': {'mapping': [{'callee': 'fit', 'caller': 'fit'}, {'callee': 'predict', 'caller': 'predict'}, {'callee': 'score', 'caller': 'score'}], 'router': {'fit': {'sample_weight': 'aliased_sample_weight'}, 'predict': {'groups': None}, 'score': {'sample_weight': None}}}} .. GENERATED FROM PYTHON SOURCE LINES 251-254 如你所见,方法 ``fit`` 唯一请求的元数据是 ``"sample_weight"`` ,其别名为 ``"aliased_sample_weight"`` 。 ``~utils.metadata_routing.MetadataRouter`` 类使我们能够轻松创建路由对象,从而生成 ``get_metadata_routing`` 所需的输出。 为了理解别名在元估计器中的工作原理,想象我们的元估计器在另一个元估计器内部: .. GENERATED FROM PYTHON SOURCE LINES 254-259 .. code-block:: Python meta_meta_est = MetaClassifier(estimator=meta_est).fit( X, y, aliased_sample_weight=my_weights ) .. rst-class:: sphx-glr-script-out .. code-block:: none Received sample_weight of length = 100 in ExampleClassifier. .. GENERATED FROM PYTHON SOURCE LINES 260-273 在上面的例子中, `meta_meta_est` 的 ``fit`` 方法将这样调用其子估计器的 ``fit`` 方法: 用户将 `my_weights` 作为 `aliased_sample_weight` 输入到 `meta_meta_est` 中: meta_meta_est.fit(X, y, aliased_sample_weight=my_weights): ... 第一个子估计器 ( `meta_est` ) 需要 `aliased_sample_weight` self.estimator_.fit(X, y, aliased_sample_weight=aliased_sample_weight): ... 第二个子估计器( `est` )需要 `sample_weight` self.estimator_.fit(X, y, sample_weight=aliased_sample_weight): ... .. GENERATED FROM PYTHON SOURCE LINES 275-278 消耗和路由元估计器 -------------------- 对于一个稍微复杂一点的例子,考虑一个元估计器,它像之前一样将元数据路由到底层估计器,但它也在自己的方法中使用一些元数据。这个元估计器同时是一个消费者和路由器。实现这样的元估计器与之前的实现非常相似,但有一些调整。 .. GENERATED FROM PYTHON SOURCE LINES 278-331 .. code-block:: Python class RouterConsumerClassifier(MetaEstimatorMixin, ClassifierMixin, BaseEstimator): def __init__(self, estimator): self.estimator = estimator def get_metadata_routing(self): router = ( MetadataRouter(owner=self.__class__.__name__) # 定义元数据路由请求值以供元估计器使用 .add_self_request(self) # 定义元数据路由请求值以供子估算器使用 .add( estimator=self.estimator, method_mapping=MethodMapping() .add(caller="fit", callee="fit") .add(caller="predict", callee="predict") .add(caller="score", callee="score"), ) ) return router # 由于 `sample_weight` 在此处被使用和消耗,因此应在方法的签名中将其定义为显式参数。所有其他仅被传递的元数据将作为 `**fit_params` 传递: def fit(self, X, y, sample_weight, **fit_params): if self.estimator is None: raise ValueError("estimator cannot be None!") check_metadata(self, sample_weight=sample_weight) # 我们将 `sample_weight` 添加到 `fit_params` 字典中。 if sample_weight is not None: fit_params["sample_weight"] = sample_weight request_router = get_routing_for_object(self) request_router.validate_metadata(params=fit_params, method="fit") routed_params = request_router.route_params(params=fit_params, caller="fit") self.estimator_ = clone(self.estimator).fit(X, y, **routed_params.estimator.fit) self.classes_ = self.estimator_.classes_ return self def predict(self, X, **predict_params): check_is_fitted(self) # 正如在 `fit` 中一样,我们获得了对象的 MetadataRouter 的副本, request_router = get_routing_for_object(self) # 我们验证了给定的元数据, request_router.validate_metadata(params=predict_params, method="predict") # 然后准备底层 ``predict`` 方法的输入。 routed_params = request_router.route_params( params=predict_params, caller="predict" ) return self.estimator_.predict(X, **routed_params.estimator.predict) .. GENERATED FROM PYTHON SOURCE LINES 332-335 上述元估计器与我们之前的元估计器的主要区别在于,它在 ``fit`` 中显式接受 ``sample_weight`` 并将其包含在 ``fit_params`` 中。由于 ``sample_weight`` 是一个显式参数,我们可以确保该方法中存在 ``set_fit_request(sample_weight=...)`` 。该元估计器既是 ``sample_weight`` 的使用者,也是路由器。 在 ``get_metadata_routing`` 中,我们使用 ``add_self_request`` 将 ``self`` 添加到路由中,以表明此估计器正在使用 ``sample_weight`` 并且也是一个路由器;这还会在路由信息中添加一个 ``$self_request`` 键,如下所示。现在让我们来看一些例子: .. GENERATED FROM PYTHON SOURCE LINES 337-338 - 未请求元数据 .. GENERATED FROM PYTHON SOURCE LINES 338-343 .. code-block:: Python meta_est = RouterConsumerClassifier(estimator=ExampleClassifier()) print_routing(meta_est) .. rst-class:: sphx-glr-script-out .. code-block:: none {'$self_request': {'fit': {'sample_weight': None}, 'score': {'sample_weight': None}}, 'estimator': {'mapping': [{'callee': 'fit', 'caller': 'fit'}, {'callee': 'predict', 'caller': 'predict'}, {'callee': 'score', 'caller': 'score'}], 'router': {'fit': {'sample_weight': None}, 'predict': {'groups': None}, 'score': {'sample_weight': None}}}} .. GENERATED FROM PYTHON SOURCE LINES 344-345 - 子估计器请求的 ``sample_weight`` .. GENERATED FROM PYTHON SOURCE LINES 345-351 .. code-block:: Python meta_est = RouterConsumerClassifier( estimator=ExampleClassifier().set_fit_request(sample_weight=True) ) print_routing(meta_est) .. rst-class:: sphx-glr-script-out .. code-block:: none {'$self_request': {'fit': {'sample_weight': None}, 'score': {'sample_weight': None}}, 'estimator': {'mapping': [{'callee': 'fit', 'caller': 'fit'}, {'callee': 'predict', 'caller': 'predict'}, {'callee': 'score', 'caller': 'score'}], 'router': {'fit': {'sample_weight': True}, 'predict': {'groups': None}, 'score': {'sample_weight': None}}}} .. GENERATED FROM PYTHON SOURCE LINES 352-353 - ``sample_weight`` 被元估计器请求 .. GENERATED FROM PYTHON SOURCE LINES 353-359 .. code-block:: Python meta_est = RouterConsumerClassifier(estimator=ExampleClassifier()).set_fit_request( sample_weight=True ) print_routing(meta_est) .. rst-class:: sphx-glr-script-out .. code-block:: none {'$self_request': {'fit': {'sample_weight': True}, 'score': {'sample_weight': None}}, 'estimator': {'mapping': [{'callee': 'fit', 'caller': 'fit'}, {'callee': 'predict', 'caller': 'predict'}, {'callee': 'score', 'caller': 'score'}], 'router': {'fit': {'sample_weight': None}, 'predict': {'groups': None}, 'score': {'sample_weight': None}}}} .. GENERATED FROM PYTHON SOURCE LINES 360-363 请注意上面请求的元数据表示之间的差异。 - 我们还可以为元估计器和子估计器的拟合方法传递不同的值来设置元数据的别名: .. GENERATED FROM PYTHON SOURCE LINES 363-369 .. code-block:: Python meta_est = RouterConsumerClassifier( estimator=ExampleClassifier().set_fit_request(sample_weight="clf_sample_weight"), ).set_fit_request(sample_weight="meta_clf_sample_weight") print_routing(meta_est) .. rst-class:: sphx-glr-script-out .. code-block:: none {'$self_request': {'fit': {'sample_weight': 'meta_clf_sample_weight'}, 'score': {'sample_weight': None}}, 'estimator': {'mapping': [{'callee': 'fit', 'caller': 'fit'}, {'callee': 'predict', 'caller': 'predict'}, {'callee': 'score', 'caller': 'score'}], 'router': {'fit': {'sample_weight': 'clf_sample_weight'}, 'predict': {'groups': None}, 'score': {'sample_weight': None}}}} .. GENERATED FROM PYTHON SOURCE LINES 370-371 然而,元估计器的 ``fit`` 只需要子估计器的别名,并将它们自己的样本权重作为 `sample_weight` ,因为它不会验证和传递其自身所需的元数据。 .. GENERATED FROM PYTHON SOURCE LINES 371-374 .. code-block:: Python meta_est.fit(X, y, sample_weight=my_weights, clf_sample_weight=my_other_weights) .. rst-class:: sphx-glr-script-out .. code-block:: none Received sample_weight of length = 100 in RouterConsumerClassifier. Received sample_weight of length = 100 in ExampleClassifier. .. raw:: html
RouterConsumerClassifier(estimator=ExampleClassifier())
In a Jupyter environment, please rerun this cell to show the HTML representation or trust the notebook.
On GitHub, the HTML representation is unable to render, please try loading this page with nbviewer.org.


.. GENERATED FROM PYTHON SOURCE LINES 375-378 - 仅在子估计器上使用别名: 当我们不希望元估计器使用元数据,但子估计器应该使用时,这很有用。 .. GENERATED FROM PYTHON SOURCE LINES 378-382 .. code-block:: Python meta_est = RouterConsumerClassifier( estimator=ExampleClassifier().set_fit_request(sample_weight="aliased_sample_weight") ) print_routing(meta_est) .. rst-class:: sphx-glr-script-out .. code-block:: none {'$self_request': {'fit': {'sample_weight': None}, 'score': {'sample_weight': None}}, 'estimator': {'mapping': [{'callee': 'fit', 'caller': 'fit'}, {'callee': 'predict', 'caller': 'predict'}, {'callee': 'score', 'caller': 'score'}], 'router': {'fit': {'sample_weight': 'aliased_sample_weight'}, 'predict': {'groups': None}, 'score': {'sample_weight': None}}}} .. GENERATED FROM PYTHON SOURCE LINES 383-384 元估计器不能使用 `aliased_sample_weight` ,因为它期望将其作为 `sample_weight` 传递。即使在元估计器上设置了 `set_fit_request(sample_weight=True)` ,这一点仍然适用。 .. GENERATED FROM PYTHON SOURCE LINES 387-390 简单流水线 --------------- 一个稍微复杂一点的用例是一个类似于 :class:`~pipeline.Pipeline` 的元估计器。这里是一个元估计器,它接受一个转换器和一个分类器。当调用其 `fit` 方法时,它会在运行分类器之前应用转换器的 `fit` 和 `transform` 方法对数据进行转换。在 `predict` 时,它会在使用分类器的 `predict` 方法对新数据进行预测之前应用转换器的 `transform` 方法。 .. GENERATED FROM PYTHON SOURCE LINES 390-446 .. code-block:: Python class SimplePipeline(ClassifierMixin, BaseEstimator): def __init__(self, transformer, classifier): self.transformer = transformer self.classifier = classifier def get_metadata_routing(self): router = ( MetadataRouter(owner=self.__class__.__name__) # 我们为变压器添加了路由。 .add( transformer=self.transformer, method_mapping=MethodMapping() # 元数据的路由方式使其能够追溯 `SimplePipeline` 在其自身方法( `fit` 和 `predict` )中如何内部调用转换器的 `fit` 和 `transform` 方法。 .add(caller="fit", callee="fit") .add(caller="fit", callee="transform") .add(caller="predict", callee="transform"), ) # 我们为分类器添加了路由。 .add( classifier=self.classifier, method_mapping=MethodMapping() .add(caller="fit", callee="fit") .add(caller="predict", callee="predict"), ) ) return router def fit(self, X, y, **fit_params): routed_params = process_routing(self, "fit", **fit_params) self.transformer_ = clone(self.transformer).fit( X, y, **routed_params.transformer.fit ) X_transformed = self.transformer_.transform( X, **routed_params.transformer.transform ) self.classifier_ = clone(self.classifier).fit( X_transformed, y, **routed_params.classifier.fit ) return self def predict(self, X, **predict_params): routed_params = process_routing(self, "predict", **predict_params) X_transformed = self.transformer_.transform( X, **routed_params.transformer.transform ) return self.classifier_.predict( X_transformed, **routed_params.classifier.predict ) .. GENERATED FROM PYTHON SOURCE LINES 447-452 请注意使用 :class:`~utils.metadata_routing.MethodMapping` 来声明子估计器(被调用者)的哪些方法在元估计器(调用者)的哪些方法中使用。如你所见, `SimplePipeline` 在 ``fit`` 方法中使用了转换器的 ``transform`` 和 ``fit`` 方法,并在 ``predict`` 方法中使用了其 ``transform`` 方法,这就是你在管道类的路由结构中看到的实现。 在上述示例中与之前示例的另一个区别是使用了 :func:`~utils.metadata_routing.process_routing` ,该函数处理输入参数,进行必要的验证,并返回我们在之前示例中创建的 `routed_params` 。这减少了开发人员在每个元估计器方法中需要编写的样板代码。强烈建议开发人员使用此函数,除非有充分的理由不使用它。 为了测试上述管道,让我们添加一个示例转换器。 .. GENERATED FROM PYTHON SOURCE LINES 452-467 .. code-block:: Python class ExampleTransformer(TransformerMixin, BaseEstimator): def fit(self, X, y, sample_weight=None): check_metadata(self, sample_weight=sample_weight) return self def transform(self, X, groups=None): check_metadata(self, groups=groups) return X def fit_transform(self, X, y, sample_weight=None, groups=None): return self.fit(X, y, sample_weight).transform(X, groups) .. GENERATED FROM PYTHON SOURCE LINES 468-472 请注意,在上面的示例中,我们实现了调用带有适当元数据的 ``fit`` 和 ``transform`` 的 ``fit_transform`` 。这仅在 ``transform`` 接受元数据时才需要,因为:class:`~base.TransformerMixin` 中的默认 ``fit_transform`` 实现不会将元数据传递给 ``transform`` 。 现在我们可以测试我们的管道,并查看元数据是否被正确传递。 这个例子使用了我们的 `SimplePipeline` 、我们的 `ExampleTransformer` 和我们的 `RouterConsumerClassifier` ,后者使用了我们的 `ExampleClassifier` 。 .. GENERATED FROM PYTHON SOURCE LINES 472-493 .. code-block:: Python pipe = SimplePipeline( transformer=ExampleTransformer() # 我们设置了transformer的fit方法以接收sample_weight .set_fit_request(sample_weight=True) # 我们将变压器的变换设置为接收组 .set_transform_request(groups=True), classifier=RouterConsumerClassifier( estimator=ExampleClassifier() # 我们希望这个子估计器在拟合时接收样本权重 .set_fit_request(sample_weight=True) # 但在预测中没有分组 .set_predict_request(groups=False), ) # 并且我们希望元估计器也能接收样本权重 .set_fit_request(sample_weight=True), ) pipe.fit(X, y, sample_weight=my_weights, groups=my_groups).predict( X[:3], groups=my_groups ) .. rst-class:: sphx-glr-script-out .. code-block:: none Received sample_weight of length = 100 in ExampleTransformer. Received groups of length = 100 in ExampleTransformer. Received sample_weight of length = 100 in RouterConsumerClassifier. Received sample_weight of length = 100 in ExampleClassifier. Received groups of length = 100 in ExampleTransformer. groups is None in ExampleClassifier. array([1., 1., 1.]) .. GENERATED FROM PYTHON SOURCE LINES 494-497 弃用/默认值更改 ------------------ 在本节中,我们展示了如何处理路由器也成为消费者的情况,特别是当它消费与其子估计器相同的元数据时,或者消费者开始消费在旧版本中未消费的元数据时。在这种情况下,应发出警告一段时间,以便让用户知道行为已从以前的版本更改。 .. GENERATED FROM PYTHON SOURCE LINES 497-516 .. code-block:: Python class MetaRegressor(MetaEstimatorMixin, RegressorMixin, BaseEstimator): def __init__(self, estimator): self.estimator = estimator def fit(self, X, y, **fit_params): routed_params = process_routing(self, "fit", **fit_params) self.estimator_ = clone(self.estimator).fit(X, y, **routed_params.estimator.fit) def get_metadata_routing(self): router = MetadataRouter(owner=self.__class__.__name__).add( estimator=self.estimator, method_mapping=MethodMapping().add(caller="fit", callee="fit"), ) return router .. GENERATED FROM PYTHON SOURCE LINES 517-518 正如上文所述,如果 `my_weights` 不应该作为 `sample_weight` 传递给 `MetaRegressor` ,那么这是一个有效的用法: .. GENERATED FROM PYTHON SOURCE LINES 518-524 .. code-block:: Python reg = MetaRegressor(estimator=LinearRegression().set_fit_request(sample_weight=True)) reg.fit(X, y, sample_weight=my_weights) .. GENERATED FROM PYTHON SOURCE LINES 525-526 现在想象一下我们进一步开发了 ``MetaRegressor`` ,它现在也*使用* ``sample_weight`` : .. GENERATED FROM PYTHON SOURCE LINES 526-556 .. code-block:: Python class WeightedMetaRegressor(MetaEstimatorMixin, RegressorMixin, BaseEstimator): # 显示警告,提醒用户显式设置值,使用 # `.set_{method}_request(sample_weight={boolean})` __metadata_request__fit = {"sample_weight": metadata_routing.WARN} def __init__(self, estimator): self.estimator = estimator def fit(self, X, y, sample_weight=None, **fit_params): routed_params = process_routing( self, "fit", sample_weight=sample_weight, **fit_params ) check_metadata(self, sample_weight=sample_weight) self.estimator_ = clone(self.estimator).fit(X, y, **routed_params.estimator.fit) def get_metadata_routing(self): router = ( MetadataRouter(owner=self.__class__.__name__) .add_self_request(self) .add( estimator=self.estimator, method_mapping=MethodMapping().add(caller="fit", callee="fit"), ) ) return router .. GENERATED FROM PYTHON SOURCE LINES 557-558 上述实现与 ``MetaRegressor`` 几乎相同,并且由于在 ``__metadata_request__fit`` 中定义的默认请求值,在拟合时会引发警告。 .. GENERATED FROM PYTHON SOURCE LINES 558-568 .. code-block:: Python with warnings.catch_warnings(record=True) as record: WeightedMetaRegressor( estimator=LinearRegression().set_fit_request(sample_weight=False) ).fit(X, y, sample_weight=my_weights) for w in record: print(w.message) .. rst-class:: sphx-glr-script-out .. code-block:: none Received sample_weight of length = 100 in WeightedMetaRegressor. Support for sample_weight has recently been added to this class. To maintain backward compatibility, it is ignored now. Using `set_fit_request(sample_weight={True, False})` on this method of the class, you can set the request value to False to silence this warning, or to True to consume and use the metadata. .. GENERATED FROM PYTHON SOURCE LINES 569-570 当估计器消耗以前未消耗的元数据时,可以使用以下模式来警告用户。 .. GENERATED FROM PYTHON SOURCE LINES 570-589 .. code-block:: Python class ExampleRegressor(RegressorMixin, BaseEstimator): __metadata_request__fit = {"sample_weight": metadata_routing.WARN} def fit(self, X, y, sample_weight=None): check_metadata(self, sample_weight=sample_weight) return self def predict(self, X): return np.zeros(shape=(len(X))) with warnings.catch_warnings(record=True) as record: MetaRegressor(estimator=ExampleRegressor()).fit(X, y, sample_weight=my_weights) for w in record: print(w.message) .. rst-class:: sphx-glr-script-out .. code-block:: none sample_weight is None in ExampleRegressor. Support for sample_weight has recently been added to this class. To maintain backward compatibility, it is ignored now. Using `set_fit_request(sample_weight={True, False})` on this method of the class, you can set the request value to False to silence this warning, or to True to consume and use the metadata. .. GENERATED FROM PYTHON SOURCE LINES 590-591 最后,我们禁用元数据路由的配置标志: .. GENERATED FROM PYTHON SOURCE LINES 591-595 .. code-block:: Python set_config(enable_metadata_routing=False) .. GENERATED FROM PYTHON SOURCE LINES 596-604 第三方开发和scikit-learn依赖 --------------------------------------------------- 如上所述,信息通过 :class:`~utils.metadata_routing.MetadataRequest` 和 :class:`~utils.metadata_routing.MetadataRouter` 在类之间传递。强烈不建议这样做,但如果您严格希望拥有一个兼容 scikit-learn 的估计器,而不依赖于 scikit-learn 包,则可以使用与元数据路由相关的工具。如果满足以下所有条件,您完全不需要修改代码: - 你的估计器继承自 :class:`~base.BaseEstimator` - 估计器方法(例如 ``fit`` )所需的参数在方法的签名中明确定义,而不是使用 ``*args`` 或 ``*kwargs`` 。 - 你的估计器不会将任何元数据传递给底层对象,即它不是一个*路由器*。 .. rst-class:: sphx-glr-timing **Total running time of the script:** (0 minutes 0.022 seconds) .. _sphx_glr_download_auto_examples_miscellaneous_plot_metadata_routing.py: .. only:: html .. container:: sphx-glr-footer sphx-glr-footer-example .. container:: binder-badge .. image:: images/binder_badge_logo.svg :target: https://mybinder.org/v2/gh/scikit-learn/scikit-learn/main?urlpath=lab/tree/notebooks/auto_examples/miscellaneous/plot_metadata_routing.ipynb :alt: Launch binder :width: 150 px .. container:: sphx-glr-download sphx-glr-download-jupyter :download:`Download Jupyter notebook: plot_metadata_routing.ipynb ` .. container:: sphx-glr-download sphx-glr-download-python :download:`Download Python source code: plot_metadata_routing.py ` .. container:: sphx-glr-download sphx-glr-download-zip :download:`Download zipped: plot_metadata_routing.zip ` .. include:: plot_metadata_routing.recommendations .. only:: html .. rst-class:: sphx-glr-signature `Gallery generated by Sphinx-Gallery `_