.. DO NOT EDIT. .. THIS FILE WAS AUTOMATICALLY GENERATED BY SPHINX-GALLERY. .. TO MAKE CHANGES, EDIT THE SOURCE PYTHON FILE: .. "auto_examples/neural_networks/plot_rbm_logistic_classification.py" .. LINE NUMBERS ARE GIVEN BELOW. .. only:: html .. note:: :class: sphx-glr-download-link-note :ref:`Go to the end ` to download the full example code. or to run this example in your browser via Binder .. rst-class:: sphx-glr-example-title .. _sphx_glr_auto_examples_neural_networks_plot_rbm_logistic_classification.py: ============================================================== 用于数字分类的受限玻尔兹曼机特征 ============================================================== 对于像素值可以解释为白色背景上黑度的灰度图像数据,如手写数字识别,伯努利受限玻尔兹曼机模型(:class:`BernoulliRBM ` )可以执行有效的非线性特征提取。 .. GENERATED FROM PYTHON SOURCE LINES 9-13 .. code-block:: Python # 作者:scikit-learn 开发者 # SPDX-License-Identifier: BSD-3-Clause .. GENERATED FROM PYTHON SOURCE LINES 14-18 生成数据 ------------- 为了从小数据集中学习到良好的潜在表示,我们通过对训练数据进行每个方向1个像素的线性偏移来人为生成更多的标记数据。 .. GENERATED FROM PYTHON SOURCE LINES 18-54 .. code-block:: Python import numpy as np from scipy.ndimage import convolve from sklearn import datasets from sklearn.model_selection import train_test_split from sklearn.preprocessing import minmax_scale def nudge_dataset(X, Y): """通过将X中的8x8图像向左、向右、向下、向上移动1像素,这会生成一个比原始数据集大5倍的数据集, """ direction_vectors = [ [[0, 1, 0], [0, 0, 0], [0, 0, 0]], [[0, 0, 0], [1, 0, 0], [0, 0, 0]], [[0, 0, 0], [0, 0, 1], [0, 0, 0]], [[0, 0, 0], [0, 0, 0], [0, 1, 0]], ] def shift(x, w): return convolve(x.reshape((8, 8)), mode="constant", weights=w).ravel() X = np.concatenate( [X] + [np.apply_along_axis(shift, 1, X, vector) for vector in direction_vectors] ) Y = np.concatenate([Y for _ in range(5)], axis=0) return X, Y X, y = datasets.load_digits(return_X_y=True) X = np.asarray(X, "float32") X, Y = nudge_dataset(X, y) X = minmax_scale(X, feature_range=(0, 1)) # 0-1 scaling X_train, X_test, Y_train, Y_test = train_test_split(X, Y, test_size=0.2, random_state=0) .. GENERATED FROM PYTHON SOURCE LINES 55-59 模型定义 ----------------- 我们构建了一个分类管道,使用BernoulliRBM特征提取器和一个:class:`LogisticRegression ` 分类器。 .. GENERATED FROM PYTHON SOURCE LINES 59-69 .. code-block:: Python from sklearn import linear_model from sklearn.neural_network import BernoulliRBM from sklearn.pipeline import Pipeline logistic = linear_model.LogisticRegression(solver="newton-cg", tol=1) rbm = BernoulliRBM(random_state=0, verbose=True) rbm_features_classifier = Pipeline(steps=[("rbm", rbm), ("logistic", logistic)]) .. GENERATED FROM PYTHON SOURCE LINES 70-74 Training -------- 整个模型的超参数(学习率、隐藏层大小、正则化)通过网格搜索进行了优化,但由于运行时间的限制,这里不再重复搜索过程。 .. GENERATED FROM PYTHON SOURCE LINES 74-93 .. code-block:: Python from sklearn.base import clone # 超参数。这些参数是通过交叉验证和GridSearchCV设置的。这里我们不进行交叉验证以节省时间。 rbm.learning_rate = 0.06 rbm.n_iter = 10 # 更多的组件往往会带来更好的预测性能,但也会增加拟合时间。 rbm.n_components = 100 logistic.C = 6000 # 训练 RBM-Logistic 管道 rbm_features_classifier.fit(X_train, Y_train) # 直接在像素上训练逻辑回归分类器 raw_pixel_classifier = clone(logistic) raw_pixel_classifier.C = 100.0 raw_pixel_classifier.fit(X_train, Y_train) .. rst-class:: sphx-glr-script-out .. code-block:: none [BernoulliRBM] Iteration 1, pseudo-likelihood = -25.57, time = 0.05s [BernoulliRBM] Iteration 2, pseudo-likelihood = -23.68, time = 0.08s [BernoulliRBM] Iteration 3, pseudo-likelihood = -22.88, time = 0.07s [BernoulliRBM] Iteration 4, pseudo-likelihood = -21.91, time = 0.07s [BernoulliRBM] Iteration 5, pseudo-likelihood = -21.79, time = 0.07s [BernoulliRBM] Iteration 6, pseudo-likelihood = -20.96, time = 0.07s [BernoulliRBM] Iteration 7, pseudo-likelihood = -20.88, time = 0.08s [BernoulliRBM] Iteration 8, pseudo-likelihood = -20.50, time = 0.06s [BernoulliRBM] Iteration 9, pseudo-likelihood = -20.36, time = 0.07s [BernoulliRBM] Iteration 10, pseudo-likelihood = -20.07, time = 0.06s .. raw:: html
LogisticRegression(C=100.0, solver='newton-cg', tol=1)
In a Jupyter environment, please rerun this cell to show the HTML representation or trust the notebook.
On GitHub, the HTML representation is unable to render, please try loading this page with nbviewer.org.


