.. DO NOT EDIT. .. THIS FILE WAS AUTOMATICALLY GENERATED BY SPHINX-GALLERY. .. TO MAKE CHANGES, EDIT THE SOURCE PYTHON FILE: .. "auto_examples/svm/plot_svm_scale_c.py" .. LINE NUMBERS ARE GIVEN BELOW. .. only:: html .. note:: :class: sphx-glr-download-link-note :ref:`Go to the end <sphx_glr_download_auto_examples_svm_plot_svm_scale_c.py>` to download the full example code. or to run this example in your browser via Binder .. rst-class:: sphx-glr-example-title .. _sphx_glr_auto_examples_svm_plot_svm_scale_c.py: ============================================== 缩放SVCçš„æ£åˆ™åŒ–傿•° ============================================== 以下示例说明了在使用:ref:`svm` 进行:ref:`分类 <svm_classification>` 时缩放æ£åˆ™åŒ–傿•°çš„æ•ˆæžœã€‚ 对于SVC分类,我们关注的是以下方程的风险最å°åŒ–: .. math:: C \sum_{i=1, n} \mathcal{L} (f(x_i), y_i) + \Omega (w) å…¶ä¸ - :math:`C` 用于设置æ£åˆ™åŒ–çš„é‡ - :math:`\mathcal{L}` æ˜¯æ ·æœ¬å’Œæ¨¡åž‹å‚æ•°çš„ `æŸå¤±` 函数 - :math:`\Omega` æ˜¯æ¨¡åž‹å‚æ•°çš„ `惩罚` 函数 如果我们认为æŸå¤±å‡½æ•°æ˜¯æ¯ä¸ªæ ·æœ¬çš„å•ç‹¬è¯¯å·®ï¼Œé‚£ä¹ˆæ•°æ®æ‹Ÿåˆé¡¹æˆ–æ¯ä¸ªæ ·æœ¬è¯¯å·®çš„æ€»å’Œä¼šéšç€æˆ‘ä»¬æ·»åŠ æ›´å¤šæ ·æœ¬è€Œå¢žåŠ ã€‚ç„¶è€Œï¼Œæƒ©ç½šé¡¹ä¸ä¼šå¢žåŠ ã€‚ 例如,当使用:ref:`交å‰éªŒè¯ <cross_validation>` æ¥è®¾ç½®æ£åˆ™åŒ–傿•° `C` 时,主问题和交å‰éªŒè¯æŠ˜å 内的较å°é—®é¢˜ä¹‹é—´çš„æ ·æœ¬æ•°é‡ä¼šæœ‰æ‰€ä¸åŒã€‚ 由于æŸå¤±å‡½æ•°ä¾èµ–äºŽæ ·æœ¬æ•°é‡ï¼ŒåŽè€…ä¼šå½±å“æ‰€é€‰çš„ `C` 值。由æ¤äº§ç”Ÿçš„问题是“我们如何优化调整C以考虑ä¸åŒæ•°é‡çš„è®ç»ƒæ ·æœ¬ï¼Ÿâ€ .. GENERATED FROM PYTHON SOURCE LINES 25-29 .. code-block:: Python # 作者:scikit-learn å¼€å‘者 # SPDX 许å¯è¯æ ‡è¯†ç¬¦ï¼šBSD-3-Clause .. GENERATED FROM PYTHON SOURCE LINES 30-34 æ•°æ®ç”Ÿæˆ --------------- 在这个示例ä¸ï¼Œæˆ‘ä»¬ç ”ç©¶äº†åœ¨ä½¿ç”¨L1或L2æƒ©ç½šæ—¶ï¼Œé‡æ–°å‚数化æ£åˆ™åŒ–傿•° `C` ä»¥è€ƒè™‘æ ·æœ¬æ•°é‡çš„æ•ˆæžœã€‚为æ¤ï¼Œæˆ‘们创建了一个具有大é‡ç‰¹å¾çš„åˆæˆæ•°æ®é›†ï¼Œå…¶ä¸åªæœ‰å°‘æ•°ç‰¹å¾æ˜¯æœ‰ä¿¡æ¯é‡çš„ã€‚å› æ¤ï¼Œæˆ‘们期望æ£åˆ™åŒ–将系数缩å°åˆ°æŽ¥è¿‘零(L2惩罚)或精确为零(L1惩罚)。 .. GENERATED FROM PYTHON SOURCE LINES 34-42 .. code-block:: Python from sklearn.datasets import make_classification n_samples, n_features = 100, 300 X, y = make_classification( n_samples=n_samples, n_features=n_features, n_informative=5, random_state=1 ) .. GENERATED FROM PYTHON SOURCE LINES 43-48 L1-惩罚情况 --------------- 在L1情况下,ç†è®ºè¡¨æ˜Žï¼Œåªè¦æœ‰å¼ºæ£åˆ™åŒ–,估计器的预测能力就ä¸å¦‚知é“真实分布的模型(å³ä½¿åœ¨æ ·æœ¬é‡æ— é™å¢žé•¿çš„æƒ…å†µä¸‹ï¼‰ï¼Œå› ä¸ºå®ƒå¯èƒ½ä¼šå°†ä¸€äº›æœ¬æ¥å…·æœ‰é¢„测能力的特å¾çš„æƒé‡è®¾ä¸ºé›¶ï¼Œä»Žè€Œå¼•å…¥å差。然而,ç†è®ºä¹Ÿè¡¨æ˜Žï¼Œé€šè¿‡è°ƒæ•´ `C` ,å¯ä»¥æ‰¾åˆ°æ£ç¡®çš„éžé›¶å‚数集åŠå…¶ç¬¦å·ã€‚ 我们定义一个带有L1惩罚的线性SVC。 .. GENERATED FROM PYTHON SOURCE LINES 48-53 .. code-block:: Python from sklearn.svm import LinearSVC model_l1 = LinearSVC(penalty="l1", loss="squared_hinge", dual=False, tol=1e-3) .. GENERATED FROM PYTHON SOURCE LINES 54-55 我们通过交å‰éªŒè¯è®¡ç®—ä¸åŒ `C` å€¼çš„å¹³å‡æµ‹è¯•得分。 .. GENERATED FROM PYTHON SOURCE LINES 55-86 .. code-block:: Python import numpy as np import pandas as pd from sklearn.model_selection import ShuffleSplit, validation_curve Cs = np.logspace(-2.3, -1.3, 10) train_sizes = np.linspace(0.3, 0.7, 3) labels = [f"fraction: {train_size}" for train_size in train_sizes] shuffle_params = { "test_size": 0.3, "n_splits": 150, "random_state": 1, } results = {"C": Cs} for label, train_size in zip(labels, train_sizes): cv = ShuffleSplit(train_size=train_size, **shuffle_params) train_scores, test_scores = validation_curve( model_l1, X, y, param_name="C", param_range=Cs, cv=cv, n_jobs=2, ) results[label] = test_scores.mean(axis=1) results = pd.DataFrame(results) .. GENERATED FROM PYTHON SOURCE LINES 87-114 .. code-block:: Python import matplotlib.pyplot as plt fig, axes = plt.subplots(nrows=1, ncols=2, sharey=True, figsize=(12, 6)) # 绘制未缩放C的结果 results.plot(x="C", ax=axes[0], logx=True) axes[0].set_ylabel("CV score") axes[0].set_title("No scaling") for label in labels: best_C = results.loc[results[label].idxmax(), "C"] axes[0].axvline(x=best_C, linestyle="--", color="grey", alpha=0.7) # 通过缩放C绘制结果 for train_size_idx, label in enumerate(labels): train_size = train_sizes[train_size_idx] results_scaled = results[[label]].assign( C_scaled=Cs * float(n_samples * np.sqrt(train_size)) ) results_scaled.plot(x="C_scaled", ax=axes[1], logx=True, label=label) best_C_scaled = results_scaled["C_scaled"].loc[results[label].idxmax()] axes[1].axvline(x=best_C_scaled, linestyle="--", color="grey", alpha=0.7) axes[1].set_title("Scaling C by sqrt(1 / n_samples)") _ = fig.suptitle("Effect of scaling C with L1 penalty") .. image-sg:: /auto_examples/svm/images/sphx_glr_plot_svm_scale_c_001.png :alt: Effect of scaling C with L1 penalty, No scaling, Scaling C by sqrt(1 / n_samples) :srcset: /auto_examples/svm/images/sphx_glr_plot_svm_scale_c_001.png :class: sphx-glr-single-img .. GENERATED FROM PYTHON SOURCE LINES 115-124 åœ¨å° `C` (强æ£åˆ™åŒ–)区域,模型å¦ä¹ åˆ°çš„æ‰€æœ‰ç³»æ•°éƒ½æ˜¯é›¶ï¼Œå¯¼è‡´ä¸¥é‡æ¬ 拟åˆã€‚å®žé™…ä¸Šï¼Œè¯¥åŒºåŸŸçš„å‡†ç¡®çŽ‡å¤„äºŽéšæœºæ°´å¹³ã€‚ 使用默认的缩放比例会得到一个相对稳定的 `C` æœ€ä¼˜å€¼ï¼Œè€Œä»Žæ¬ æ‹ŸåˆåŒºåŸŸè¿‡æ¸¡å–决于è®ç»ƒæ ·æœ¬çš„æ•°é‡ã€‚釿–°å‚æ•°åŒ–ä¼šå¯¼è‡´ç»“æžœæ›´åŠ ç¨³å®šã€‚ å‚è§ä¾‹å¦‚:arxiv:`On the prediction performance of the Lasso <1402.1700>` 的定ç†3或:arxiv:`Simultaneous analysis of Lasso and Dantzig selector <0801.1095>` ï¼Œå…¶ä¸æ£åˆ™åŒ–傿•°æ€»æ˜¯å‡å®šä¸Ž1 / sqrt(n_samples)æˆæ£æ¯”。 L2-惩罚情况 --------------- 我们å¯ä»¥å¯¹L2æƒ©ç½šè¿›è¡Œç±»ä¼¼çš„å®žéªŒã€‚åœ¨è¿™ç§æƒ…况下,ç†è®ºä¸Šä¸ºäº†å®žçŽ°é¢„æµ‹ä¸€è‡´æ€§ï¼Œæƒ©ç½šå‚æ•°åº”éšç€æ ·æœ¬æ•°é‡çš„å¢žåŠ ä¿æŒä¸å˜ã€‚ .. GENERATED FROM PYTHON SOURCE LINES 124-144 .. code-block:: Python model_l2 = LinearSVC(penalty="l2", loss="squared_hinge", dual=True) Cs = np.logspace(-8, 4, 11) labels = [f"fraction: {train_size}" for train_size in train_sizes] results = {"C": Cs} for label, train_size in zip(labels, train_sizes): cv = ShuffleSplit(train_size=train_size, **shuffle_params) train_scores, test_scores = validation_curve( model_l2, X, y, param_name="C", param_range=Cs, cv=cv, n_jobs=2, ) results[label] = test_scores.mean(axis=1) results = pd.DataFrame(results) .. GENERATED FROM PYTHON SOURCE LINES 145-171 .. code-block:: Python import matplotlib.pyplot as plt fig, axes = plt.subplots(nrows=1, ncols=2, sharey=True, figsize=(12, 6)) # 绘制未缩放C的结果 results.plot(x="C", ax=axes[0], logx=True) axes[0].set_ylabel("CV score") axes[0].set_title("No scaling") for label in labels: best_C = results.loc[results[label].idxmax(), "C"] axes[0].axvline(x=best_C, linestyle="--", color="grey", alpha=0.8) # 通过缩放C绘制结果 for train_size_idx, label in enumerate(labels): results_scaled = results[[label]].assign( C_scaled=Cs * float(n_samples * np.sqrt(train_sizes[train_size_idx])) ) results_scaled.plot(x="C_scaled", ax=axes[1], logx=True, label=label) best_C_scaled = results_scaled["C_scaled"].loc[results[label].idxmax()] axes[1].axvline(x=best_C_scaled, linestyle="--", color="grey", alpha=0.8) axes[1].set_title("Scaling C by sqrt(1 / n_samples)") fig.suptitle("Effect of scaling C with L2 penalty") plt.show() .. image-sg:: /auto_examples/svm/images/sphx_glr_plot_svm_scale_c_002.png :alt: Effect of scaling C with L2 penalty, No scaling, Scaling C by sqrt(1 / n_samples) :srcset: /auto_examples/svm/images/sphx_glr_plot_svm_scale_c_002.png :class: sphx-glr-single-img .. GENERATED FROM PYTHON SOURCE LINES 172-175 对于L2æƒ©ç½šæƒ…å†µï¼Œé‡æ–°å‚数化似乎对æ£åˆ™åŒ–的最优值稳定性影å“较å°ã€‚过拟åˆåŒºåŸŸçš„过渡å‘生在更广泛的范围内,并且准确性似乎没有é™åˆ°éšæœºæ°´å¹³ã€‚ å°è¯•å°†å€¼å¢žåŠ åˆ° `n_splits=1_000` 以在 L2 情况下获得更好的结果,这里由于文档生æˆå™¨çš„é™åˆ¶æœªæ˜¾ç¤ºã€‚ .. rst-class:: sphx-glr-timing **Total running time of the script:** (0 minutes 8.553 seconds) .. _sphx_glr_download_auto_examples_svm_plot_svm_scale_c.py: .. only:: html .. container:: sphx-glr-footer sphx-glr-footer-example .. container:: binder-badge .. image:: images/binder_badge_logo.svg :target: https://mybinder.org/v2/gh/scikit-learn/scikit-learn/main?urlpath=lab/tree/notebooks/auto_examples/svm/plot_svm_scale_c.ipynb :alt: Launch binder :width: 150 px .. container:: sphx-glr-download sphx-glr-download-jupyter :download:`Download Jupyter notebook: plot_svm_scale_c.ipynb <plot_svm_scale_c.ipynb>` .. container:: sphx-glr-download sphx-glr-download-python :download:`Download Python source code: plot_svm_scale_c.py <plot_svm_scale_c.py>` .. container:: sphx-glr-download sphx-glr-download-zip :download:`Download zipped: plot_svm_scale_c.zip <plot_svm_scale_c.zip>` .. include:: plot_svm_scale_c.recommendations .. only:: html .. rst-class:: sphx-glr-signature `Gallery generated by Sphinx-Gallery <https://sphinx-gallery.github.io>`_