Note
Go to the end to download the full example code. or to run this example in your browser via Binder
具有异构数据源的列转换器#
数据集通常包含需要不同特征提取和处理管道的组件。以下情况可能会出现这种情况:
您的数据集由异构数据类型组成(例如栅格图像和文本标题),
您的数据集存储在
pandas.DataFrame
中,不同的列需要不同的处理管道。
此示例演示了如何在包含不同类型特征的数据集上使用
ColumnTransformer
。特征的选择并不是特别有用,但用于说明该技术。
# 作者:scikit-learn 开发者
# SPDX-License-Identifier: BSD-3-Clause
import numpy as np
from sklearn.compose import ColumnTransformer
from sklearn.datasets import fetch_20newsgroups
from sklearn.decomposition import PCA
from sklearn.feature_extraction import DictVectorizer
from sklearn.feature_extraction.text import TfidfVectorizer
from sklearn.metrics import classification_report
from sklearn.pipeline import Pipeline
from sklearn.preprocessing import FunctionTransformer
from sklearn.svm import LinearSVC
20个新闻组数据集#
我们将使用:ref:20个新闻组数据集 <20newsgroups_dataset>
,该数据集包含来自20个主题的新闻组的帖子。这个数据集根据特定日期前后的消息分为训练集和测试集。为了加快运行时间,我们将只使用来自2个类别的帖子。
categories = ["sci.med", "sci.space"]
X_train, y_train = fetch_20newsgroups(
random_state=1,
subset="train",
categories=categories,
remove=("footers", "quotes"),
return_X_y=True,
)
X_test, y_test = fetch_20newsgroups(
random_state=1,
subset="test",
categories=categories,
remove=("footers", "quotes"),
return_X_y=True,
)
每个特征包含关于该帖子的元信息,例如主题和新闻帖子的正文。
print(X_train[0])
From: mccall@mksol.dseg.ti.com (fred j mccall 575-3539)
Subject: Re: Metric vs English
Article-I.D.: mksol.1993Apr6.131900.8407
Organization: Texas Instruments Inc
Lines: 31
American, perhaps, but nothing military about it. I learned (mostly)
slugs when we talked English units in high school physics and while
the teacher was an ex-Navy fighter jock the book certainly wasn't
produced by the military.
[Poundals were just too flinking small and made the math come out
funny; sort of the same reason proponents of SI give for using that.]
--
"Insisting on perfect safety is for people who don't have the balls to live
in the real world." -- Mary Shafer, NASA Ames Dryden
创建转换器#
首先,我们需要一个转换器来提取每个帖子的主题和正文。由于这是一个无状态的转换(不需要训练数据的状态信息),我们可以定义一个执行数据转换的函数,然后使用 FunctionTransformer
来创建一个 scikit-learn 转换器。
def subject_body_extractor(posts):
# 构造一个具有两列的对象类型数组
# 第一列 = 'subject',第二列 = 'body'
features = np.empty(shape=(len(posts), 2), dtype=object)
for i, text in enumerate(posts):
# 临时变量 `_` 存储 '\n\n'
headers, _, body = text.partition("\n\n")
# 将正文存储在第二列
features[i, 1] = body
prefix = "Subject:"
sub = ""
# 在第一列中保存 'Subject:' 之后的文本
for line in headers.split("\n"):
if line.startswith(prefix):
sub = line[len(prefix) :]
break
features[i, 0] = sub
return features
subject_body_transformer = FunctionTransformer(subject_body_extractor)
我们还将创建一个转换器,用于提取文本的长度和句子的数量。
def text_stats(posts):
return [{"length": len(text), "num_sentences": text.count(".")} for text in posts]
text_stats_transformer = FunctionTransformer(text_stats)
分类管道#
下面的流水线使用 SubjectBodyExtractor
从每个帖子中提取主题和正文,生成一个(n_samples, 2)的数组。然后使用 ColumnTransformer
对该数组计算主题和正文的标准词袋特征,以及正文的文本长度和句子数量。我们将它们加权组合,然后在组合后的特征集上训练一个分类器。
pipeline = Pipeline(
[
# 提取主题和正文
("subjectbody", subject_body_transformer),
# 使用ColumnTransformer来组合主题和正文特征
(
"union",
ColumnTransformer(
[
# bag-of-words for subject (col 0)
("subject", TfidfVectorizer(min_df=50), 0),
# 词袋模型与分解用于主体(第1列)
(
"body_bow",
Pipeline(
[
("tfidf", TfidfVectorizer()),
("best", PCA(n_components=50, svd_solver="arpack")),
]
),
1,
),
# 从帖子的正文中提取文本统计信息的管道
(
"body_stats",
Pipeline(
[
(
"stats",
text_stats_transformer,
), # returns a list of dicts
(
"vect",
DictVectorizer(),
), # list of dicts -> feature matrix
]
),
1,
),
],
# 权重高于ColumnTransformer特征
transformer_weights={
"subject": 0.8,
"body_bow": 0.5,
"body_stats": 1.0,
},
),
),
# 使用SVC分类器对组合特征进行分类
("svc", LinearSVC(dual=False)),
],
verbose=True,
)
最后,我们在训练数据上拟合我们的管道,并使用它来预测 X_test
的主题。然后打印我们管道的性能指标。
pipeline.fit(X_train, y_train)
y_pred = pipeline.predict(X_test)
print("Classification report:\n\n{}".format(classification_report(y_test, y_pred)))
[Pipeline] ....... (step 1 of 3) Processing subjectbody, total= 0.0s
[Pipeline] ............. (step 2 of 3) Processing union, total= 0.5s
[Pipeline] ............... (step 3 of 3) Processing svc, total= 0.0s
Classification report:
precision recall f1-score support
0 0.84 0.88 0.86 396
1 0.87 0.84 0.85 394
accuracy 0.86 790
macro avg 0.86 0.86 0.86 790
weighted avg 0.86 0.86 0.86 790
Total running time of the script: (0 minutes 1.857 seconds)
Related examples