.. DO NOT EDIT. .. THIS FILE WAS AUTOMATICALLY GENERATED BY SPHINX-GALLERY. .. TO MAKE CHANGES, EDIT THE SOURCE PYTHON FILE: .. "auto_examples/compose/plot_transformed_target.py" .. LINE NUMBERS ARE GIVEN BELOW. .. only:: html .. note:: :class: sphx-glr-download-link-note :ref:`Go to the end <sphx_glr_download_auto_examples_compose_plot_transformed_target.py>` to download the full example code. or to run this example in your browser via Binder .. rst-class:: sphx-glr-example-title .. _sphx_glr_auto_examples_compose_plot_transformed_target.py: ====================================================== 转æ¢å›žå½’模型ä¸çš„ç›®æ ‡å˜é‡çš„æ•ˆæžœ ====================================================== 在这个示例ä¸ï¼Œæˆ‘们概述了:class:`~sklearn.compose.TransformedTargetRegressor` 。我们使用两个示例æ¥è¯´æ˜Žåœ¨å¦ä¹ 线性回归模型之å‰è½¬æ¢ç›®æ ‡å˜é‡çš„å¥½å¤„ã€‚ç¬¬ä¸€ä¸ªç¤ºä¾‹ä½¿ç”¨åˆæˆæ•°æ®ï¼Œè€Œç¬¬äºŒä¸ªç¤ºä¾‹åŸºäºŽAmes房价数æ®é›†ã€‚ .. GENERATED FROM PYTHON SOURCE LINES 9-15 .. code-block:: Python # 作者:scikit-learn å¼€å‘者 # SPDX-License-Identifier:BSD-3-Clause print(__doc__) .. GENERATED FROM PYTHON SOURCE LINES 16-17 åˆæˆç¤ºä¾‹ .. GENERATED FROM PYTHON SOURCE LINES 17-34 .. code-block:: Python ################### # # 生æˆä¸€ä¸ªåˆæˆçš„éšæœºå›žå½’æ•°æ®é›†ã€‚ç›®æ ‡å€¼ ``y`` 被修改为: # # 1. å°†æ‰€æœ‰ç›®æ ‡å€¼å¹³ç§»ï¼Œä½¿å¾—æ‰€æœ‰æ¡ç›®éƒ½æ˜¯éžè´Ÿçš„ï¼ˆé€šè¿‡åŠ ä¸Šæœ€å°çš„ ``y`` çš„ç»å¯¹å€¼ï¼‰ï¼Œå¹¶ä¸” # 2. åº”ç”¨æŒ‡æ•°å‡½æ•°ä»¥èŽ·å¾—æ— æ³•ä½¿ç”¨ç®€å•线性模型拟åˆçš„éžçº¿æ€§ç›®æ ‡å€¼ã€‚ # # å› æ¤ï¼Œåœ¨è®ç»ƒçº¿æ€§å›žå½’模型并使用其进行预测之å‰ï¼Œå°†ä½¿ç”¨å¯¹æ•°å‡½æ•°ï¼ˆ `np.log1p` )和指数函数( `np.expm1` ï¼‰å¯¹ç›®æ ‡è¿›è¡Œè½¬æ¢ã€‚ import numpy as np from sklearn.datasets import make_regression X, y = make_regression(n_samples=10_000, noise=100, random_state=0) y = np.expm1((y + abs(y.min())) / 200) y_trans = np.log1p(y) .. GENERATED FROM PYTHON SOURCE LINES 35-36 䏋颿ˆ‘们绘制了应用对数函数å‰åŽç›®æ ‡çš„æ¦‚率密度函数。 .. GENERATED FROM PYTHON SOURCE LINES 36-58 .. code-block:: Python import matplotlib.pyplot as plt from sklearn.model_selection import train_test_split f, (ax0, ax1) = plt.subplots(1, 2) ax0.hist(y, bins=100, density=True) ax0.set_xlim([0, 2000]) ax0.set_ylabel("Probability") ax0.set_xlabel("Target") ax0.set_title("Target distribution") ax1.hist(y_trans, bins=100, density=True) ax1.set_ylabel("Probability") ax1.set_xlabel("Target") ax1.set_title("Transformed target distribution") f.suptitle("Synthetic data", y=1.05) plt.tight_layout() X_train, X_test, y_train, y_test = train_test_split(X, y, random_state=0) .. image-sg:: /auto_examples/compose/images/sphx_glr_plot_transformed_target_001.png :alt: Synthetic data, Target distribution, Transformed target distribution :srcset: /auto_examples/compose/images/sphx_glr_plot_transformed_target_001.png :class: sphx-glr-single-img .. GENERATED FROM PYTHON SOURCE LINES 59-60 é¦–å…ˆï¼Œå°†åœ¨çº¿æ€§æ¨¡åž‹ä¸Šåº”ç”¨åŽŸå§‹ç›®æ ‡ã€‚ç”±äºŽéžçº¿æ€§ï¼Œè®ç»ƒçš„æ¨¡åž‹åœ¨é¢„测时ä¸ä¼šç²¾ç¡®ã€‚éšåŽï¼Œä½¿ç”¨å¯¹æ•°å‡½æ•°å¯¹ç›®æ ‡è¿›è¡Œçº¿æ€§åŒ–,å³ä½¿ä½¿ç”¨ç±»ä¼¼çš„线性模型,也能通过ä¸ä½ç»å¯¹è¯¯å·®ï¼ˆMedAE)报告更好的预测结果。 .. GENERATED FROM PYTHON SOURCE LINES 60-71 .. code-block:: Python from sklearn.metrics import median_absolute_error, r2_score def compute_score(y_true, y_pred): return { "R2": f"{r2_score(y_true, y_pred):.3f}", "MedAE": f"{median_absolute_error(y_true, y_pred):.3f}", } .. GENERATED FROM PYTHON SOURCE LINES 72-112 .. code-block:: Python from sklearn.compose import TransformedTargetRegressor from sklearn.linear_model import RidgeCV from sklearn.metrics import PredictionErrorDisplay f, (ax0, ax1) = plt.subplots(1, 2, sharey=True) ridge_cv = RidgeCV().fit(X_train, y_train) y_pred_ridge = ridge_cv.predict(X_test) ridge_cv_with_trans_target = TransformedTargetRegressor( regressor=RidgeCV(), func=np.log1p, inverse_func=np.expm1 ).fit(X_train, y_train) y_pred_ridge_with_trans_target = ridge_cv_with_trans_target.predict(X_test) PredictionErrorDisplay.from_predictions( y_test, y_pred_ridge, kind="actual_vs_predicted", ax=ax0, scatter_kwargs={"alpha": 0.5}, ) PredictionErrorDisplay.from_predictions( y_test, y_pred_ridge_with_trans_target, kind="actual_vs_predicted", ax=ax1, scatter_kwargs={"alpha": 0.5}, ) # 在æ¯ä¸ªè½´çš„图例䏿·»åŠ åˆ†æ•° for ax, y_pred in zip([ax0, ax1], [y_pred_ridge, y_pred_ridge_with_trans_target]): for name, score in compute_score(y_test, y_pred).items(): ax.plot([], [], " ", label=f"{name}={score}") ax.legend(loc="upper left") ax0.set_title("Ridge regression \n without target transformation") ax1.set_title("Ridge regression \n with target transformation") f.suptitle("Synthetic data", y=1.05) plt.tight_layout() .. image-sg:: /auto_examples/compose/images/sphx_glr_plot_transformed_target_002.png :alt: Synthetic data, Ridge regression without target transformation, Ridge regression with target transformation :srcset: /auto_examples/compose/images/sphx_glr_plot_transformed_target_002.png :class: sphx-glr-single-img .. GENERATED FROM PYTHON SOURCE LINES 113-114 真实世界数æ®é›† .. GENERATED FROM PYTHON SOURCE LINES 116-117 以类似的方å¼ï¼ŒAmes 房价数æ®é›†è¢«ç”¨æ¥å±•示在å¦ä¹ 模型之å‰è½¬æ¢ç›®æ ‡å˜é‡çš„å½±å“。在这个例åä¸ï¼Œè¦é¢„æµ‹çš„ç›®æ ‡æ˜¯æ¯æ ‹æˆ¿å的售价。 .. GENERATED FROM PYTHON SOURCE LINES 118-132 .. code-block:: Python from sklearn.datasets import fetch_openml from sklearn.preprocessing import quantile_transform ames = fetch_openml(name="house_prices", as_frame=True) # åªä¿ç•™æ•°å€¼åˆ— X = ames.data.select_dtypes(np.number) # åˆ é™¤åŒ…å«NaN或Inf值的列 X = X.drop(columns=["LotFrontage", "GarageYrBlt", "MasVnrArea"]) # ä»¤ä»·æ ¼ä»¥åƒç¾Žå…ƒä¸ºå•ä½ y = ames.target / 1000 y_trans = quantile_transform( y.to_frame(), n_quantiles=900, output_distribution="normal", copy=True ).squeeze() .. GENERATED FROM PYTHON SOURCE LINES 133-134 一个 :class:`~sklearn.preprocessing.QuantileTransformer` 被用æ¥åœ¨åº”用 :class:`~sklearn.linear_model.RidgeCV` 模型之å‰å¯¹ç›®æ ‡åˆ†å¸ƒè¿›è¡Œå½’一化。 .. GENERATED FROM PYTHON SOURCE LINES 134-150 .. code-block:: Python f, (ax0, ax1) = plt.subplots(1, 2) ax0.hist(y, bins=100, density=True) ax0.set_ylabel("Probability") ax0.set_xlabel("Target") ax0.set_title("Target distribution") ax1.hist(y_trans, bins=100, density=True) ax1.set_ylabel("Probability") ax1.set_xlabel("Target") ax1.set_title("Transformed target distribution") f.suptitle("Ames housing data: selling price", y=1.05) plt.tight_layout() .. image-sg:: /auto_examples/compose/images/sphx_glr_plot_transformed_target_003.png :alt: Ames housing data: selling price, Target distribution, Transformed target distribution :srcset: /auto_examples/compose/images/sphx_glr_plot_transformed_target_003.png :class: sphx-glr-single-img .. GENERATED FROM PYTHON SOURCE LINES 151-153 .. code-block:: Python X_train, X_test, y_train, y_test = train_test_split(X, y, random_state=1) .. GENERATED FROM PYTHON SOURCE LINES 154-155 å˜åŽ‹å™¨çš„æ•ˆæžœæ¯”åœ¨åˆæˆæ•°æ®ä¸Šè¦å¼±ã€‚然而,转æ¢ç»“果使 :math:`R^2` å¢žåŠ ï¼Œä¸” MedAE 大幅å‡å°‘ã€‚æ®‹å·®å›¾ï¼ˆé¢„æµ‹ç›®æ ‡ - çœŸå®žç›®æ ‡ vs é¢„æµ‹ç›®æ ‡ï¼‰åœ¨æ²¡æœ‰ç›®æ ‡è½¬æ¢çš„æƒ…况下,由于残差值éšé¢„æµ‹ç›®æ ‡å€¼çš„å˜åŒ–而å˜åŒ–,呈现出弯曲的“åå‘微笑â€å½¢çŠ¶ã€‚é€šè¿‡ç›®æ ‡è½¬æ¢ï¼Œå½¢çŠ¶æ›´åŠ çº¿æ€§ï¼Œè¡¨æ˜Žæ¨¡åž‹æ‹Ÿåˆæ›´å¥½ã€‚ .. GENERATED FROM PYTHON SOURCE LINES 155-215 .. code-block:: Python from sklearn.preprocessing import QuantileTransformer f, (ax0, ax1) = plt.subplots(2, 2, sharey="row", figsize=(6.5, 8)) ridge_cv = RidgeCV().fit(X_train, y_train) y_pred_ridge = ridge_cv.predict(X_test) ridge_cv_with_trans_target = TransformedTargetRegressor( regressor=RidgeCV(), transformer=QuantileTransformer(n_quantiles=900, output_distribution="normal"), ).fit(X_train, y_train) y_pred_ridge_with_trans_target = ridge_cv_with_trans_target.predict(X_test) # 绘制实际值与预测值的对比图 PredictionErrorDisplay.from_predictions( y_test, y_pred_ridge, kind="actual_vs_predicted", ax=ax0[0], scatter_kwargs={"alpha": 0.5}, ) PredictionErrorDisplay.from_predictions( y_test, y_pred_ridge_with_trans_target, kind="actual_vs_predicted", ax=ax0[1], scatter_kwargs={"alpha": 0.5}, ) # 在æ¯ä¸ªè½´çš„图例䏿·»åŠ åˆ†æ•° for ax, y_pred in zip([ax0[0], ax0[1]], [y_pred_ridge, y_pred_ridge_with_trans_target]): for name, score in compute_score(y_test, y_pred).items(): ax.plot([], [], " ", label=f"{name}={score}") ax.legend(loc="upper left") ax0[0].set_title("Ridge regression \n without target transformation") ax0[1].set_title("Ridge regression \n with target transformation") # 绘制残差与预测值的关系图 PredictionErrorDisplay.from_predictions( y_test, y_pred_ridge, kind="residual_vs_predicted", ax=ax1[0], scatter_kwargs={"alpha": 0.5}, ) PredictionErrorDisplay.from_predictions( y_test, y_pred_ridge_with_trans_target, kind="residual_vs_predicted", ax=ax1[1], scatter_kwargs={"alpha": 0.5}, ) ax1[0].set_title("Ridge regression \n without target transformation") ax1[1].set_title("Ridge regression \n with target transformation") f.suptitle("Ames housing data: selling price", y=1.05) plt.tight_layout() plt.show() .. image-sg:: /auto_examples/compose/images/sphx_glr_plot_transformed_target_004.png :alt: Ames housing data: selling price, Ridge regression without target transformation, Ridge regression with target transformation, Ridge regression without target transformation, Ridge regression with target transformation :srcset: /auto_examples/compose/images/sphx_glr_plot_transformed_target_004.png :class: sphx-glr-single-img .. rst-class:: sphx-glr-timing **Total running time of the script:** (0 minutes 0.722 seconds) .. _sphx_glr_download_auto_examples_compose_plot_transformed_target.py: .. only:: html .. container:: sphx-glr-footer sphx-glr-footer-example .. container:: binder-badge .. image:: images/binder_badge_logo.svg :target: https://mybinder.org/v2/gh/scikit-learn/scikit-learn/main?urlpath=lab/tree/notebooks/auto_examples/compose/plot_transformed_target.ipynb :alt: Launch binder :width: 150 px .. container:: sphx-glr-download sphx-glr-download-jupyter :download:`Download Jupyter notebook: plot_transformed_target.ipynb <plot_transformed_target.ipynb>` .. container:: sphx-glr-download sphx-glr-download-python :download:`Download Python source code: plot_transformed_target.py <plot_transformed_target.py>` .. container:: sphx-glr-download sphx-glr-download-zip :download:`Download zipped: plot_transformed_target.zip <plot_transformed_target.zip>` .. include:: plot_transformed_target.recommendations .. only:: html .. rst-class:: sphx-glr-signature `Gallery generated by Sphinx-Gallery <https://sphinx-gallery.github.io>`_