用于数字分类的受限玻尔兹曼机特征#

对于像素值可以解释为白色背景上黑度的灰度图像数据,如手写数字识别,伯努利受限玻尔兹曼机模型(BernoulliRBM )可以执行有效的非线性特征提取。

# 作者:scikit-learn 开发者
# SPDX-License-Identifier: BSD-3-Clause

生成数据#

为了从小数据集中学习到良好的潜在表示,我们通过对训练数据进行每个方向1个像素的线性偏移来人为生成更多的标记数据。

import numpy as np
from scipy.ndimage import convolve

from sklearn import datasets
from sklearn.model_selection import train_test_split
from sklearn.preprocessing import minmax_scale


def nudge_dataset(X, Y):
    """通过将X中的8x8图像向左、向右、向下、向上移动1像素,这会生成一个比原始数据集大5倍的数据集,
"""
    direction_vectors = [
        [[0, 1, 0], [0, 0, 0], [0, 0, 0]],
        [[0, 0, 0], [1, 0, 0], [0, 0, 0]],
        [[0, 0, 0], [0, 0, 1], [0, 0, 0]],
        [[0, 0, 0], [0, 0, 0], [0, 1, 0]],
    ]

    def shift(x, w):
        return convolve(x.reshape((8, 8)), mode="constant", weights=w).ravel()

    X = np.concatenate(
        [X] + [np.apply_along_axis(shift, 1, X, vector) for vector in direction_vectors]
    )
    Y = np.concatenate([Y for _ in range(5)], axis=0)
    return X, Y


X, y = datasets.load_digits(return_X_y=True)
X = np.asarray(X, "float32")
X, Y = nudge_dataset(X, y)
X = minmax_scale(X, feature_range=(0, 1))  # 0-1 scaling

X_train, X_test, Y_train, Y_test = train_test_split(X, Y, test_size=0.2, random_state=0)

模型定义#

我们构建了一个分类管道,使用BernoulliRBM特征提取器和一个:class:LogisticRegression <sklearn.linear_model.LogisticRegression> 分类器。

from sklearn import linear_model
from sklearn.neural_network import BernoulliRBM
from sklearn.pipeline import Pipeline

logistic = linear_model.LogisticRegression(solver="newton-cg", tol=1)
rbm = BernoulliRBM(random_state=0, verbose=True)

rbm_features_classifier = Pipeline(steps=[("rbm", rbm), ("logistic", logistic)])

Training#

整个模型的超参数(学习率、隐藏层大小、正则化)通过网格搜索进行了优化,但由于运行时间的限制,这里不再重复搜索过程。

from sklearn.base import clone

# 超参数。这些参数是通过交叉验证和GridSearchCV设置的。这里我们不进行交叉验证以节省时间。
rbm.learning_rate = 0.06
rbm.n_iter = 10

# 更多的组件往往会带来更好的预测性能,但也会增加拟合时间。
rbm.n_components = 100
logistic.C = 6000

# 训练 RBM-Logistic 管道
rbm_features_classifier.fit(X_train, Y_train)

# 直接在像素上训练逻辑回归分类器
raw_pixel_classifier = clone(logistic)
raw_pixel_classifier.C = 100.0
raw_pixel_classifier.fit(X_train, Y_train)
[BernoulliRBM] Iteration 1, pseudo-likelihood = -25.57, time = 0.05s
[BernoulliRBM] Iteration 2, pseudo-likelihood = -23.68, time = 0.08s
[BernoulliRBM] Iteration 3, pseudo-likelihood = -22.88, time = 0.07s
[BernoulliRBM] Iteration 4, pseudo-likelihood = -21.91, time = 0.07s
[BernoulliRBM] Iteration 5, pseudo-likelihood = -21.79, time = 0.07s
[BernoulliRBM] Iteration 6, pseudo-likelihood = -20.96, time = 0.07s
[BernoulliRBM] Iteration 7, pseudo-likelihood = -20.88, time = 0.08s
[BernoulliRBM] Iteration 8, pseudo-likelihood = -20.50, time = 0.06s
[BernoulliRBM] Iteration 9, pseudo-likelihood = -20.36, time = 0.07s
[BernoulliRBM] Iteration 10, pseudo-likelihood = -20.07, time = 0.06s
LogisticRegression(C=100.0, solver='newton-cg', tol=1)
In a Jupyter environment, please rerun this cell to show the HTML representation or trust the notebook.
On GitHub, the HTML representation is unable to render, please try loading this page with nbviewer.org.


