Note
Go to the end to download the full example code. or to run this example in your browser via Binder
使用稀疏特征对文本文档进行分类#
这是一个示例,展示了如何使用scikit-learn通过 词袋模型 对文档进行按主题分类。此示例使用 Tf-idf加权的文档-词项稀疏矩阵来编码特征,并展示了几种可以高效处理稀疏矩阵的分类器。
对于通过无监督学习方法进行文档分析,请参见示例脚本 使用k-means聚类文本文档 .
# 作者:scikit-learn 开发者
# SPDX-License-Identifier: BSD-3-Clause
加载和向量化20个新闻组文本数据集#
我们定义了一个函数来从 The 20 newsgroups text dataset 加载数据,该数据集包含大约 18,000 篇关于 20 个主题的新闻组帖子,分为两个子集:一个用于训练(或开发),另一个用于测试(或性能评估)。请注意,默认情况下,文本样本包含一些消息元数据,例如 'headers'
、 'footers'
(签名)和对其他帖子的 'quotes'
。因此, fetch_20newsgroups
函数接受一个名为 remove
的参数,尝试去除这些可能使分类问题“过于简单”的信息。这是通过使用既不完美也不标准的简单启发式方法来实现的,因此默认情况下是禁用的。
from time import time
from sklearn.datasets import fetch_20newsgroups
from sklearn.feature_extraction.text import TfidfVectorizer
categories = [
"alt.atheism",
"talk.religion.misc",
"comp.graphics",
"sci.space",
]
def size_mb(docs):
return sum(len(s.encode("utf-8")) for s in docs) / 1e6
def load_dataset(verbose=False, remove=()):
"""加载并向量化20个新闻组数据集。"""
data_train = fetch_20newsgroups(
subset="train",
categories=categories,
shuffle=True,
random_state=42,
remove=remove,
)
data_test = fetch_20newsgroups(
subset="test",
categories=categories,
shuffle=True,
random_state=42,
remove=remove,
)
# `target_names` 中标签的顺序可能与 `categories` 不同
target_names = data_train.target_names
# 将目标分为训练集和测试集
y_train, y_test = data_train.target, data_test.target
# 使用稀疏向量化器从训练数据中提取特征
t0 = time()
vectorizer = TfidfVectorizer(
sublinear_tf=True, max_df=0.5, min_df=5, stop_words="english"
)
X_train = vectorizer.fit_transform(data_train.data)
duration_train = time() - t0
# 使用相同的向量化器从测试数据中提取特征
t0 = time()
X_test = vectorizer.transform(data_test.data)
duration_test = time() - t0
feature_names = vectorizer.get_feature_names_out()
if verbose:
# 计算已加载数据的大小
data_train_size_mb = size_mb(data_train.data)
data_test_size_mb = size_mb(data_test.data)
print(
f"{len(data_train.data)} documents - "
f"{data_train_size_mb:.2f}MB (training set)"
)
print(f"{len(data_test.data)} documents - {data_test_size_mb:.2f}MB (test set)")
print(f"{len(target_names)} categories")
print(
f"vectorize training done in {duration_train:.3f}s "
f"at {data_train_size_mb / duration_train:.3f}MB/s"
)
print(f"n_samples: {X_train.shape[0]}, n_features: {X_train.shape[1]}")
print(
f"vectorize testing done in {duration_test:.3f}s "
f"at {data_test_size_mb / duration_test:.3f}MB/s"
)
print(f"n_samples: {X_test.shape[0]}, n_features: {X_test.shape[1]}")
return X_train, X_test, y_train, y_test, feature_names, target_names
袋装词文档分类器的分析#
我们现在将训练一个分类器两次,一次是在包含元数据的文本样本上,另一次是在去除元数据后。对于这两种情况,我们将使用混淆矩阵分析测试集上的分类错误,并检查定义训练模型的分类函数的系数。
没有元数据剥离的模型#
我们首先使用自定义函数 load_dataset
来加载未去除元数据的数据。
X_train, X_test, y_train, y_test, feature_names, target_names = load_dataset(
verbose=True
)
2034 documents - 3.98MB (training set)
1353 documents - 2.87MB (test set)
4 categories
vectorize training done in 0.203s at 19.580MB/s
n_samples: 2034, n_features: 7831
vectorize testing done in 0.143s at 20.087MB/s
n_samples: 1353, n_features: 7831
我们的第一个模型是 RidgeClassifier
类的一个实例。这是一个线性分类模型,使用对每个可能类别进行 {-1, 1} 编码的目标的均方误差。与 LogisticRegression
相反,RidgeClassifier
不提供概率预测(没有 predict_proba
方法),但通常训练速度更快。
from sklearn.linear_model import RidgeClassifier
clf = RidgeClassifier(tol=1e-2, solver="sparse_cg")
clf.fit(X_train, y_train)
pred = clf.predict(X_test)
我们绘制此分类器的混淆矩阵,以查找分类错误中是否存在模式。
import matplotlib.pyplot as plt
from sklearn.metrics import ConfusionMatrixDisplay
fig, ax = plt.subplots(figsize=(10, 5))
ConfusionMatrixDisplay.from_predictions(y_test, pred, ax=ax)
ax.xaxis.set_ticklabels(target_names)
ax.yaxis.set_ticklabels(target_names)
_ = ax.set_title(
f"Confusion Matrix for {clf.__class__.__name__}\non the original documents"
)
混淆矩阵突显出, alt.atheism
类的文档经常与 talk.religion.misc
类的文档混淆,反之亦然,这是预料之中的,因为这些主题在语义上是相关的。
我们还观察到, sci.space
类的一些文档可能会被错误分类为 comp.graphics
,而反过来的情况则少得多。需要手动检查这些错误分类的文档,以便对这种不对称性有一些了解。可能是因为空间主题的词汇比计算机图形的词汇更为具体。
我们可以通过查看具有最高平均特征效应的词语,来更深入地了解这个分类器是如何做出决策的:
import numpy as np
import pandas as pd
def plot_feature_effects():
# 按出现频率加权的学习系数
average_feature_effects = clf.coef_ * np.asarray(X_train.mean(axis=0)).ravel()
for i, label in enumerate(target_names):
top5 = np.argsort(average_feature_effects[i])[-5 :][::-1]
if i == 0:
top = pd.DataFrame(feature_names[top5], columns=[label])
top_indices = top5
else:
top[label] = feature_names[top5]
top_indices = np.concatenate((top_indices, top5), axis=None)
top_indices = np.unique(top_indices)
predictive_words = feature_names[top_indices]
# 绘制特征效果
bar_size = 0.25
padding = 0.75
y_locs = np.arange(len(top_indices)) * (4 * bar_size + padding)
fig, ax = plt.subplots(figsize=(10, 8))
for i, label in enumerate(target_names):
ax.barh(
y_locs + (i - 2) * bar_size,
average_feature_effects[i, top_indices],
height=bar_size,
label=label,
)
ax.set(
yticks=y_locs,
yticklabels=predictive_words,
ylim=[
0 - 4 * bar_size,
len(top_indices) * (4 * bar_size + padding) - 4 * bar_size,
],
)
ax.legend(loc="lower right")
print("top 5 keywords per class:")
print(top)
return ax
_ = plot_feature_effects().set_title("Average feature effect on the original data")
top 5 keywords per class:
alt.atheism comp.graphics sci.space talk.religion.misc
0 keith graphics space christian
1 god university nasa com
2 atheists thanks orbit god
3 people does moon morality
4 caltech image access people
我们可以观察到,最具预测性的词语通常与某个单一类别强烈正相关,而与所有其他类别负相关。大多数正相关的词语都很容易解释。然而,一些词语如“god”和“people”同时与“talk.misc.religion”和“alt.atheism”正相关,因为这两个类别预期会共享一些常见词汇。然而,也有一些词语如“christian”和“morality”仅与“talk.misc.religion”正相关。此外,在这个版本的数据集中,由于数据集中的某些元数据污染,例如讨论中前几封电子邮件的发件人地址,“caltech”这个词成为了无神论的顶级预测特征之一,如下所示:
data_train = fetch_20newsgroups(
subset="train", categories=categories, shuffle=True, random_state=42
)
for doc in data_train.data:
if "caltech" in doc:
print(doc)
break
From: livesey@solntze.wpd.sgi.com (Jon Livesey)
Subject: Re: Morality? (was Re: <Political Atheists?)
Organization: sgi
Lines: 93
Distribution: world
NNTP-Posting-Host: solntze.wpd.sgi.com
In article <1qlettINN8oi@gap.caltech.edu>, keith@cco.caltech.edu (Keith Allan Schneider) writes:
|> livesey@solntze.wpd.sgi.com (Jon Livesey) writes:
|>
|> >>>Explain to me
|> >>>how instinctive acts can be moral acts, and I am happy to listen.
|> >>For example, if it were instinctive not to murder...
|> >
|> >Then not murdering would have no moral significance, since there
|> >would be nothing voluntary about it.
|>
|> See, there you go again, saying that a moral act is only significant
|> if it is "voluntary." Why do you think this?
If you force me to do something, am I morally responsible for it?
|>
|> And anyway, humans have the ability to disregard some of their instincts.
Well, make up your mind. Is it to be "instinctive not to murder"
or not?
|>
|> >>So, only intelligent beings can be moral, even if the bahavior of other
|> >>beings mimics theirs?
|> >
|> >You are starting to get the point. Mimicry is not necessarily the
|> >same as the action being imitated. A Parrot saying "Pretty Polly"
|> >isn't necessarily commenting on the pulchritude of Polly.
|>
|> You are attaching too many things to the term "moral," I think.
|> Let's try this: is it "good" that animals of the same species
|> don't kill each other. Or, do you think this is right?
It's not even correct. Animals of the same species do kill
one another.
|>
|> Or do you think that animals are machines, and that nothing they do
|> is either right nor wrong?
Sigh. I wonder how many times we have been round this loop.
I think that instinctive bahaviour has no moral significance.
I am quite prepared to believe that higher animals, such as
primates, have the beginnings of a moral sense, since they seem
to exhibit self-awareness.
|>
|>
|> >>Animals of the same species could kill each other arbitarily, but
|> >>they don't.
|> >
|> >They do. I and other posters have given you many examples of exactly
|> >this, but you seem to have a very short memory.
|>
|> Those weren't arbitrary killings. They were slayings related to some
|> sort of mating ritual or whatnot.
So what? Are you trying to say that some killing in animals
has a moral significance and some does not? Is this your
natural morality>
|>
|> >>Are you trying to say that this isn't an act of morality because
|> >>most animals aren't intelligent enough to think like we do?
|> >
|> >I'm saying:
|> > "There must be the possibility that the organism - it's not
|> > just people we are talking about - can consider alternatives."
|> >
|> >It's right there in the posting you are replying to.
|>
|> Yes it was, but I still don't understand your distinctions. What
|> do you mean by "consider?" Can a small child be moral? How about
|> a gorilla? A dolphin? A platypus? Where is the line drawn? Does
|> the being need to be self aware?
Are you blind? What do you think that this sentence means?
"There must be the possibility that the organism - it's not
just people we are talking about - can consider alternatives."
What would that imply?
|>
|> What *do* you call the mechanism which seems to prevent animals of
|> the same species from (arbitrarily) killing each other? Don't
|> you find the fact that they don't at all significant?
I find the fact that they do to be significant.
jon.
这些标头、签名页脚(以及来自先前消息的引用元数据)可以被视为通过识别注册成员而人为地揭示新闻组的辅助信息,而我们更希望文本分类器仅从每个文本文档的“主要内容”中学习,而不是依赖泄露的作者身份。
带有元数据剥离的模型
scikit-learn 中 20 个新闻组数据集加载器的 remove
选项允许通过启发式方法尝试过滤掉一些不需要的元数据,这些元数据会使分类问题变得人为地更简单。请注意,这种文本内容的过滤远非完美。
让我们尝试利用这个选项来训练一个文本分类器,使其在做出决策时不过多依赖此类元数据:
(
X_train,
X_test,
y_train,
y_test,
feature_names,
target_names,
) = load_dataset(remove=("headers", "footers", "quotes"))
clf = RidgeClassifier(tol=1e-2, solver="sparse_cg")
clf.fit(X_train, y_train)
pred = clf.predict(X_test)
fig, ax = plt.subplots(figsize=(10, 5))
ConfusionMatrixDisplay.from_predictions(y_test, pred, ax=ax)
ax.xaxis.set_ticklabels(target_names)
ax.yaxis.set_ticklabels(target_names)
_ = ax.set_title(
f"Confusion Matrix for {clf.__class__.__name__}\non filtered documents"
)
通过查看混淆矩阵,更明显地看出使用元数据训练的模型的得分过于乐观。没有元数据的分类问题虽然准确性较低,但更能代表预期的文本分类问题。
_ = plot_feature_effects().set_title("Average feature effects on filtered documents")
top 5 keywords per class:
alt.atheism comp.graphics sci.space talk.religion.misc
0 don graphics space god
1 people file like christian
2 say thanks nasa jesus
3 religion image orbit christians
4 post does launch wrong
在下一节中,我们将保留没有元数据的数据集,以比较几种分类器。
基准测试分类器#
Scikit-learn 提供了许多不同种类的分类算法。在本节中,我们将针对同一个文本分类问题训练这些分类器,并测量它们的泛化性能(测试集上的准确率)和计算性能(速度),包括训练时间和测试时间。为此,我们定义了以下基准测试工具:
from sklearn import metrics
from sklearn.utils.extmath import density
def benchmark(clf, custom_name=False):
print("_" * 80)
print("Training: ")
print(clf)
t0 = time()
clf.fit(X_train, y_train)
train_time = time() - t0
print(f"train time: {train_time:.3}s")
t0 = time()
pred = clf.predict(X_test)
test_time = time() - t0
print(f"test time: {test_time:.3}s")
score = metrics.accuracy_score(y_test, pred)
print(f"accuracy: {score:.3}")
if hasattr(clf, "coef_"):
print(f"dimensionality: {clf.coef_.shape[1]}")
print(f"density: {density(clf.coef_)}")
print()
print()
if custom_name:
clf_descr = str(custom_name)
else:
clf_descr = clf.__class__.__name__
return clf_descr, score, train_time, test_time
我们现在使用8种不同的分类模型对数据集进行训练和测试,并获取每个模型的性能结果。本研究的目的是突出在这种多类别文本分类问题中,不同类型分类器的计算/准确性权衡。
请注意,最重要的超参数值是通过网格搜索过程调整的,为了简化起见,这个过程没有在本笔记本中展示。请参见示例脚本 文本特征提取和评估的示例管道 noqa: E501 以了解如何进行此类调整的演示。
from sklearn.ensemble import RandomForestClassifier
from sklearn.linear_model import LogisticRegression, SGDClassifier
from sklearn.naive_bayes import ComplementNB
from sklearn.neighbors import KNeighborsClassifier, NearestCentroid
from sklearn.svm import LinearSVC
results = []
for clf, name in (
(LogisticRegression(C=5, max_iter=1000), "Logistic Regression"),
(RidgeClassifier(alpha=1.0, solver="sparse_cg"), "Ridge Classifier"),
(KNeighborsClassifier(n_neighbors=100), "kNN"),
(RandomForestClassifier(), "Random Forest"),
# L2惩罚线性支持向量机
(LinearSVC(C=0.1, dual=False, max_iter=1000), "Linear SVC"),
# L2惩罚线性随机梯度下降 (SGD)
(
SGDClassifier(
loss="log_loss", alpha=1e-4, n_iter_no_change=3, early_stopping=True
),
"log-loss SGD",
),
# NearestCentroid(又名Rocchio分类器)
(NearestCentroid(), "NearestCentroid"),
# 稀疏朴素贝叶斯分类器
(ComplementNB(alpha=0.1), "Complement naive Bayes"),
):
print("=" * 80)
print(name)
results.append(benchmark(clf, name))
================================================================================
Logistic Regression
________________________________________________________________________________
Training:
LogisticRegression(C=5, max_iter=1000)
train time: 3.84s
test time: 0.00113s
accuracy: 0.772
dimensionality: 5316
density: 1.0
================================================================================
Ridge Classifier
________________________________________________________________________________
Training:
RidgeClassifier(solver='sparse_cg')
train time: 0.0203s
test time: 0.000765s
accuracy: 0.76
dimensionality: 5316
density: 1.0
================================================================================
kNN
________________________________________________________________________________
Training:
KNeighborsClassifier(n_neighbors=100)
train time: 0.000469s
test time: 0.0809s
accuracy: 0.752
================================================================================
Random Forest
________________________________________________________________________________
Training:
RandomForestClassifier()
train time: 2.97s
test time: 0.0368s
accuracy: 0.702
================================================================================
Linear SVC
________________________________________________________________________________
Training:
LinearSVC(C=0.1, dual=False)
train time: 0.0197s
test time: 0.000665s
accuracy: 0.752
dimensionality: 5316
density: 1.0
================================================================================
log-loss SGD
________________________________________________________________________________
Training:
SGDClassifier(early_stopping=True, loss='log_loss', n_iter_no_change=3)
train time: 0.0241s
test time: 0.000684s
accuracy: 0.766
dimensionality: 5316
density: 1.0
================================================================================
NearestCentroid
________________________________________________________________________________
Training:
NearestCentroid()
train time: 0.00173s
test time: 0.00173s
accuracy: 0.748
================================================================================
Complement naive Bayes
________________________________________________________________________________
Training:
ComplementNB(alpha=0.1)
train time: 0.00109s
test time: 0.000638s
accuracy: 0.779
绘制每个分类器的准确率、训练时间和测试时间#
散点图显示了每个分类器的测试准确率与训练和测试时间之间的权衡。
indices = np.arange(len(results))
results = [[x[i] for x in results] for i in range(4)]
clf_names, score, training_time, test_time = results
training_time = np.array(training_time)
test_time = np.array(test_time)
fig, ax1 = plt.subplots(figsize=(10, 8))
ax1.scatter(score, training_time, s=60)
ax1.set(
title="Score-training time trade-off",
yscale="log",
xlabel="test accuracy",
ylabel="training time (s)",
)
fig, ax2 = plt.subplots(figsize=(10, 8))
ax2.scatter(score, test_time, s=60)
ax2.set(
title="Score-test time trade-off",
yscale="log",
xlabel="test accuracy",
ylabel="test time (s)",
)
for i, txt in enumerate(clf_names):
ax1.annotate(txt, (score[i], training_time[i]))
ax2.annotate(txt, (score[i], test_time[i]))
朴素贝叶斯模型在得分和训练/测试时间之间具有最佳的权衡,而随机森林训练缓慢,预测代价高且准确性相对较差。这是预料之中的:对于高维预测问题,线性模型通常更适合,因为当特征空间有10000维或更多时,大多数问题变得线性可分。
线性模型在训练速度和准确性上的差异可以通过它们优化的损失函数的选择以及使用的正则化类型来解释。请注意,某些具有相同损失但使用不同求解器或正则化配置的线性模型可能会产生不同的拟合时间和测试准确性。我们可以在第二个图上观察到,一旦训练完成,所有线性模型的预测速度大致相同,这是预期的,因为它们都实现了相同的预测函数。
KNeighborsClassifier 的准确率相对较低,并且测试时间最长。预测时间长也是预料之中的:对于每个预测,模型必须计算测试样本与训练集中每个文档之间的成对距离,这在计算上是昂贵的。此外,“维度灾难”会损害该模型在文本分类问题的高维特征空间中产生具有竞争力的准确性的能力。
Total running time of the script: (0 minutes 9.628 seconds)
Related examples