.. GENERATED FROM PYTHON SOURCE LINES 94-96 Evaluation ---------- .. GENERATED FROM PYTHON SOURCE LINES 96-105 .. code-block:: Python from sklearn import metrics Y_pred = rbm_features_classifier.predict(X_test) print( "Logistic regression using RBM features:\n%s\n" % (metrics.classification_report(Y_test, Y_pred)) ) .. rst-class:: sphx-glr-script-out .. code-block:: none /app/scikit-learn-main-origin/sklearn/metrics/_classification.py:1404: UndefinedMetricWarning: Precision is ill-defined and being set to 0.0 in labels with no predicted samples. Use `zero_division` parameter to control this behavior. /app/scikit-learn-main-origin/sklearn/metrics/_classification.py:1404: UndefinedMetricWarning: Precision is ill-defined and being set to 0.0 in labels with no predicted samples. Use `zero_division` parameter to control this behavior. /app/scikit-learn-main-origin/sklearn/metrics/_classification.py:1404: UndefinedMetricWarning: Precision is ill-defined and being set to 0.0 in labels with no predicted samples. Use `zero_division` parameter to control this behavior. Logistic regression using RBM features: precision recall f1-score support 0 0.10 1.00 0.18 174 1 0.00 0.00 0.00 184 2 0.00 0.00 0.00 166 3 0.00 0.00 0.00 194 4 0.00 0.00 0.00 186 5 0.00 0.00 0.00 181 6 0.00 0.00 0.00 207 7 0.00 0.00 0.00 154 8 0.00 0.00 0.00 182 9 0.00 0.00 0.00 169 accuracy 0.10 1797 macro avg 0.01 0.10 0.02 1797 weighted avg 0.01 0.10 0.02 1797 .. GENERATED FROM PYTHON SOURCE LINES 106-112 .. code-block:: Python Y_pred = raw_pixel_classifier.predict(X_test) print( "Logistic regression using raw pixel features:\n%s\n" % (metrics.classification_report(Y_test, Y_pred)) ) .. rst-class:: sphx-glr-script-out .. code-block:: none /app/scikit-learn-main-origin/sklearn/metrics/_classification.py:1404: UndefinedMetricWarning: Precision is ill-defined and being set to 0.0 in labels with no predicted samples. Use `zero_division` parameter to control this behavior. /app/scikit-learn-main-origin/sklearn/metrics/_classification.py:1404: UndefinedMetricWarning: Precision is ill-defined and being set to 0.0 in labels with no predicted samples. Use `zero_division` parameter to control this behavior. /app/scikit-learn-main-origin/sklearn/metrics/_classification.py:1404: UndefinedMetricWarning: Precision is ill-defined and being set to 0.0 in labels with no predicted samples. Use `zero_division` parameter to control this behavior. Logistic regression using raw pixel features: precision recall f1-score support 0 0.10 1.00 0.18 174 1 0.00 0.00 0.00 184 2 0.00 0.00 0.00 166 3 0.00 0.00 0.00 194 4 0.00 0.00 0.00 186 5 0.00 0.00 0.00 181 6 0.00 0.00 0.00 207 7 0.00 0.00 0.00 154 8 0.00 0.00 0.00 182 9 0.00 0.00 0.00 169 accuracy 0.10 1797 macro avg 0.01 0.10 0.02 1797 weighted avg 0.01 0.10 0.02 1797 .. GENERATED FROM PYTHON SOURCE LINES 113-114 伯努利受限玻尔兹曼机提取的特征有助于提高相对于原始像素的逻辑回归分类准确性。 .. GENERATED FROM PYTHON SOURCE LINES 117-119 Plotting -------- .. GENERATED FROM PYTHON SOURCE LINES 119-132 .. code-block:: Python import matplotlib.pyplot as plt plt.figure(figsize=(4.2, 4)) for i, comp in enumerate(rbm.components_): plt.subplot(10, 10, i + 1) plt.imshow(comp.reshape((8, 8)), cmap=plt.cm.gray_r, interpolation="nearest") plt.xticks(()) plt.yticks(()) plt.suptitle("100 components extracted by RBM", fontsize=16) plt.subplots_adjust(0.08, 0.02, 0.92, 0.85, 0.08, 0.23) plt.show() .. image-sg:: /auto_examples/neural_networks/images/sphx_glr_plot_rbm_logistic_classification_001.png :alt: 100 components extracted by RBM :srcset: /auto_examples/neural_networks/images/sphx_glr_plot_rbm_logistic_classification_001.png :class: sphx-glr-single-img .. rst-class:: sphx-glr-timing **Total running time of the script:** (0 minutes 1.376 seconds) .. _sphx_glr_download_auto_examples_neural_networks_plot_rbm_logistic_classification.py: .. only:: html .. container:: sphx-glr-footer sphx-glr-footer-example .. container:: binder-badge .. image:: images/binder_badge_logo.svg :target: https://mybinder.org/v2/gh/scikit-learn/scikit-learn/main?urlpath=lab/tree/notebooks/auto_examples/neural_networks/plot_rbm_logistic_classification.ipynb :alt: Launch binder :width: 150 px .. container:: sphx-glr-download sphx-glr-download-jupyter :download:`Download Jupyter notebook: plot_rbm_logistic_classification.ipynb ` .. container:: sphx-glr-download sphx-glr-download-python :download:`Download Python source code: plot_rbm_logistic_classification.py ` .. container:: sphx-glr-download sphx-glr-download-zip :download:`Download zipped: plot_rbm_logistic_classification.zip ` .. include:: plot_rbm_logistic_classification.recommendations .. only:: html .. rst-class:: sphx-glr-signature `Gallery generated by Sphinx-Gallery `_