GENERATED FROM PYTHON SOURCE LINES 15-31 .. code-block:: Python # 作者:scikit-learn 开发者 # SPDX-License-Identifier: BSD-3-Clause import numpy as np from sklearn.compose import ColumnTransformer from sklearn.datasets import fetch_20newsgroups from sklearn.decomposition import PCA from sklearn.feature_extraction import DictVectorizer from sklearn.feature_extraction.text import TfidfVectorizer from sklearn.metrics import classification_report from sklearn.pipeline import Pipeline from sklearn.preprocessing import FunctionTransformer from sklearn.svm import LinearSVC .. GENERATED FROM PYTHON SOURCE LINES 32-36 20个新闻组数据集 --------------------- 我们将使用:ref:`20个新闻组数据集 <20newsgroups_dataset>` ,该数据集包含来自20个主题的新闻组的帖子。这个数据集根据特定日期前后的消息分为训练集和测试集。为了加快运行时间,我们将只使用来自2个类别的帖子。 .. GENERATED FROM PYTHON SOURCE LINES 36-53 .. code-block:: Python categories = ["sci.med", "sci.space"] X_train, y_train = fetch_20newsgroups( random_state=1, subset="train", categories=categories, remove=("footers", "quotes"), return_X_y=True, ) X_test, y_test = fetch_20newsgroups( random_state=1, subset="test", categories=categories, remove=("footers", "quotes"), return_X_y=True, ) .. GENERATED FROM PYTHON SOURCE LINES 54-55 每个特征包含关于该帖子的元信息,例如主题和新闻帖子的正文。 .. GENERATED FROM PYTHON SOURCE LINES 55-58 .. code-block:: Python print(X_train[0]) .. rst-class:: sphx-glr-script-out .. code-block:: none From: mccall@mksol.dseg.ti.com (fred j mccall 575-3539) Subject: Re: Metric vs English Article-I.D.: mksol.1993Apr6.131900.8407 Organization: Texas Instruments Inc Lines: 31 American, perhaps, but nothing military about it. I learned (mostly) slugs when we talked English units in high school physics and while the teacher was an ex-Navy fighter jock the book certainly wasn't produced by the military. [Poundals were just too flinking small and made the math come out funny; sort of the same reason proponents of SI give for using that.] -- "Insisting on perfect safety is for people who don't have the balls to live in the real world." -- Mary Shafer, NASA Ames Dryden .. GENERATED FROM PYTHON SOURCE LINES 59-63 创建转换器 --------------------- 首先,我们需要一个转换器来提取每个帖子的主题和正文。由于这是一个无状态的转换(不需要训练数据的状态信息),我们可以定义一个执行数据转换的函数,然后使用 :class:`~sklearn.preprocessing.FunctionTransformer` 来创建一个 scikit-learn 转换器。 .. GENERATED FROM PYTHON SOURCE LINES 63-89 .. code-block:: Python def subject_body_extractor(posts): # 构造一个具有两列的对象类型数组 # 第一列 = 'subject',第二列 = 'body' features = np.empty(shape=(len(posts), 2), dtype=object) for i, text in enumerate(posts): # 临时变量 `_` 存储 '\n\n' headers, _, body = text.partition("\n\n") # 将正文存储在第二列 features[i, 1] = body prefix = "Subject:" sub = "" # 在第一列中保存 'Subject:' 之后的文本 for line in headers.split("\n"): if line.startswith(prefix): sub = line[len(prefix) :] break features[i, 0] = sub return features subject_body_transformer = FunctionTransformer(subject_body_extractor) .. GENERATED FROM PYTHON SOURCE LINES 90-91 我们还将创建一个转换器,用于提取文本的长度和句子的数量。 .. GENERATED FROM PYTHON SOURCE LINES 91-99 .. code-block:: Python def text_stats(posts): return [{"length": len(text), "num_sentences": text.count(".")} for text in posts] text_stats_transformer = FunctionTransformer(text_stats) .. GENERATED FROM PYTHON SOURCE LINES 100-105 分类管道 ----------------------- 下面的流水线使用 ``SubjectBodyExtractor`` 从每个帖子中提取主题和正文,生成一个(n_samples, 2)的数组。然后使用 ``ColumnTransformer`` 对该数组计算主题和正文的标准词袋特征,以及正文的文本长度和句子数量。我们将它们加权组合,然后在组合后的特征集上训练一个分类器。 .. GENERATED FROM PYTHON SOURCE LINES 106-161 .. code-block:: Python pipeline = Pipeline( [ # 提取主题和正文 ("subjectbody", subject_body_transformer), # 使用ColumnTransformer来组合主题和正文特征 ( "union", ColumnTransformer( [ # bag-of-words for subject (col 0) ("subject", TfidfVectorizer(min_df=50), 0), # 词袋模型与分解用于主体(第1列) ( "body_bow", Pipeline( [ ("tfidf", TfidfVectorizer()), ("best", PCA(n_components=50, svd_solver="arpack")), ] ), 1, ), # 从帖子的正文中提取文本统计信息的管道 ( "body_stats", Pipeline( [ ( "stats", text_stats_transformer, ), # returns a list of dicts ( "vect", DictVectorizer(), ), # list of dicts -> feature matrix ] ), 1, ), ], # 权重高于ColumnTransformer特征 transformer_weights={ "subject": 0.8, "body_bow": 0.5, "body_stats": 1.0, }, ), ), # 使用SVC分类器对组合特征进行分类 ("svc", LinearSVC(dual=False)), ], verbose=True, ) .. GENERATED FROM PYTHON SOURCE LINES 162-163 最后,我们在训练数据上拟合我们的管道,并使用它来预测 ``X_test`` 的主题。然后打印我们管道的性能指标。 .. GENERATED FROM PYTHON SOURCE LINES 163-167 .. code-block:: Python pipeline.fit(X_train, y_train) y_pred = pipeline.predict(X_test) print("Classification report:\n\n{}".format(classification_report(y_test, y_pred))) .. rst-class:: sphx-glr-script-out .. code-block:: none [Pipeline] ....... (step 1 of 3) Processing subjectbody, total= 0.0s [Pipeline] ............. (step 2 of 3) Processing union, total= 0.5s [Pipeline] ............... (step 3 of 3) Processing svc, total= 0.0s
Classification report:

              precision    recall  f1-score   support

           0       0.84      0.88      0.86       396
           1       0.87      0.84      0.85       394

    accuracy                           0.86       790
   macro avg       0.86      0.86      0.86       790
weighted avg       0.86      0.86      0.86       790