Evaluation#

from sklearn import metrics

Y_pred = rbm_features_classifier.predict(X_test)
print(
    "Logistic regression using RBM features:\n%s\n"
    % (metrics.classification_report(Y_test, Y_pred))
)
/app/scikit-learn-main-origin/sklearn/metrics/_classification.py:1404: UndefinedMetricWarning:

Precision is ill-defined and being set to 0.0 in labels with no predicted samples. Use `zero_division` parameter to control this behavior.

/app/scikit-learn-main-origin/sklearn/metrics/_classification.py:1404: UndefinedMetricWarning:

Precision is ill-defined and being set to 0.0 in labels with no predicted samples. Use `zero_division` parameter to control this behavior.

/app/scikit-learn-main-origin/sklearn/metrics/_classification.py:1404: UndefinedMetricWarning:

Precision is ill-defined and being set to 0.0 in labels with no predicted samples. Use `zero_division` parameter to control this behavior.

Logistic regression using RBM features:
              precision    recall  f1-score   support

           0       0.10      1.00      0.18       174
           1       0.00      0.00      0.00       184
           2       0.00      0.00      0.00       166
           3       0.00      0.00      0.00       194
           4       0.00      0.00      0.00       186
           5       0.00      0.00      0.00       181
           6       0.00      0.00      0.00       207
           7       0.00      0.00      0.00       154
           8       0.00      0.00      0.00       182
           9       0.00      0.00      0.00       169

    accuracy                           0.10      1797
   macro avg       0.01      0.10      0.02      1797
weighted avg       0.01      0.10      0.02      1797
Y_pred = raw_pixel_classifier.predict(X_test)
print(
    "Logistic regression using raw pixel features:\n%s\n"
    % (metrics.classification_report(Y_test, Y_pred))
)
/app/scikit-learn-main-origin/sklearn/metrics/_classification.py:1404: UndefinedMetricWarning:

Precision is ill-defined and being set to 0.0 in labels with no predicted samples. Use `zero_division` parameter to control this behavior.

/app/scikit-learn-main-origin/sklearn/metrics/_classification.py:1404: UndefinedMetricWarning:

Precision is ill-defined and being set to 0.0 in labels with no predicted samples. Use `zero_division` parameter to control this behavior.

/app/scikit-learn-main-origin/sklearn/metrics/_classification.py:1404: UndefinedMetricWarning:

Precision is ill-defined and being set to 0.0 in labels with no predicted samples. Use `zero_division` parameter to control this behavior.

Logistic regression using raw pixel features:
              precision    recall  f1-score   support

           0       0.10      1.00      0.18       174
           1       0.00      0.00      0.00       184
           2       0.00      0.00      0.00       166
           3       0.00      0.00      0.00       194
           4       0.00      0.00      0.00       186
           5       0.00      0.00      0.00       181
           6       0.00      0.00      0.00       207
           7       0.00      0.00      0.00       154
           8       0.00      0.00      0.00       182
           9       0.00      0.00      0.00       169

    accuracy                           0.10      1797
   macro avg       0.01      0.10      0.02      1797
weighted avg       0.01      0.10      0.02      1797

伯努利受限玻尔兹曼机提取的特征有助于提高相对于原始像素的逻辑回归分类准确性。

Plotting#

import matplotlib.pyplot as plt

plt.figure(figsize=(4.2, 4))
for i, comp in enumerate(rbm.components_):
    plt.subplot(10, 10, i + 1)
    plt.imshow(comp.reshape((8, 8)), cmap=plt.cm.gray_r, interpolation="nearest")
    plt.xticks(())
    plt.yticks(())
plt.suptitle("100 components extracted by RBM", fontsize=16)
plt.subplots_adjust(0.08, 0.02, 0.92, 0.85, 0.08, 0.23)

plt.show()
100 components extracted by RBM

Total running time of the script: (0 minutes 1.376 seconds)

Related examples

绘制分类概率

绘制分类概率

识别手写数字

识别手写数字

逻辑函数

逻辑函数

具有异构数据源的列转换器

具有异构数据源的列转换器

Gallery generated by Sphinx-Gallery