========================================
在鸢尾花数据集上绘制多类SGD
========================================

在鸢尾花数据集上绘制多类SGD的决策面。
对应于三个一对多(OVA)分类器的超平面由虚线表示。

.. GENERATED FROM PYTHON SOURCE LINES 10-83

.. image-sg:: /auto_examples/linear_model/images/sphx_glr_plot_sgd_iris_001.png
   :alt: Decision surface of multi-class SGD
   :srcset: /auto_examples/linear_model/images/sphx_glr_plot_sgd_iris_001.png
   :class: sphx-glr-single-img

.. code-block:: Python

    import matplotlib.pyplot as plt
    import numpy as np
    from sklearn import datasets
    from sklearn.inspection import DecisionBoundaryDisplay
    from sklearn.linear_model import SGDClassifier

    # 导入一些数据来玩玩
    iris = datasets.load_iris()

    # 我们只取前两个特征。我们可以通过使用一个二维数据集来避免这种丑陋的切片操作。
    X = iris.data[:, :2]
    y = iris.target
    colors = "bry"

    # shuffle
    idx = np.arange(X.shape[0])
    np.random.seed(13)
    np.random.shuffle(idx)
    X = X[idx]
    y = y[idx]

    # 标准化
    mean = X.mean(axis=0)
    std = X.std(axis=0)
    X = (X - mean) / std

    clf = SGDClassifier(alpha=0.001, max_iter=100).fit(X, y)
    ax = plt.gca()
    DecisionBoundaryDisplay.from_estimator(
        clf,
        X,
        cmap=plt.cm.Paired,
        ax=ax,
        response_method="predict",
        xlabel=iris.feature_names[0],
        ylabel=iris.feature_names[1],
    )
    plt.axis("tight")

    # 还要绘制训练点
    for i, color in zip(clf.classes_, colors):
        idx = np.where(y == i)
        plt.scatter(
            X[idx, 0],
            X[idx, 1],
            c=color,
            label=iris.target_names[i],
            edgecolor="black",
            s=20,
        )
    plt.title("Decision surface of multi-class SGD")
    plt.axis("tight")

    # 绘制三个一对多分类器
    xmin, xmax = plt.xlim()
    ymin, ymax = plt.ylim()
    coef = clf.coef_
    intercept = clf.intercept_

    def plot_hyperplane(c, color):
        def line(x0):
            return (-(x0 * coef[c, 0]) - intercept[c]) / coef[c, 1]

        plt.plot([xmin, xmax], [line(xmin), line(xmax)], ls="--", color=color)

    for i, color in zip(clf.classes_, colors):
        plot_hyperplane(i, color)
    plt.legend()
    plt.show()

**Total running time of the script:** (0 minutes 0.055 seconds)