.. DO NOT EDIT. .. THIS FILE WAS AUTOMATICALLY GENERATED BY SPHINX-GALLERY. .. TO MAKE CHANGES, EDIT THE SOURCE PYTHON FILE: .. "auto_examples/semi_supervised/plot_semi_supervised_newsgroups.py" .. LINE NUMBERS ARE GIVEN BELOW. .. only:: html .. note:: :class: sphx-glr-download-link-note :ref:`Go to the end ` to download the full example code. or to run this example in your browser via Binder .. rst-class:: sphx-glr-example-title .. _sphx_glr_auto_examples_semi_supervised_plot_semi_supervised_newsgroups.py: ================================================ 文本数据集上的半监督分类 ================================================ 在本示例中,半监督分类器将在20个新闻组数据集上进行训练(该数据集将自动下载)。 您可以通过将类别名称提供给数据集加载器来调整类别数量,或者将其设置为 `None` 以获取所有20个类别。 .. GENERATED FROM PYTHON SOURCE LINES 11-107 .. rst-class:: sphx-glr-script-out .. code-block:: none 2823 documents 5 categories Supervised SGDClassifier on 100% of the data: Number of training samples: 2117 Unlabeled samples in training set: 0 Micro-averaged F1 score on test set: 0.888 ---------- Supervised SGDClassifier on 20% of the training data: Number of training samples: 412 Unlabeled samples in training set: 0 Micro-averaged F1 score on test set: 0.754 ---------- SelfTrainingClassifier on 20% of the training data (rest is unlabeled): Number of training samples: 2117 Unlabeled samples in training set: 1705 End of iteration 1, added 1084 new labels. End of iteration 2, added 187 new labels. End of iteration 3, added 55 new labels. End of iteration 4, added 28 new labels. End of iteration 5, added 23 new labels. End of iteration 6, added 14 new labels. End of iteration 7, added 2 new labels. End of iteration 8, added 9 new labels. End of iteration 9, added 8 new labels. End of iteration 10, added 4 new labels. Micro-averaged F1 score on test set: 0.806 ---------- LabelSpreading on 20% of the data (rest is unlabeled): Number of training samples: 2117 Unlabeled samples in training set: 1705 Micro-averaged F1 score on test set: 0.644 ---------- | .. code-block:: Python import numpy as np from sklearn.datasets import fetch_20newsgroups from sklearn.feature_extraction.text import CountVectorizer, TfidfTransformer from sklearn.linear_model import SGDClassifier from sklearn.metrics import f1_score from sklearn.model_selection import train_test_split from sklearn.pipeline import Pipeline from sklearn.preprocessing import FunctionTransformer from sklearn.semi_supervised import LabelSpreading, SelfTrainingClassifier # 正在加载包含前五个类别的数据集 data = fetch_20newsgroups( subset="train", categories=[ "alt.atheism", "comp.graphics", "comp.os.ms-windows.misc", "comp.sys.ibm.pc.hardware", "comp.sys.mac.hardware", ], ) print("%d documents" % len(data.filenames)) print("%d categories" % len(data.target_names)) print() # Parameters sdg_params = dict(alpha=1e-5, penalty="l2", loss="log_loss") vectorizer_params = dict(ngram_range=(1, 2), min_df=5, max_df=0.8) # 监督管道 pipeline = Pipeline( [ ("vect", CountVectorizer(**vectorizer_params)), ("tfidf", TfidfTransformer()), ("clf", SGDClassifier(**sdg_params)), ] ) # 自我训练流程 st_pipeline = Pipeline( [ ("vect", CountVectorizer(**vectorizer_params)), ("tfidf", TfidfTransformer()), ("clf", SelfTrainingClassifier(SGDClassifier(**sdg_params), verbose=True)), ] ) # LabelSpreading Pipeline ls_pipeline = Pipeline( [ ("vect", CountVectorizer(**vectorizer_params)), ("tfidf", TfidfTransformer()), # LabelSpreading 不支持密集矩阵 ("toarray", FunctionTransformer(lambda x: x.toarray())), ("clf", LabelSpreading()), ] ) def eval_and_print_metrics(clf, X_train, y_train, X_test, y_test): print("Number of training samples:", len(X_train)) print("Unlabeled samples in training set:", sum(1 for x in y_train if x == -1)) clf.fit(X_train, y_train) y_pred = clf.predict(X_test) print( "Micro-averaged F1 score on test set: %0.3f" % f1_score(y_test, y_pred, average="micro") ) print("-" * 10) print() if __name__ == "__main__": X, y = data.data, data.target X_train, X_test, y_train, y_test = train_test_split(X, y) print("Supervised SGDClassifier on 100% of the data:") eval_and_print_metrics(pipeline, X_train, y_train, X_test, y_test) # 选择训练数据集的20%作为掩码 y_mask = np.random.rand(len(y_train)) < 0.2 # X_20 和 y_20 是由掩码指示的训练数据集的子集 X_20, y_20 = map( list, zip(*((x, y) for x, y, m in zip(X_train, y_train, y_mask) if m)) ) print("Supervised SGDClassifier on 20% of the training data:") eval_and_print_metrics(pipeline, X_20, y_20, X_test, y_test) # 将非掩码子集设置为未标记 y_train[~y_mask] = -1 print("SelfTrainingClassifier on 20% of the training data (rest is unlabeled):") eval_and_print_metrics(st_pipeline, X_train, y_train, X_test, y_test) print("LabelSpreading on 20% of the data (rest is unlabeled):") eval_and_print_metrics(ls_pipeline, X_train, y_train, X_test, y_test) .. rst-class:: sphx-glr-timing **Total running time of the script:** (0 minutes 5.662 seconds) .. _sphx_glr_download_auto_examples_semi_supervised_plot_semi_supervised_newsgroups.py: .. only:: html .. container:: sphx-glr-footer sphx-glr-footer-example .. container:: binder-badge .. image:: images/binder_badge_logo.svg :target: https://mybinder.org/v2/gh/scikit-learn/scikit-learn/main?urlpath=lab/tree/notebooks/auto_examples/semi_supervised/plot_semi_supervised_newsgroups.ipynb :alt: Launch binder :width: 150 px .. container:: sphx-glr-download sphx-glr-download-jupyter :download:`Download Jupyter notebook: plot_semi_supervised_newsgroups.ipynb ` .. container:: sphx-glr-download sphx-glr-download-python :download:`Download Python source code: plot_semi_supervised_newsgroups.py ` .. container:: sphx-glr-download sphx-glr-download-zip :download:`Download zipped: plot_semi_supervised_newsgroups.zip ` .. include:: plot_semi_supervised_newsgroups.recommendations .. only:: html .. rst-class:: sphx-glr-signature `Gallery generated by Sphinx-Gallery `_