Note
Go to the end to download the full example code. or to run this example in your browser via Binder
元数据路由#
本文档展示了如何在scikit-learn中使用:ref:元数据路由机制<metadata_routing>
将元数据路由到使用它们的估计器、评分器和交叉验证分割器。
为了更好地理解以下文档,我们需要介绍两个概念:路由器和消费者。路由器是一个将给定数据和元数据转发到其他对象的对象。在大多数情况下,路由器是一个:term:元估计器
,即一个将另一个估计器作为参数的估计器。像:func:sklearn.model_selection.cross_validate
这样的函数,它将估计器作为参数并转发数据和元数据,也是一个路由器。
另一方面,消费者是一个接受并使用给定元数据的对象。例如,在其:term:fit
方法中考虑 sample_weight
的估计器是 sample_weight
的消费者。
一个对象可以同时是路由器和消费者。例如,一个元估计器可能在某些计算中考虑 sample_weight
,但它也可能将其路由到底层估计器。
首先是一些导入和一些用于其余脚本的随机数据。
import warnings
from pprint import pprint
import numpy as np
from sklearn import set_config
from sklearn.base import (
BaseEstimator,
ClassifierMixin,
MetaEstimatorMixin,
RegressorMixin,
TransformerMixin,
clone,
)
from sklearn.linear_model import LinearRegression
from sklearn.utils import metadata_routing
from sklearn.utils.metadata_routing import (
MetadataRouter,
MethodMapping,
get_routing_for_object,
process_routing,
)
from sklearn.utils.validation import check_is_fitted
n_samples, n_features = 100, 4
rng = np.random.RandomState(42)
X = rng.rand(n_samples, n_features)
y = rng.randint(0, 2, size=n_samples)
my_groups = rng.randint(0, 10, size=n_samples)
my_weights = rng.rand(n_samples)
my_other_weights = rng.rand(n_samples)
元数据路由仅在明确启用时可用:
set_config(enable_metadata_routing=True)
这个实用函数是一个虚拟函数,用于检查是否传递了元数据:
def check_metadata(obj, **kwargs):
for key, value in kwargs.items():
if value is not None:
print(
f"Received {key} of length = {len(value)} in {obj.__class__.__name__}."
)
else:
print(f"{key} is None in {obj.__class__.__name__}.")
一个用于美观地打印对象路由信息的实用函数:
def print_routing(obj):
pprint(obj.get_metadata_routing()._serialize())
消耗估计器#
在这里我们演示了一个估计器如何暴露所需的 API 以支持作为消费者的元数据路由。想象一个简单的分类器在其 fit
方法中接受 sample_weight
作为元数据,并在其 predict
方法中接受 groups
:
class ExampleClassifier(ClassifierMixin, BaseEstimator):
def fit(self, X, y, sample_weight=None):
check_metadata(self, sample_weight=sample_weight)
# 所有分类器在拟合后都需要暴露一个 classes_ 属性。
self.classes_ = np.array([0, 1])
return self
def predict(self, X, groups=None):
check_metadata(self, groups=groups)
# 返回一个常数值1,不是一个非常聪明的分类器!
return np.ones(len(X))
上述估计器现在具备了处理元数据所需的一切。这是通过一些在 BaseEstimator
中完成的魔法实现的。现在,上述类公开了三个方法: set_fit_request
、 set_predict_request
和 get_metadata_routing
。此外,还有一个用于 sample_weight
的 set_score_request
,因为 ClassifierMixin
实现了一个接受 sample_weight
的 score
方法。同样的情况也适用于继承自 RegressorMixin
的回归器。
默认情况下,不会请求任何元数据,我们可以看到如下:
print_routing(ExampleClassifier())
{'fit': {'sample_weight': None},
'predict': {'groups': None},
'score': {'sample_weight': None}}
上述输出意味着 ExampleClassifier
并不需要 sample_weight
和 groups
,如果路由器接收到这些元数据,它应该抛出一个错误,因为用户没有明确设置它们是否是必需的。同样的情况也适用于从 ClassifierMixin
继承的 score
方法中的 sample_weight
。为了明确设置这些元数据的请求值,我们可以使用以下方法:
est = (
ExampleClassifier()
.set_fit_request(sample_weight=False)
.set_predict_request(groups=True)
.set_score_request(sample_weight=False)
)
print_routing(est)
{'fit': {'sample_weight': False},
'predict': {'groups': True},
'score': {'sample_weight': False}}
请注意,只要上述估计器未在元估计器中使用,用户就不需要设置任何元数据请求,并且设置的值将被忽略,因为消费者不会验证或路由给定的元数据。上述估计器的简单使用将按预期工作。
est = ExampleClassifier()
est.fit(X, y, sample_weight=my_weights)
est.predict(X[:3, :], groups=my_groups)
Received sample_weight of length = 100 in ExampleClassifier.
Received groups of length = 100 in ExampleClassifier.
array([1., 1., 1.])
路由元估计器#
现在,我们展示如何设计一个元估计器作为路由器。作为一个简化的例子,这里有一个元估计器,它除了路由元数据之外没有做太多其他事情。
class MetaClassifier(MetaEstimatorMixin, ClassifierMixin, BaseEstimator):
def __init__(self, estimator):
self.estimator = estimator
def get_metadata_routing(self):
# 此方法定义了该元估计器的路由。
# 为此,会创建一个 `MetadataRouter` 实例,并将路由添加到其中。更多解释如下。
router = MetadataRouter(owner=self.__class__.__name__).add(
estimator=self.estimator,
method_mapping=MethodMapping()
.add(caller="fit", callee="fit")
.add(caller="predict", callee="predict")
.add(caller="score", callee="score"),
)
return router
def fit(self, X, y, **fit_params):
# `get_routing_for_object` 返回一个由上述 `get_metadata_routing` 方法构建的 `MetadataRouter` 的副本,该方法在内部被调用。
request_router = get_routing_for_object(self)
# 元估计器负责验证给定的元数据。
# `method` 指的是父类的方法,即本例中的 `fit` 。
request_router.validate_metadata(params=fit_params, method="fit")
# `MetadataRouter.route_params` 根据 MetadataRouter 定义的路由信息,将给定的元数据映射到底层估计器所需的元数据。类型为 `Bunch` 的输出包含每个消费对象的键,这些键包含其消费方法的键,然后这些键包含应路由到它们的元数据的键。
routed_params = request_router.route_params(params=fit_params, caller="fit")
# 子估计器被拟合,其类别被归属于元估计器。
self.estimator_ = clone(self.estimator).fit(X, y, **routed_params.estimator.fit)
self.classes_ = self.estimator_.classes_
return self
def predict(self, X, **predict_params):
check_is_fitted(self)
# 正如在 `fit` 中一样,我们获得了对象的 MetadataRouter 的副本,
request_router = get_routing_for_object(self)
# 然后我们验证给定的元数据,
request_router.validate_metadata(params=predict_params, method="predict")
# 然后准备底层 `predict` 方法的输入。
routed_params = request_router.route_params(
params=predict_params, caller="predict"
)
return self.estimator_.predict(X, **routed_params.estimator.predict)
让我们分解上述代码的不同部分。
首先,get_routing_for_object
方法接收我们的元估计器( self
),并返回一个 MetadataRouter
对象,或者如果对象是消费者,则返回一个 MetadataRequest
对象,这取决于估计器的 get_metadata_routing
方法的输出。
然后在每个方法中,我们使用 route_params
方法构建一个形式为 {"object_name": {"method_name": {"metadata": value}}}
的字典,以传递给底层估计器的方法。 object_name
(在上述 routed_params.estimator.fit
示例中为 estimator
)与在 get_metadata_routing
中添加的相同。 validate_metadata
确保所有给定的元数据都被请求,以避免静默错误。
接下来,我们将说明不同的行为,特别是引发的错误类型。
meta_est = MetaClassifier(
estimator=ExampleClassifier().set_fit_request(sample_weight=True)
)
meta_est.fit(X, y, sample_weight=my_weights)
Received sample_weight of length = 100 in ExampleClassifier.
请注意,上述示例通过 ExampleClassifier
调用了我们的实用函数 check_metadata()
。它检查 sample_weight
是否被正确传递。如果没有正确传递,如以下示例所示,它将打印 sample_weight
为 None
:
meta_est.fit(X, y)
sample_weight is None in ExampleClassifier.
如果我们传递了未知的元数据,将会引发错误:
try:
meta_est.fit(X, y, test=my_weights)
except TypeError as e:
print(e)
MetaClassifier.fit got unexpected argument(s) {'test'}, which are not routed to any object.
如果我们传递一个未明确请求的元数据:
try:
meta_est.fit(X, y, sample_weight=my_weights).predict(X, groups=my_groups)
except ValueError as e:
print(e)
Received sample_weight of length = 100 in ExampleClassifier.
[groups] are passed but are not explicitly set as requested or not requested for ExampleClassifier.predict, which is used within MetaClassifier.predict. Call `ExampleClassifier.set_predict_request({metadata}=True/False)` for each metadata you want to request/ignore.
此外,如果我们明确设置为不请求,但它已被提供:
meta_est = MetaClassifier(
estimator=ExampleClassifier()
.set_fit_request(sample_weight=True)
.set_predict_request(groups=False)
)
try:
meta_est.fit(X, y, sample_weight=my_weights).predict(X[:3, :], groups=my_groups)
except TypeError as e:
print(e)
Received sample_weight of length = 100 in ExampleClassifier.
MetaClassifier.predict got unexpected argument(s) {'groups'}, which are not routed to any object.
另一个需要介绍的概念是 别名元数据 。这是指估计器请求的元数据使用了不同于默认变量名的变量名。例如,在一个包含两个估计器的管道中,一个可以请求 sample_weight1
,另一个可以请求 sample_weight2
。需要注意的是,这并不会改变估计器的预期,它只是告诉元估计器如何将提供的元数据映射到所需的内容。以下是一个示例,我们将 aliased_sample_weight
传递给元估计器,但元估计器理解 aliased_sample_weight
是 sample_weight
的别名,并将其作为 sample_weight
传递给底层估计器:
meta_est = MetaClassifier(
estimator=ExampleClassifier().set_fit_request(sample_weight="aliased_sample_weight")
)
meta_est.fit(X, y, aliased_sample_weight=my_weights)
Received sample_weight of length = 100 in ExampleClassifier.
传递 sample_weight
会失败,因为它是通过别名请求的,而没有请求名为 sample_weight
的参数:
try:
meta_est.fit(X, y, sample_weight=my_weights)
except TypeError as e:
print(e)
MetaClassifier.fit got unexpected argument(s) {'sample_weight'}, which are not routed to any object.
这将引导我们了解 get_metadata_routing
。在 scikit-learn 中,路由的工作方式是消费者请求他们需要的内容,而路由器则传递这些内容。此外,路由器会公开其自身的需求,以便可以在另一个路由器内部使用,例如网格搜索对象中的管道。 get_metadata_routing
的输出是 MetadataRouter
的字典表示形式,其中包括所有嵌套对象请求的元数据的完整树及其相应的方法路由,即子估计器的哪个方法在元估计器的哪个方法中使用。
print_routing(meta_est)
{'estimator': {'mapping': [{'callee': 'fit', 'caller': 'fit'},
{'callee': 'predict', 'caller': 'predict'},
{'callee': 'score', 'caller': 'score'}],
'router': {'fit': {'sample_weight': 'aliased_sample_weight'},
'predict': {'groups': None},
'score': {'sample_weight': None}}}}
如你所见,方法 fit
唯一请求的元数据是 "sample_weight"
,其别名为 "aliased_sample_weight"
。 ~utils.metadata_routing.MetadataRouter
类使我们能够轻松创建路由对象,从而生成 get_metadata_routing
所需的输出。
为了理解别名在元估计器中的工作原理,想象我们的元估计器在另一个元估计器内部:
meta_meta_est = MetaClassifier(estimator=meta_est).fit(
X, y, aliased_sample_weight=my_weights
)
Received sample_weight of length = 100 in ExampleClassifier.
在上面的例子中, meta_meta_est
的 fit
方法将这样调用其子估计器的 fit
方法:
用户将 my_weights
作为 aliased_sample_weight
输入到 meta_meta_est
中:
meta_meta_est.fit(X, y, aliased_sample_weight=my_weights):
…
第一个子估计器 ( meta_est
) 需要 aliased_sample_weight
self.estimator_.fit(X, y, aliased_sample_weight=aliased_sample_weight):
…
第二个子估计器( est
)需要 sample_weight
self.estimator_.fit(X, y, sample_weight=aliased_sample_weight):
…
消耗和路由元估计器#
对于一个稍微复杂一点的例子,考虑一个元估计器,它像之前一样将元数据路由到底层估计器,但它也在自己的方法中使用一些元数据。这个元估计器同时是一个消费者和路由器。实现这样的元估计器与之前的实现非常相似,但有一些调整。
class RouterConsumerClassifier(MetaEstimatorMixin, ClassifierMixin, BaseEstimator):
def __init__(self, estimator):
self.estimator = estimator
def get_metadata_routing(self):
router = (
MetadataRouter(owner=self.__class__.__name__)
# 定义元数据路由请求值以供元估计器使用
.add_self_request(self)
# 定义元数据路由请求值以供子估算器使用
.add(
estimator=self.estimator,
method_mapping=MethodMapping()
.add(caller="fit", callee="fit")
.add(caller="predict", callee="predict")
.add(caller="score", callee="score"),
)
)
return router
# 由于 `sample_weight` 在此处被使用和消耗,因此应在方法的签名中将其定义为显式参数。所有其他仅被传递的元数据将作为 `**fit_params` 传递:
def fit(self, X, y, sample_weight, **fit_params):
if self.estimator is None:
raise ValueError("estimator cannot be None!")
check_metadata(self, sample_weight=sample_weight)
# 我们将 `sample_weight` 添加到 `fit_params` 字典中。
if sample_weight is not None:
fit_params["sample_weight"] = sample_weight
request_router = get_routing_for_object(self)
request_router.validate_metadata(params=fit_params, method="fit")
routed_params = request_router.route_params(params=fit_params, caller="fit")
self.estimator_ = clone(self.estimator).fit(X, y, **routed_params.estimator.fit)
self.classes_ = self.estimator_.classes_
return self
def predict(self, X, **predict_params):
check_is_fitted(self)
# 正如在 `fit` 中一样,我们获得了对象的 MetadataRouter 的副本,
request_router = get_routing_for_object(self)
# 我们验证了给定的元数据,
request_router.validate_metadata(params=predict_params, method="predict")
# 然后准备底层 ``predict`` 方法的输入。
routed_params = request_router.route_params(
params=predict_params, caller="predict"
)
return self.estimator_.predict(X, **routed_params.estimator.predict)
上述元估计器与我们之前的元估计器的主要区别在于,它在 fit
中显式接受 sample_weight
并将其包含在 fit_params
中。由于 sample_weight
是一个显式参数,我们可以确保该方法中存在 set_fit_request(sample_weight=...)
。该元估计器既是 sample_weight
的使用者,也是路由器。
在 get_metadata_routing
中,我们使用 add_self_request
将 self
添加到路由中,以表明此估计器正在使用 sample_weight
并且也是一个路由器;这还会在路由信息中添加一个 $self_request
键,如下所示。现在让我们来看一些例子:
未请求元数据
meta_est = RouterConsumerClassifier(estimator=ExampleClassifier())
print_routing(meta_est)
{'$self_request': {'fit': {'sample_weight': None},
'score': {'sample_weight': None}},
'estimator': {'mapping': [{'callee': 'fit', 'caller': 'fit'},
{'callee': 'predict', 'caller': 'predict'},
{'callee': 'score', 'caller': 'score'}],
'router': {'fit': {'sample_weight': None},
'predict': {'groups': None},
'score': {'sample_weight': None}}}}
子估计器请求的
sample_weight
meta_est = RouterConsumerClassifier(
estimator=ExampleClassifier().set_fit_request(sample_weight=True)
)
print_routing(meta_est)
{'$self_request': {'fit': {'sample_weight': None},
'score': {'sample_weight': None}},
'estimator': {'mapping': [{'callee': 'fit', 'caller': 'fit'},
{'callee': 'predict', 'caller': 'predict'},
{'callee': 'score', 'caller': 'score'}],
'router': {'fit': {'sample_weight': True},
'predict': {'groups': None},
'score': {'sample_weight': None}}}}
sample_weight
被元估计器请求
meta_est = RouterConsumerClassifier(estimator=ExampleClassifier()).set_fit_request(
sample_weight=True
)
print_routing(meta_est)
{'$self_request': {'fit': {'sample_weight': True},
'score': {'sample_weight': None}},
'estimator': {'mapping': [{'callee': 'fit', 'caller': 'fit'},
{'callee': 'predict', 'caller': 'predict'},
{'callee': 'score', 'caller': 'score'}],
'router': {'fit': {'sample_weight': None},
'predict': {'groups': None},
'score': {'sample_weight': None}}}}
请注意上面请求的元数据表示之间的差异。
我们还可以为元估计器和子估计器的拟合方法传递不同的值来设置元数据的别名:
meta_est = RouterConsumerClassifier(
estimator=ExampleClassifier().set_fit_request(sample_weight="clf_sample_weight"),
).set_fit_request(sample_weight="meta_clf_sample_weight")
print_routing(meta_est)
{'$self_request': {'fit': {'sample_weight': 'meta_clf_sample_weight'},
'score': {'sample_weight': None}},
'estimator': {'mapping': [{'callee': 'fit', 'caller': 'fit'},
{'callee': 'predict', 'caller': 'predict'},
{'callee': 'score', 'caller': 'score'}],
'router': {'fit': {'sample_weight': 'clf_sample_weight'},
'predict': {'groups': None},
'score': {'sample_weight': None}}}}
然而,元估计器的 fit
只需要子估计器的别名,并将它们自己的样本权重作为 sample_weight
,因为它不会验证和传递其自身所需的元数据。
meta_est.fit(X, y, sample_weight=my_weights, clf_sample_weight=my_other_weights)
Received sample_weight of length = 100 in RouterConsumerClassifier.
Received sample_weight of length = 100 in ExampleClassifier.
仅在子估计器上使用别名:
当我们不希望元估计器使用元数据,但子估计器应该使用时,这很有用。
meta_est = RouterConsumerClassifier(
estimator=ExampleClassifier().set_fit_request(sample_weight="aliased_sample_weight")
)
print_routing(meta_est)
{'$self_request': {'fit': {'sample_weight': None},
'score': {'sample_weight': None}},
'estimator': {'mapping': [{'callee': 'fit', 'caller': 'fit'},
{'callee': 'predict', 'caller': 'predict'},
{'callee': 'score', 'caller': 'score'}],
'router': {'fit': {'sample_weight': 'aliased_sample_weight'},
'predict': {'groups': None},
'score': {'sample_weight': None}}}}
元估计器不能使用 aliased_sample_weight
,因为它期望将其作为 sample_weight
传递。即使在元估计器上设置了 set_fit_request(sample_weight=True)
,这一点仍然适用。
简单流水线#
一个稍微复杂一点的用例是一个类似于 Pipeline
的元估计器。这里是一个元估计器,它接受一个转换器和一个分类器。当调用其 fit
方法时,它会在运行分类器之前应用转换器的 fit
和 transform
方法对数据进行转换。在 predict
时,它会在使用分类器的 predict
方法对新数据进行预测之前应用转换器的 transform
方法。
class SimplePipeline(ClassifierMixin, BaseEstimator):
def __init__(self, transformer, classifier):
self.transformer = transformer
self.classifier = classifier
def get_metadata_routing(self):
router = (
MetadataRouter(owner=self.__class__.__name__)
# 我们为变压器添加了路由。
.add(
transformer=self.transformer,
method_mapping=MethodMapping()
# 元数据的路由方式使其能够追溯 `SimplePipeline` 在其自身方法( `fit` 和 `predict` )中如何内部调用转换器的 `fit` 和 `transform` 方法。
.add(caller="fit", callee="fit")
.add(caller="fit", callee="transform")
.add(caller="predict", callee="transform"),
)
# 我们为分类器添加了路由。
.add(
classifier=self.classifier,
method_mapping=MethodMapping()
.add(caller="fit", callee="fit")
.add(caller="predict", callee="predict"),
)
)
return router
def fit(self, X, y, **fit_params):
routed_params = process_routing(self, "fit", **fit_params)
self.transformer_ = clone(self.transformer).fit(
X, y, **routed_params.transformer.fit
)
X_transformed = self.transformer_.transform(
X, **routed_params.transformer.transform
)
self.classifier_ = clone(self.classifier).fit(
X_transformed, y, **routed_params.classifier.fit
)
return self
def predict(self, X, **predict_params):
routed_params = process_routing(self, "predict", **predict_params)
X_transformed = self.transformer_.transform(
X, **routed_params.transformer.transform
)
return self.classifier_.predict(
X_transformed, **routed_params.classifier.predict
)
请注意使用 MethodMapping
来声明子估计器(被调用者)的哪些方法在元估计器(调用者)的哪些方法中使用。如你所见, SimplePipeline
在 fit
方法中使用了转换器的 transform
和 fit
方法,并在 predict
方法中使用了其 transform
方法,这就是你在管道类的路由结构中看到的实现。
在上述示例中与之前示例的另一个区别是使用了 process_routing
,该函数处理输入参数,进行必要的验证,并返回我们在之前示例中创建的 routed_params
。这减少了开发人员在每个元估计器方法中需要编写的样板代码。强烈建议开发人员使用此函数,除非有充分的理由不使用它。
为了测试上述管道,让我们添加一个示例转换器。
class ExampleTransformer(TransformerMixin, BaseEstimator):
def fit(self, X, y, sample_weight=None):
check_metadata(self, sample_weight=sample_weight)
return self
def transform(self, X, groups=None):
check_metadata(self, groups=groups)
return X
def fit_transform(self, X, y, sample_weight=None, groups=None):
return self.fit(X, y, sample_weight).transform(X, groups)
请注意,在上面的示例中,我们实现了调用带有适当元数据的 fit
和 transform
的 fit_transform
。这仅在 transform
接受元数据时才需要,因为:class:~base.TransformerMixin
中的默认 fit_transform
实现不会将元数据传递给 transform
。
现在我们可以测试我们的管道,并查看元数据是否被正确传递。
这个例子使用了我们的 SimplePipeline
、我们的 ExampleTransformer
和我们的 RouterConsumerClassifier
,后者使用了我们的 ExampleClassifier
。
pipe = SimplePipeline(
transformer=ExampleTransformer()
# 我们设置了transformer的fit方法以接收sample_weight
.set_fit_request(sample_weight=True)
# 我们将变压器的变换设置为接收组
.set_transform_request(groups=True),
classifier=RouterConsumerClassifier(
estimator=ExampleClassifier()
# 我们希望这个子估计器在拟合时接收样本权重
.set_fit_request(sample_weight=True)
# 但在预测中没有分组
.set_predict_request(groups=False),
)
# 并且我们希望元估计器也能接收样本权重
.set_fit_request(sample_weight=True),
)
pipe.fit(X, y, sample_weight=my_weights, groups=my_groups).predict(
X[:3], groups=my_groups
)
Received sample_weight of length = 100 in ExampleTransformer.
Received groups of length = 100 in ExampleTransformer.
Received sample_weight of length = 100 in RouterConsumerClassifier.
Received sample_weight of length = 100 in ExampleClassifier.
Received groups of length = 100 in ExampleTransformer.
groups is None in ExampleClassifier.
array([1., 1., 1.])
弃用/默认值更改#
在本节中,我们展示了如何处理路由器也成为消费者的情况,特别是当它消费与其子估计器相同的元数据时,或者消费者开始消费在旧版本中未消费的元数据时。在这种情况下,应发出警告一段时间,以便让用户知道行为已从以前的版本更改。
class MetaRegressor(MetaEstimatorMixin, RegressorMixin, BaseEstimator):
def __init__(self, estimator):
self.estimator = estimator
def fit(self, X, y, **fit_params):
routed_params = process_routing(self, "fit", **fit_params)
self.estimator_ = clone(self.estimator).fit(X, y, **routed_params.estimator.fit)
def get_metadata_routing(self):
router = MetadataRouter(owner=self.__class__.__name__).add(
estimator=self.estimator,
method_mapping=MethodMapping().add(caller="fit", callee="fit"),
)
return router
正如上文所述,如果 my_weights
不应该作为 sample_weight
传递给 MetaRegressor
,那么这是一个有效的用法:
reg = MetaRegressor(estimator=LinearRegression().set_fit_request(sample_weight=True))
reg.fit(X, y, sample_weight=my_weights)
现在想象一下我们进一步开发了 MetaRegressor
,它现在也*使用* sample_weight
:
class WeightedMetaRegressor(MetaEstimatorMixin, RegressorMixin, BaseEstimator):
# 显示警告,提醒用户显式设置值,使用
# `.set_{method}_request(sample_weight={boolean})`
__metadata_request__fit = {"sample_weight": metadata_routing.WARN}
def __init__(self, estimator):
self.estimator = estimator
def fit(self, X, y, sample_weight=None, **fit_params):
routed_params = process_routing(
self, "fit", sample_weight=sample_weight, **fit_params
)
check_metadata(self, sample_weight=sample_weight)
self.estimator_ = clone(self.estimator).fit(X, y, **routed_params.estimator.fit)
def get_metadata_routing(self):
router = (
MetadataRouter(owner=self.__class__.__name__)
.add_self_request(self)
.add(
estimator=self.estimator,
method_mapping=MethodMapping().add(caller="fit", callee="fit"),
)
)
return router
上述实现与 MetaRegressor
几乎相同,并且由于在 __metadata_request__fit
中定义的默认请求值,在拟合时会引发警告。
with warnings.catch_warnings(record=True) as record:
WeightedMetaRegressor(
estimator=LinearRegression().set_fit_request(sample_weight=False)
).fit(X, y, sample_weight=my_weights)
for w in record:
print(w.message)
Received sample_weight of length = 100 in WeightedMetaRegressor.
Support for sample_weight has recently been added to this class. To maintain backward compatibility, it is ignored now. Using `set_fit_request(sample_weight={True, False})` on this method of the class, you can set the request value to False to silence this warning, or to True to consume and use the metadata.
当估计器消耗以前未消耗的元数据时,可以使用以下模式来警告用户。
class ExampleRegressor(RegressorMixin, BaseEstimator):
__metadata_request__fit = {"sample_weight": metadata_routing.WARN}
def fit(self, X, y, sample_weight=None):
check_metadata(self, sample_weight=sample_weight)
return self
def predict(self, X):
return np.zeros(shape=(len(X)))
with warnings.catch_warnings(record=True) as record:
MetaRegressor(estimator=ExampleRegressor()).fit(X, y, sample_weight=my_weights)
for w in record:
print(w.message)
sample_weight is None in ExampleRegressor.
Support for sample_weight has recently been added to this class. To maintain backward compatibility, it is ignored now. Using `set_fit_request(sample_weight={True, False})` on this method of the class, you can set the request value to False to silence this warning, or to True to consume and use the metadata.
最后,我们禁用元数据路由的配置标志:
set_config(enable_metadata_routing=False)
第三方开发和scikit-learn依赖#
如上所述,信息通过 MetadataRequest
和 MetadataRouter
在类之间传递。强烈不建议这样做,但如果您严格希望拥有一个兼容 scikit-learn 的估计器,而不依赖于 scikit-learn 包,则可以使用与元数据路由相关的工具。如果满足以下所有条件,您完全不需要修改代码:
你的估计器继承自
BaseEstimator
估计器方法(例如
fit
)所需的参数在方法的签名中明确定义,而不是使用*args
或*kwargs
。你的估计器不会将任何元数据传递给底层对象,即它不是一个*路由器*。
Total running time of the script: (0 minutes 0.022 seconds)
Related examples