.. DO NOT EDIT. .. THIS FILE WAS AUTOMATICALLY GENERATED BY SPHINX-GALLERY. .. TO MAKE CHANGES, EDIT THE SOURCE PYTHON FILE: .. "auto_examples/text/plot_document_classification_20newsgroups.py" .. LINE NUMBERS ARE GIVEN BELOW. .. only:: html .. note:: :class: sphx-glr-download-link-note :ref:`Go to the end ` to download the full example code. or to run this example in your browser via Binder .. rst-class:: sphx-glr-example-title .. _sphx_glr_auto_examples_text_plot_document_classification_20newsgroups.py: ====================================================== 使用稀疏特征对文本文档进行分类 ====================================================== 这是一个示例,展示了如何使用scikit-learn通过 `词袋模型 `_ 对文档进行按主题分类。此示例使用 Tf-idf加权的文档-词项稀疏矩阵来编码特征,并展示了几种可以高效处理稀疏矩阵的分类器。 对于通过无监督学习方法进行文档分析,请参见示例脚本 :ref:`sphx_glr_auto_examples_text_plot_document_clustering.py` . .. GENERATED FROM PYTHON SOURCE LINES 14-19 .. code-block:: Python # 作者:scikit-learn 开发者 # SPDX-License-Identifier: BSD-3-Clause .. GENERATED FROM PYTHON SOURCE LINES 20-24 加载和向量化20个新闻组文本数据集 =================================== 我们定义了一个函数来从 :ref:`20newsgroups_dataset` 加载数据,该数据集包含大约 18,000 篇关于 20 个主题的新闻组帖子,分为两个子集:一个用于训练(或开发),另一个用于测试(或性能评估)。请注意,默认情况下,文本样本包含一些消息元数据,例如 `'headers'` 、 `'footers'` (签名)和对其他帖子的 `'quotes'` 。因此, `fetch_20newsgroups` 函数接受一个名为 `remove` 的参数,尝试去除这些可能使分类问题“过于简单”的信息。这是通过使用既不完美也不标准的简单启发式方法来实现的,因此默认情况下是禁用的。 .. GENERATED FROM PYTHON SOURCE LINES 24-107 .. code-block:: Python from time import time from sklearn.datasets import fetch_20newsgroups from sklearn.feature_extraction.text import TfidfVectorizer categories = [ "alt.atheism", "talk.religion.misc", "comp.graphics", "sci.space", ] def size_mb(docs): return sum(len(s.encode("utf-8")) for s in docs) / 1e6 def load_dataset(verbose=False, remove=()): """加载并向量化20个新闻组数据集。""" data_train = fetch_20newsgroups( subset="train", categories=categories, shuffle=True, random_state=42, remove=remove, ) data_test = fetch_20newsgroups( subset="test", categories=categories, shuffle=True, random_state=42, remove=remove, ) # `target_names` 中标签的顺序可能与 `categories` 不同 target_names = data_train.target_names # 将目标分为训练集和测试集 y_train, y_test = data_train.target, data_test.target # 使用稀疏向量化器从训练数据中提取特征 t0 = time() vectorizer = TfidfVectorizer( sublinear_tf=True, max_df=0.5, min_df=5, stop_words="english" ) X_train = vectorizer.fit_transform(data_train.data) duration_train = time() - t0 # 使用相同的向量化器从测试数据中提取特征 t0 = time() X_test = vectorizer.transform(data_test.data) duration_test = time() - t0 feature_names = vectorizer.get_feature_names_out() if verbose: # 计算已加载数据的大小 data_train_size_mb = size_mb(data_train.data) data_test_size_mb = size_mb(data_test.data) print( f"{len(data_train.data)} documents - " f"{data_train_size_mb:.2f}MB (training set)" ) print(f"{len(data_test.data)} documents - {data_test_size_mb:.2f}MB (test set)") print(f"{len(target_names)} categories") print( f"vectorize training done in {duration_train:.3f}s " f"at {data_train_size_mb / duration_train:.3f}MB/s" ) print(f"n_samples: {X_train.shape[0]}, n_features: {X_train.shape[1]}") print( f"vectorize testing done in {duration_test:.3f}s " f"at {data_test_size_mb / duration_test:.3f}MB/s" ) print(f"n_samples: {X_test.shape[0]}, n_features: {X_test.shape[1]}") return X_train, X_test, y_train, y_test, feature_names, target_names .. GENERATED FROM PYTHON SOURCE LINES 108-118 袋装词文档分类器的分析 ========================= 我们现在将训练一个分类器两次,一次是在包含元数据的文本样本上,另一次是在去除元数据后。对于这两种情况,我们将使用混淆矩阵分析测试集上的分类错误,并检查定义训练模型的分类函数的系数。 没有元数据剥离的模型 -------------------------------- 我们首先使用自定义函数 `load_dataset` 来加载未去除元数据的数据。 .. GENERATED FROM PYTHON SOURCE LINES 118-123 .. code-block:: Python X_train, X_test, y_train, y_test, feature_names, target_names = load_dataset( verbose=True ) .. rst-class:: sphx-glr-script-out .. code-block:: none 2034 documents - 3.98MB (training set) 1353 documents - 2.87MB (test set) 4 categories vectorize training done in 0.203s at 19.580MB/s n_samples: 2034, n_features: 7831 vectorize testing done in 0.143s at 20.087MB/s n_samples: 1353, n_features: 7831 .. GENERATED FROM PYTHON SOURCE LINES 124-125 我们的第一个模型是 :class:`~sklearn.linear_model.RidgeClassifier` 类的一个实例。这是一个线性分类模型,使用对每个可能类别进行 {-1, 1} 编码的目标的均方误差。与 :class:`~sklearn.linear_model.LogisticRegression` 相反,:class:`~sklearn.linear_model.RidgeClassifier` 不提供概率预测(没有 `predict_proba` 方法),但通常训练速度更快。 .. GENERATED FROM PYTHON SOURCE LINES 125-133 .. code-block:: Python from sklearn.linear_model import RidgeClassifier clf = RidgeClassifier(tol=1e-2, solver="sparse_cg") clf.fit(X_train, y_train) pred = clf.predict(X_test) .. GENERATED FROM PYTHON SOURCE LINES 134-135 我们绘制此分类器的混淆矩阵,以查找分类错误中是否存在模式。 .. GENERATED FROM PYTHON SOURCE LINES 135-149 .. code-block:: Python import matplotlib.pyplot as plt from sklearn.metrics import ConfusionMatrixDisplay fig, ax = plt.subplots(figsize=(10, 5)) ConfusionMatrixDisplay.from_predictions(y_test, pred, ax=ax) ax.xaxis.set_ticklabels(target_names) ax.yaxis.set_ticklabels(target_names) _ = ax.set_title( f"Confusion Matrix for {clf.__class__.__name__}\non the original documents" ) .. image-sg:: /auto_examples/text/images/sphx_glr_plot_document_classification_20newsgroups_001.png :alt: Confusion Matrix for RidgeClassifier on the original documents :srcset: /auto_examples/text/images/sphx_glr_plot_document_classification_20newsgroups_001.png :class: sphx-glr-single-img .. GENERATED FROM PYTHON SOURCE LINES 150-155 混淆矩阵突显出, `alt.atheism` 类的文档经常与 `talk.religion.misc` 类的文档混淆,反之亦然,这是预料之中的,因为这些主题在语义上是相关的。 我们还观察到, `sci.space` 类的一些文档可能会被错误分类为 `comp.graphics` ,而反过来的情况则少得多。需要手动检查这些错误分类的文档,以便对这种不对称性有一些了解。可能是因为空间主题的词汇比计算机图形的词汇更为具体。 我们可以通过查看具有最高平均特征效应的词语,来更深入地了解这个分类器是如何做出决策的: .. GENERATED FROM PYTHON SOURCE LINES 155-206 .. code-block:: Python import numpy as np import pandas as pd def plot_feature_effects(): # 按出现频率加权的学习系数 average_feature_effects = clf.coef_ * np.asarray(X_train.mean(axis=0)).ravel() for i, label in enumerate(target_names): top5 = np.argsort(average_feature_effects[i])[-5 :][::-1] if i == 0: top = pd.DataFrame(feature_names[top5], columns=[label]) top_indices = top5 else: top[label] = feature_names[top5] top_indices = np.concatenate((top_indices, top5), axis=None) top_indices = np.unique(top_indices) predictive_words = feature_names[top_indices] # 绘制特征效果 bar_size = 0.25 padding = 0.75 y_locs = np.arange(len(top_indices)) * (4 * bar_size + padding) fig, ax = plt.subplots(figsize=(10, 8)) for i, label in enumerate(target_names): ax.barh( y_locs + (i - 2) * bar_size, average_feature_effects[i, top_indices], height=bar_size, label=label, ) ax.set( yticks=y_locs, yticklabels=predictive_words, ylim=[ 0 - 4 * bar_size, len(top_indices) * (4 * bar_size + padding) - 4 * bar_size, ], ) ax.legend(loc="lower right") print("top 5 keywords per class:") print(top) return ax _ = plot_feature_effects().set_title("Average feature effect on the original data") .. image-sg:: /auto_examples/text/images/sphx_glr_plot_document_classification_20newsgroups_002.png :alt: Average feature effect on the original data :srcset: /auto_examples/text/images/sphx_glr_plot_document_classification_20newsgroups_002.png :class: sphx-glr-single-img .. rst-class:: sphx-glr-script-out .. code-block:: none top 5 keywords per class: alt.atheism comp.graphics sci.space talk.religion.misc 0 keith graphics space christian 1 god university nasa com 2 atheists thanks orbit god 3 people does moon morality 4 caltech image access people .. GENERATED FROM PYTHON SOURCE LINES 207-208 我们可以观察到,最具预测性的词语通常与某个单一类别强烈正相关,而与所有其他类别负相关。大多数正相关的词语都很容易解释。然而,一些词语如“god”和“people”同时与“talk.misc.religion”和“alt.atheism”正相关,因为这两个类别预期会共享一些常见词汇。然而,也有一些词语如“christian”和“morality”仅与“talk.misc.religion”正相关。此外,在这个版本的数据集中,由于数据集中的某些元数据污染,例如讨论中前几封电子邮件的发件人地址,“caltech”这个词成为了无神论的顶级预测特征之一,如下所示: .. GENERATED FROM PYTHON SOURCE LINES 208-219 .. code-block:: Python data_train = fetch_20newsgroups( subset="train", categories=categories, shuffle=True, random_state=42 ) for doc in data_train.data: if "caltech" in doc: print(doc) break .. rst-class:: sphx-glr-script-out .. code-block:: none From: livesey@solntze.wpd.sgi.com (Jon Livesey) Subject: Re: Morality? (was Re: , keith@cco.caltech.edu (Keith Allan Schneider) writes: |> livesey@solntze.wpd.sgi.com (Jon Livesey) writes: |> |> >>>Explain to me |> >>>how instinctive acts can be moral acts, and I am happy to listen. |> >>For example, if it were instinctive not to murder... |> > |> >Then not murdering would have no moral significance, since there |> >would be nothing voluntary about it. |> |> See, there you go again, saying that a moral act is only significant |> if it is "voluntary." Why do you think this? If you force me to do something, am I morally responsible for it? |> |> And anyway, humans have the ability to disregard some of their instincts. Well, make up your mind. Is it to be "instinctive not to murder" or not? |> |> >>So, only intelligent beings can be moral, even if the bahavior of other |> >>beings mimics theirs? |> > |> >You are starting to get the point. Mimicry is not necessarily the |> >same as the action being imitated. A Parrot saying "Pretty Polly" |> >isn't necessarily commenting on the pulchritude of Polly. |> |> You are attaching too many things to the term "moral," I think. |> Let's try this: is it "good" that animals of the same species |> don't kill each other. Or, do you think this is right? It's not even correct. Animals of the same species do kill one another. |> |> Or do you think that animals are machines, and that nothing they do |> is either right nor wrong? Sigh. I wonder how many times we have been round this loop. I think that instinctive bahaviour has no moral significance. I am quite prepared to believe that higher animals, such as primates, have the beginnings of a moral sense, since they seem to exhibit self-awareness. |> |> |> >>Animals of the same species could kill each other arbitarily, but |> >>they don't. |> > |> >They do. I and other posters have given you many examples of exactly |> >this, but you seem to have a very short memory. |> |> Those weren't arbitrary killings. They were slayings related to some |> sort of mating ritual or whatnot. So what? Are you trying to say that some killing in animals has a moral significance and some does not? Is this your natural morality> |> |> >>Are you trying to say that this isn't an act of morality because |> >>most animals aren't intelligent enough to think like we do? |> > |> >I'm saying: |> > "There must be the possibility that the organism - it's not |> > just people we are talking about - can consider alternatives." |> > |> >It's right there in the posting you are replying to. |> |> Yes it was, but I still don't understand your distinctions. What |> do you mean by "consider?" Can a small child be moral? How about |> a gorilla? A dolphin? A platypus? Where is the line drawn? Does |> the being need to be self aware? Are you blind? What do you think that this sentence means? "There must be the possibility that the organism - it's not just people we are talking about - can consider alternatives." What would that imply? |> |> What *do* you call the mechanism which seems to prevent animals of |> the same species from (arbitrarily) killing each other? Don't |> you find the fact that they don't at all significant? I find the fact that they do to be significant. jon. .. GENERATED FROM PYTHON SOURCE LINES 220-227 这些标头、签名页脚(以及来自先前消息的引用元数据)可以被视为通过识别注册成员而人为地揭示新闻组的辅助信息,而我们更希望文本分类器仅从每个文本文档的“主要内容”中学习,而不是依赖泄露的作者身份。 带有元数据剥离的模型 scikit-learn 中 20 个新闻组数据集加载器的 `remove` 选项允许通过启发式方法尝试过滤掉一些不需要的元数据,这些元数据会使分类问题变得人为地更简单。请注意,这种文本内容的过滤远非完美。 让我们尝试利用这个选项来训练一个文本分类器,使其在做出决策时不过多依赖此类元数据: .. GENERATED FROM PYTHON SOURCE LINES 227-248 .. code-block:: Python ( X_train, X_test, y_train, y_test, feature_names, target_names, ) = load_dataset(remove=("headers", "footers", "quotes")) clf = RidgeClassifier(tol=1e-2, solver="sparse_cg") clf.fit(X_train, y_train) pred = clf.predict(X_test) fig, ax = plt.subplots(figsize=(10, 5)) ConfusionMatrixDisplay.from_predictions(y_test, pred, ax=ax) ax.xaxis.set_ticklabels(target_names) ax.yaxis.set_ticklabels(target_names) _ = ax.set_title( f"Confusion Matrix for {clf.__class__.__name__}\non filtered documents" ) .. image-sg:: /auto_examples/text/images/sphx_glr_plot_document_classification_20newsgroups_003.png :alt: Confusion Matrix for RidgeClassifier on filtered documents :srcset: /auto_examples/text/images/sphx_glr_plot_document_classification_20newsgroups_003.png :class: sphx-glr-single-img .. GENERATED FROM PYTHON SOURCE LINES 249-250 通过查看混淆矩阵,更明显地看出使用元数据训练的模型的得分过于乐观。没有元数据的分类问题虽然准确性较低,但更能代表预期的文本分类问题。 .. GENERATED FROM PYTHON SOURCE LINES 250-254 .. code-block:: Python _ = plot_feature_effects().set_title("Average feature effects on filtered documents") .. image-sg:: /auto_examples/text/images/sphx_glr_plot_document_classification_20newsgroups_004.png :alt: Average feature effects on filtered documents :srcset: /auto_examples/text/images/sphx_glr_plot_document_classification_20newsgroups_004.png :class: sphx-glr-single-img .. rst-class:: sphx-glr-script-out .. code-block:: none top 5 keywords per class: alt.atheism comp.graphics sci.space talk.religion.misc 0 don graphics space god 1 people file like christian 2 say thanks nasa jesus 3 religion image orbit christians 4 post does launch wrong .. GENERATED FROM PYTHON SOURCE LINES 255-256 在下一节中,我们将保留没有元数据的数据集,以比较几种分类器。 .. GENERATED FROM PYTHON SOURCE LINES 259-263 基准测试分类器 ======================== Scikit-learn 提供了许多不同种类的分类算法。在本节中,我们将针对同一个文本分类问题训练这些分类器,并测量它们的泛化性能(测试集上的准确率)和计算性能(速度),包括训练时间和测试时间。为此,我们定义了以下基准测试工具: .. GENERATED FROM PYTHON SOURCE LINES 263-298 .. code-block:: Python from sklearn import metrics from sklearn.utils.extmath import density def benchmark(clf, custom_name=False): print("_" * 80) print("Training: ") print(clf) t0 = time() clf.fit(X_train, y_train) train_time = time() - t0 print(f"train time: {train_time:.3}s") t0 = time() pred = clf.predict(X_test) test_time = time() - t0 print(f"test time: {test_time:.3}s") score = metrics.accuracy_score(y_test, pred) print(f"accuracy: {score:.3}") if hasattr(clf, "coef_"): print(f"dimensionality: {clf.coef_.shape[1]}") print(f"density: {density(clf.coef_)}") print() print() if custom_name: clf_descr = str(custom_name) else: clf_descr = clf.__class__.__name__ return clf_descr, score, train_time, test_time .. GENERATED FROM PYTHON SOURCE LINES 299-304 我们现在使用8种不同的分类模型对数据集进行训练和测试,并获取每个模型的性能结果。本研究的目的是突出在这种多类别文本分类问题中,不同类型分类器的计算/准确性权衡。 请注意,最重要的超参数值是通过网格搜索过程调整的,为了简化起见,这个过程没有在本笔记本中展示。请参见示例脚本 :ref:`sphx_glr_auto_examples_model_selection_plot_grid_search_text_feature_extraction.py` noqa: E501 以了解如何进行此类调整的演示。 .. GENERATED FROM PYTHON SOURCE LINES 304-335 .. code-block:: Python from sklearn.ensemble import RandomForestClassifier from sklearn.linear_model import LogisticRegression, SGDClassifier from sklearn.naive_bayes import ComplementNB from sklearn.neighbors import KNeighborsClassifier, NearestCentroid from sklearn.svm import LinearSVC results = [] for clf, name in ( (LogisticRegression(C=5, max_iter=1000), "Logistic Regression"), (RidgeClassifier(alpha=1.0, solver="sparse_cg"), "Ridge Classifier"), (KNeighborsClassifier(n_neighbors=100), "kNN"), (RandomForestClassifier(), "Random Forest"), # L2惩罚线性支持向量机 (LinearSVC(C=0.1, dual=False, max_iter=1000), "Linear SVC"), # L2惩罚线性随机梯度下降 (SGD) ( SGDClassifier( loss="log_loss", alpha=1e-4, n_iter_no_change=3, early_stopping=True ), "log-loss SGD", ), # NearestCentroid(又名Rocchio分类器) (NearestCentroid(), "NearestCentroid"), # 稀疏朴素贝叶斯分类器 (ComplementNB(alpha=0.1), "Complement naive Bayes"), ): print("=" * 80) print(name) results.append(benchmark(clf, name)) .. rst-class:: sphx-glr-script-out .. code-block:: none ================================================================================ Logistic Regression ________________________________________________________________________________ Training: LogisticRegression(C=5, max_iter=1000) train time: 3.84s test time: 0.00113s accuracy: 0.772 dimensionality: 5316 density: 1.0 ================================================================================ Ridge Classifier ________________________________________________________________________________ Training: RidgeClassifier(solver='sparse_cg') train time: 0.0203s test time: 0.000765s accuracy: 0.76 dimensionality: 5316 density: 1.0 ================================================================================ kNN ________________________________________________________________________________ Training: KNeighborsClassifier(n_neighbors=100) train time: 0.000469s test time: 0.0809s accuracy: 0.752 ================================================================================ Random Forest ________________________________________________________________________________ Training: RandomForestClassifier() train time: 2.97s test time: 0.0368s accuracy: 0.702 ================================================================================ Linear SVC ________________________________________________________________________________ Training: LinearSVC(C=0.1, dual=False) train time: 0.0197s test time: 0.000665s accuracy: 0.752 dimensionality: 5316 density: 1.0 ================================================================================ log-loss SGD ________________________________________________________________________________ Training: SGDClassifier(early_stopping=True, loss='log_loss', n_iter_no_change=3) train time: 0.0241s test time: 0.000684s accuracy: 0.766 dimensionality: 5316 density: 1.0 ================================================================================ NearestCentroid ________________________________________________________________________________ Training: NearestCentroid() train time: 0.00173s test time: 0.00173s accuracy: 0.748 ================================================================================ Complement naive Bayes ________________________________________________________________________________ Training: ComplementNB(alpha=0.1) train time: 0.00109s test time: 0.000638s accuracy: 0.779 .. GENERATED FROM PYTHON SOURCE LINES 336-340 绘制每个分类器的准确率、训练时间和测试时间 ======================================================== 散点图显示了每个分类器的测试准确率与训练和测试时间之间的权衡。 .. GENERATED FROM PYTHON SOURCE LINES 340-370 .. code-block:: Python indices = np.arange(len(results)) results = [[x[i] for x in results] for i in range(4)] clf_names, score, training_time, test_time = results training_time = np.array(training_time) test_time = np.array(test_time) fig, ax1 = plt.subplots(figsize=(10, 8)) ax1.scatter(score, training_time, s=60) ax1.set( title="Score-training time trade-off", yscale="log", xlabel="test accuracy", ylabel="training time (s)", ) fig, ax2 = plt.subplots(figsize=(10, 8)) ax2.scatter(score, test_time, s=60) ax2.set( title="Score-test time trade-off", yscale="log", xlabel="test accuracy", ylabel="test time (s)", ) for i, txt in enumerate(clf_names): ax1.annotate(txt, (score[i], training_time[i])) ax2.annotate(txt, (score[i], test_time[i])) .. rst-class:: sphx-glr-horizontal * .. image-sg:: /auto_examples/text/images/sphx_glr_plot_document_classification_20newsgroups_005.png :alt: Score-training time trade-off :srcset: /auto_examples/text/images/sphx_glr_plot_document_classification_20newsgroups_005.png :class: sphx-glr-multi-img * .. image-sg:: /auto_examples/text/images/sphx_glr_plot_document_classification_20newsgroups_006.png :alt: Score-test time trade-off :srcset: /auto_examples/text/images/sphx_glr_plot_document_classification_20newsgroups_006.png :class: sphx-glr-multi-img .. GENERATED FROM PYTHON SOURCE LINES 371-376 朴素贝叶斯模型在得分和训练/测试时间之间具有最佳的权衡,而随机森林训练缓慢,预测代价高且准确性相对较差。这是预料之中的:对于高维预测问题,线性模型通常更适合,因为当特征空间有10000维或更多时,大多数问题变得线性可分。 线性模型在训练速度和准确性上的差异可以通过它们优化的损失函数的选择以及使用的正则化类型来解释。请注意,某些具有相同损失但使用不同求解器或正则化配置的线性模型可能会产生不同的拟合时间和测试准确性。我们可以在第二个图上观察到,一旦训练完成,所有线性模型的预测速度大致相同,这是预期的,因为它们都实现了相同的预测函数。 KNeighborsClassifier 的准确率相对较低,并且测试时间最长。预测时间长也是预料之中的:对于每个预测,模型必须计算测试样本与训练集中每个文档之间的成对距离,这在计算上是昂贵的。此外,“维度灾难”会损害该模型在文本分类问题的高维特征空间中产生具有竞争力的准确性的能力。 .. rst-class:: sphx-glr-timing **Total running time of the script:** (0 minutes 9.628 seconds) .. _sphx_glr_download_auto_examples_text_plot_document_classification_20newsgroups.py: .. only:: html .. container:: sphx-glr-footer sphx-glr-footer-example .. container:: binder-badge .. image:: images/binder_badge_logo.svg :target: https://mybinder.org/v2/gh/scikit-learn/scikit-learn/main?urlpath=lab/tree/notebooks/auto_examples/text/plot_document_classification_20newsgroups.ipynb :alt: Launch binder :width: 150 px .. container:: sphx-glr-download sphx-glr-download-jupyter :download:`Download Jupyter notebook: plot_document_classification_20newsgroups.ipynb ` .. container:: sphx-glr-download sphx-glr-download-python :download:`Download Python source code: plot_document_classification_20newsgroups.py ` .. container:: sphx-glr-download sphx-glr-download-zip :download:`Download zipped: plot_document_classification_20newsgroups.zip ` .. include:: plot_document_classification_20newsgroups.recommendations .. only:: html .. rst-class:: sphx-glr-signature `Gallery generated by Sphinx-Gallery `_