从头开始的结构化数据分类
作者: fchollet
创建日期: 2020/06/09
最后修改: 2020/06/09
描述: 包含数值和分类特征的结构化数据的二元分类。
本示例演示了如何从原始CSV文件开始进行结构化数据分类。我们的数据包括数值和分类特征。我们将使用Keras预处理层来规范化数值特征并向量化分类特征。
请注意,此示例应在TensorFlow 2.5或更高版本中运行。
我们的数据集由克利夫兰诊所心脏病基金会提供。它是一个包含303行的CSV文件。每行包含有关患者的信息(一个样本),每列描述患者的一个属性(一个特征)。我们使用这些特征来预测患者是否有心脏病(二元分类)。
以下是每个特征的描述:
列 | 描述 | 特征类型 |
---|---|---|
Age | 年龄(以年为单位) | 数值 |
Sex | (1 = 男性; 0 = 女性) | 分类 |
CP | 胸痛类型(0, 1, 2, 3, 4) | 分类 |
Trestbpd | 休息时血压(以毫米汞柱为单位) | 数值 |
Chol | 血清胆固醇(毫克/分升) | 数值 |
FBS | 120毫克/分升的空腹血糖(1 = 真; 0 = 假) | 分类 |
RestECG | 休息心电图结果(0, 1, 2) | 分类 |
Thalach | 最大心率 | 数值 |
Exang | 运动诱发的心绞痛(1 = 是; 0 = 否) | 分类 |
Oldpeak | 运动相对于休息的ST抑郁 | 数值 |
Slope | 峰值运动ST段的斜率 | 数值 |
CA | 通过荧光检查染色的主要血管数量(0-3) | 数值与分类兼有 |
Thal | 3 = 正常; 6 = 固定缺陷; 7 = 可逆缺陷 | 分类 |
Target | 心脏病诊断(1 = 真; 0 = 假) | 目标 |
import os
# TensorFlow 是唯一支持字符串输入的后端。
os.environ["KERAS_BACKEND"] = "tensorflow"
import tensorflow as tf
import pandas as pd
import keras
from keras import layers
让我们下载数据并将其加载到 Pandas 数据框中:
file_url = "http://storage.googleapis.com/download.tensorflow.org/data/heart.csv"
dataframe = pd.read_csv(file_url)
该数据集包括 303 个样本,每个样本有 14 列(13 个特征,加上目标标签):
dataframe.shape
(303, 14)
这是一些样本的预览:
dataframe.head()
age | sex | cp | trestbps | chol | fbs | restecg | thalach | exang | oldpeak | slope | ca | thal | target | |
---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|
0 | 63 | 1 | 1 | 145 | 233 | 1 | 2 | 150 | 0 | 2.3 | 3 | 0 | fixed | 0 |
1 | 67 | 1 | 4 | 160 | 286 | 0 | 2 | 108 | 1 | 1.5 | 2 | 3 | normal | 1 |
2 | 67 | 1 | 4 | 120 | 229 | 0 | 2 | 129 | 1 | 2.6 | 2 | 2 | reversible | 0 |
3 | 37 | 1 | 3 | 130 | 250 | 0 | 0 | 187 | 0 | 3.5 | 3 | 0 | normal | 0 |
4 | 41 | 0 | 2 | 130 | 204 | 0 | 2 | 172 | 0 | 1.4 | 1 | 0 | normal | 0 |
最后一列“目标”表明患者是否患有心脏病(1)或未患有(0)。
让我们将数据拆分为训练集和验证集:
val_dataframe = dataframe.sample(frac=0.2, random_state=1337)
train_dataframe = dataframe.drop(val_dataframe.index)
print(
f"使用 {len(train_dataframe)} 个样本进行训练 "
f"和 {len(val_dataframe)} 个样本进行验证"
)
使用 242 个样本进行训练和 61 个样本进行验证
让我们为每个数据框生成 tf.data.Dataset
对象:
def dataframe_to_dataset(dataframe):
dataframe = dataframe.copy()
labels = dataframe.pop("target")
ds = tf.data.Dataset.from_tensor_slices((dict(dataframe), labels))
ds = ds.shuffle(buffer_size=len(dataframe))
return ds
train_ds = dataframe_to_dataset(train_dataframe)
val_ds = dataframe_to_dataset(val_dataframe)
每个 Dataset
返回一个元组 (input, target)
,其中 input
是特征的字典,target
是值 0
或 1
:
for x, y in train_ds.take(1):
print("输入:", x)
print("目标:", y)
输入: {'age': <tf.Tensor: shape=(), dtype=int64, numpy=64>, 'sex': <tf.Tensor: shape=(), dtype=int64, numpy=1>, 'cp': <tf.Tensor: shape=(), dtype=int64, numpy=4>, 'trestbps': <tf.Tensor: shape=(), dtype=int64, numpy=128>, 'chol': <tf.Tensor: shape=(), dtype=int64, numpy=263>, 'fbs': <tf.Tensor: shape=(), dtype=int64, numpy=0>, 'restecg': <tf.Tensor: shape=(), dtype=int64, numpy=0>, 'thalach': <tf.Tensor: shape=(), dtype=int64, numpy=105>, 'exang': <tf.Tensor: shape=(), dtype=int64, numpy=1>, 'oldpeak': <tf.Tensor: shape=(), dtype=float64, numpy=0.2>, 'slope': <tf.Tensor: shape=(), dtype=int64, numpy=2>, 'ca': <tf.Tensor: shape=(), dtype=int64, numpy=1>, 'thal': <tf.Tensor: shape=(), dtype=string, numpy=b'reversible'>}
目标: tf.Tensor(0, shape=(), dtype=int64)
让我们对数据集进行批处理:
train_ds = train_ds.batch(32)
val_ds = val_ds.batch(32)
以下特征是编码为整数的分类特征:
sex
cp
fbs
restecg
exang
ca
我们将使用 独热编码 编码这些特征。我们在这里有两个选择:
CategoryEncoding()
,需要知道输入值的范围,并且在范围外的输入上会出错。IntegerLookup()
,它将为输入构建查找表,并为未知输入值保留输出索引。对于这个例子,我们希望一个简单的解决方案,可以在推理时处理超出范围的输入,因此我们将使用 IntegerLookup()
。
我们还有一个编码为字符串的分类特征:thal
。我们将创建一个所有可能特征的索引,并使用 StringLookup()
层进行编码。
最后,以下特征是连续的数值特征:
age
trestbps
chol
thalach
oldpeak
slope
对于每个这些特征,我们将使用 Normalization()
层,以确保每个特征的均值为 0,标准差为 1。
下面,我们定义 2 个工具函数来执行操作:
encode_numerical_feature
用于对数值特征应用逐特征归一化。encode_categorical_feature
用于独热编码字符串或整数分类特征。def encode_numerical_feature(feature, name, dataset):
# 为我们的特征创建一个归一化层
normalizer = layers.Normalization()
# 准备一个仅提供我们特征的数据集
feature_ds = dataset.map(lambda x, y: x[name])
feature_ds = feature_ds.map(lambda x: tf.expand_dims(x, -1))
# 学习数据的统计信息
normalizer.adapt(feature_ds)
# 归一化输入特征
encoded_feature = normalizer(feature)
return encoded_feature
def encode_categorical_feature(feature, name, dataset, is_string):
lookup_class = layers.StringLookup if is_string else layers.IntegerLookup
# 创建一个查找层,将字符串转换为整数索引
lookup = lookup_class(output_mode="binary")
# 准备一个仅提供我们特征的数据集
feature_ds = dataset.map(lambda x, y: x[name])
feature_ds = feature_ds.map(lambda x: tf.expand_dims(x, -1))
# 学习可能的字符串值并为它们分配固定的整数索引
lookup.adapt(feature_ds)
# 将字符串输入转化为整数索引
encoded_feature = lookup(feature)
return encoded_feature
完成后,我们可以创建我们端到端的模型:
# 分类特征编码为整数
sex = keras.Input(shape=(1,), name="sex", dtype="int64")
cp = keras.Input(shape=(1,), name="cp", dtype="int64")
fbs = keras.Input(shape=(1,), name="fbs", dtype="int64")
restecg = keras.Input(shape=(1,), name="restecg", dtype="int64")
exang = keras.Input(shape=(1,), name="exang", dtype="int64")
ca = keras.Input(shape=(1,), name="ca", dtype="int64")
# 分类特征编码为字符串
thal = keras.Input(shape=(1,), name="thal", dtype="string")
# 数值特征
age = keras.Input(shape=(1,), name="age")
trestbps = keras.Input(shape=(1,), name="trestbps")
chol = keras.Input(shape=(1,), name="chol")
thalach = keras.Input(shape=(1,), name="thalach")
oldpeak = keras.Input(shape=(1,), name="oldpeak")
slope = keras.Input(shape=(1,), name="slope")
all_inputs = [
sex,
cp,
fbs,
restecg,
exang,
ca,
thal,
age,
trestbps,
chol,
thalach,
oldpeak,
slope,
]
# 整数分类特征
sex_encoded = encode_categorical_feature(sex, "sex", train_ds, False)
cp_encoded = encode_categorical_feature(cp, "cp", train_ds, False)
fbs_encoded = encode_categorical_feature(fbs, "fbs", train_ds, False)
restecg_encoded = encode_categorical_feature(restecg, "restecg", train_ds, False)
exang_encoded = encode_categorical_feature(exang, "exang", train_ds, False)
ca_encoded = encode_categorical_feature(ca, "ca", train_ds, False)
# 字符串分类特征
thal_encoded = encode_categorical_feature(thal, "thal", train_ds, True)
# 数值特征
age_encoded = encode_numerical_feature(age, "age", train_ds)
trestbps_encoded = encode_numerical_feature(trestbps, "trestbps", train_ds)
chol_encoded = encode_numerical_feature(chol, "chol", train_ds)
thalach_encoded = encode_numerical_feature(thalach, "thalach", train_ds)
oldpeak_encoded = encode_numerical_feature(oldpeak, "oldpeak", train_ds)
slope_encoded = encode_numerical_feature(slope, "slope", train_ds)
all_features = layers.concatenate(
[
sex_encoded,
cp_encoded,
fbs_encoded,
restecg_encoded,
exang_encoded,
slope_encoded,
ca_encoded,
thal_encoded,
age_encoded,
trestbps_encoded,
chol_encoded,
thalach_encoded,
oldpeak_encoded,
]
)
x = layers.Dense(32, activation="relu")(all_features)
x = layers.Dropout(0.5)(x)
output = layers.Dense(1, activation="sigmoid")(x)
model = keras.Model(all_inputs, output)
model.compile("adam", "binary_crossentropy", metrics=["accuracy"])
让我们可视化我们的连接图:
# `rankdir='LR'` 是为了使图形水平排列。
keras.utils.plot_model(model, show_shapes=True, rankdir="LR")
model.fit(train_ds, epochs=50, validation_data=val_ds)
Epoch 1/50
8/8 ━━━━━━━━━━━━━━━━━━━━ 5s 46ms/step - 准确率: 0.3932 - 损失: 0.8749 - 验证准确率: 0.3303 - 验证损失: 0.7814
Epoch 2/50
8/8 ━━━━━━━━━━━━━━━━━━━━ 1s 7ms/step - 准确率: 0.4262 - 损失: 0.8375 - 验证准确率: 0.4914 - 验证损失: 0.6980
Epoch 3/50
8/8 ━━━━━━━━━━━━━━━━━━━━ 0s 8ms/step - 准确率: 0.4835 - 损失: 0.7350 - 验证准确率: 0.6541 - 验证损失: 0.6320
Epoch 4/50
8/8 ━━━━━━━━━━━━━━━━━━━━ 0s 7ms/step - 准确率: 0.5932 - 损失: 0.6665 - 验证准确率: 0.7543 - 验证损失: 0.5743
Epoch 5/50
8/8 ━━━━━━━━━━━━━━━━━━━━ 0s 7ms/step - 准确率: 0.5861 - 损失: 0.6600 - 验证准确率: 0.7683 - 验证损失: 0.5360
Epoch 6/50
8/8 ━━━━━━━━━━━━━━━━━━━━ 0s 7ms/step - 准确率: 0.6489 - 损失: 0.6020 - 验证准确率: 0.7748 - 验证损失: 0.4998
Epoch 7/50
8/8 ━━━━━━━━━━━━━━━━━━━━ 0s 7ms/step - 准确率: 0.6880 - 损失: 0.5668 - 验证准确率: 0.7699 - 验证损失: 0.4800
Epoch 8/50
8/8 ━━━━━━━━━━━━━━━━━━━━ 0s 7ms/step - 准确率: 0.7572 - 损失: 0.5009 - 验证准确率: 0.7559 - 验证损失: 0.4573
Epoch 9/50
8/8 ━━━━━━━━━━━━━━━━━━━━ 0s 7ms/step - 准确率: 0.7492 - 损失: 0.5192 - 验证准确率: 0.8060 - 验证损失: 0.4414
Epoch 10/50
8/8 ━━━━━━━━━━━━━━━━━━━━ 0s 8ms/step - 准确率: 0.7212 - 损失: 0.4973 - 验证准确率: 0.8077 - 验证损失: 0.4259
Epoch 11/50
8/8 ━━━━━━━━━━━━━━━━━━━━ 0s 7ms/step - 准确率: 0.7616 - 损失: 0.4704 - 验证准确率: 0.7904 - 验证损失: 0.4143
Epoch 12/50
8/8 ━━━━━━━━━━━━━━━━━━━━ 0s 7ms/step - 准确率: 0.8374 - 损失: 0.4342 - 验证准确率: 0.7872 - 验证损失: 0.4061
Epoch 13/50
8/8 ━━━━━━━━━━━━━━━━━━━━ 0s 7ms/step - 准确率: 0.7863 - 损失: 0.4630 - 验证准确率: 0.7888 - 验证损失: 0.3980
Epoch 14/50
8/8 ━━━━━━━━━━━━━━━━━━━━ 0s 7ms/step - 准确率: 0.7742 - 损失: 0.4492 - 验证准确率: 0.7996 - 验证损失: 0.3998
Epoch 15/50
8/8 ━━━━━━━━━━━━━━━━━━━━ 0s 7ms/step - 准确率: 0.8083 - 损失: 0.4280 - 验证准确率: 0.8060 - 验证损失: 0.3855
Epoch 16/50
8/8 ━━━━━━━━━━━━━━━━━━━━ 0s 7ms/step - 准确率: 0.8058 - 损失: 0.4191 - 验证准确率: 0.8217 - 验证损失: 0.3819
Epoch 17/50
8/8 ━━━━━━━━━━━━━━━━━━━━ 0s 7ms/step - 准确率: 0.8071 - 损失: 0.4111 - 验证准确率: 0.8389 - 验证损失: 0.3763
Epoch 18/50
8/8 ━━━━━━━━━━━━━━━━━━━━ 0s 8ms/step - 准确率: 0.8533 - 损失: 0.3676 - 验证准确率: 0.8373 - 验证损失: 0.3792
Epoch 19/50
8/8 ━━━━━━━━━━━━━━━━━━━━ 0s 7ms/step - 准确率: 0.8170 - 损失: 0.3850 - 验证准确率: 0.8357 - 验证损失: 0.3744
Epoch 20/50
8/8 ━━━━━━━━━━━━━━━━━━━━ 0s 7ms/step - 准确率: 0.8207 - 损失: 0.3767 - 验证准确率: 0.8168 - 验证损失: 0.3759
Epoch 21/50
8/8 ━━━━━━━━━━━━━━━━━━━━ 0s 7ms/step - 准确率: 0.8151 - 损失: 0.3596 - 验证准确率: 0.8217 - 验证损失: 0.3685
Epoch 22/50
8/8 ━━━━━━━━━━━━━━━━━━━━ 0s 7ms/step - 准确率: 0.7988 - 损失: 0.4087 - 验证准确率: 0.8184 - 验证损失: 0.3701
Epoch 23/50
8/8 ━━━━━━━━━━━━━━━━━━━━ 0s 7ms/step - 准确率: 0.8180 - 损失: 0.3632 - 验证准确率: 0.8217 - 验证损失: 0.3614
Epoch 24/50
8/8 ━━━━━━━━━━━━━━━━━━━━ 0s 7ms/step - 准确率: 0.8295 - 损失: 0.3504 - 验证准确率: 0.8200 - 验证损失: 0.3683
Epoch 25/50
8/8 ━━━━━━━━━━━━━━━━━━━━ 0s 7ms/step - 准确率: 0.8386 - 损失: 0.3864 - 验证准确率: 0.8200 - 验证损失: 0.3655
Epoch 26/50
8/8 ━━━━━━━━━━━━━━━━━━━━ 0s 7ms/step - 准确率: 0.8482 - 损失: 0.3345 - 验证准确率: 0.8044 - 验证损失: 0.3639
Epoch 27/50
8/8 ━━━━━━━━━━━━━━━━━━━━ 0s 8ms/step - 准确率: 0.8340 - 损失: 0.3470 - 验证准确率: 0.8077 - 验证损失: 0.3616
Epoch 28/50
8/8 ━━━━━━━━━━━━━━━━━━━━ 0s 7ms/step - 准确率: 0.8418 - 损失: 0.3684 - 验证准确率: 0.8060 - 验证损失: 0.3629
Epoch 29/50
8/8 ━━━━━━━━━━━━━━━━━━━━ 0s 7ms/step - 准确率: 0.8309 - 损失: 0.3147 - 验证准确率: 0.8060 - 验证损失: 0.3637
Epoch 30/50
8/8 ━━━━━━━━━━━━━━━━━━━━ 0s 7ms/step - 准确率: 0.8722 - 损失: 0.3151 - 验证准确率: 0.8044 - 验证损失: 0.3672
Epoch 31/50
8/8 ━━━━━━━━━━━━━━━━━━━━ 0s 8ms/step - 准确率: 0.8746 - 损失: 0.3043 - 验证准确率: 0.8060 - 验证损失: 0.3637
Epoch 32/50
8/8 ━━━━━━━━━━━━━━━━━━━━ 0s 7ms/step - 准确率: 0.8794 - 损失: 0.3245 - 验证准确率: 0.8200 - 验证损失: 0.3685
Epoch 33/50
8/8 ━━━━━━━━━━━━━━━━━━━━ 0s 8ms/step - 准确率: 0.8644 - 损失: 0.3541 - 验证准确率: 0.8357 - 验证损失: 0.3714
Epoch 34/50
8/8 ━━━━━━━━━━━━━━━━━━━━ 0s 7ms/step - 准确率: 0.8867 - 损失: 0.3007 - 验证准确率: 0.8373 - 验证损失: 0.3680
Epoch 35/50
8/8 ━━━━━━━━━━━━━━━━━━━━ 0s 7ms/step - 准确率: 0.8737 - 损失: 0.3168 - 验证准确率: 0.8357 - 验证损失: 0.3695
Epoch 36/50
8/8 ━━━━━━━━━━━━━━━━━━━━ 0s 7ms/step - 准确率: 0.8191 - 损失: 0.3298 - 验证准确率: 0.8357 - 验证损失: 0.3736
Epoch 37/50
8/8 ━━━━━━━━━━━━━━━━━━━━ 0s 7ms/step - 准确率: 0.8613 - 损失: 0.3543 - 验证准确率: 0.8357 - 验证损失: 0.3745
Epoch 38/50
8/8 ━━━━━━━━━━━━━━━━━━━━ 0s 7ms/step - 准确率: 0.8835 - 损失: 0.2835 - 验证准确率: 0.8357 - 验证损失: 0.3707
Epoch 39/50
8/8 ━━━━━━━━━━━━━━━━━━━━ 0s 7ms/step - 准确率: 0.8784 - 损失: 0.2893 - 验证准确率: 0.8357 - 验证损失: 0.3716
Epoch 40/50
8/8 ━━━━━━━━━━━━━━━━━━━━ 0s 7ms/step - 准确率: 0.8919 - 损失: 0.2587 - 验证准确率: 0.8168 - 验证损失: 0.3770
Epoch 41/50
8/8 ━━━━━━━━━━━━━━━━━━━━ 0s 7ms/step - 准确率: 0.8882 - 损失: 0.2660 - 验证准确率: 0.8217 - 验证损失: 0.3674
Epoch 42/50
8/8 ━━━━━━━━━━━━━━━━━━━━ 0s 7ms/step - 准确率: 0.8790 - 损失: 0.2931 - 验证准确率: 0.8200 - 验证损失: 0.3723
Epoch 43/50
8/8 ━━━━━━━━━━━━━━━━━━━━ 0s 7ms/step - 准确率: 0.8851 - 损失: 0.2892 - 验证准确率: 0.8200 - 验证损失: 0.3733
Epoch 44/50
8/8 ━━━━━━━━━━━━━━━━━━━━ 0s 7ms/step - 准确率: 0.8504 - 损失: 0.3189 - 验证准确率: 0.8200 - 验证损失: 0.3755
Epoch 45/50
8/8 ━━━━━━━━━━━━━━━━━━━━ 0s 7ms/step - 准确率: 0.8610 - 损失: 0.3116 - 验证准确率: 0.8184 - 验证损失: 0.3788
Epoch 46/50
8/8 ━━━━━━━━━━━━━━━━━━━━ 0s 8ms/step - 准确率: 0.8956 - 损失: 0.2544 - 验证准确率: 0.8184 - 验证损失: 0.3738
Epoch 47/50
8/8 ━━━━━━━━━━━━━━━━━━━━ 0s 7ms/step - 准确率: 0.9080 - 损失: 0.2895 - 验证准确率: 0.8217 - 验证损失: 0.3750
Epoch 48/50
8/8 ━━━━━━━━━━━━━━━━━━━━ 0s 7ms/step - 准确率: 0.8706 - 损失: 0.2993 - 验证准确率: 0.8217 - 验证损失: 0.3757
Epoch 49/50
8/8 ━━━━━━━━━━━━━━━━━━━━ 0s 7ms/step - 准确率: 0.8724 - 损失: 0.2979 - 验证准确率: 0.8184 - 验证损失: 0.3781
Epoch 50/50
8/8 ━━━━━━━━━━━━━━━━━━━━ 0s 8ms/step - 准确率: 0.8609 - 损失: 0.2937 - 验证准确率: 0.8217 - 验证损失: 0.3791
<keras.src.callbacks.history.History at 0x7efc32e01780>
要对新样本进行预测,您只需调用 model.predict()
。您只需做两件事:
convert_to_tensor
sample = {
"age": 60,
"sex": 1,
"cp": 1,
"trestbps": 145,
"chol": 233,
"fbs": 1,
"restecg": 2,
"thalach": 150,
"exang": 0,
"oldpeak": 2.3,
"slope": 3,
"ca": 0,
"thal": "fixed",
}
input_dict = {name: tf.convert_to_tensor([value]) for name, value in sample.items()}
predictions = model.predict(input_dict)
print(
f"该患者患心脏病的概率为 {100 * predictions[0][0]:.1f} "
"百分比,由我们的模型评估得出。"
)
1/1 ━━━━━━━━━━━━━━━━━━━━ 0s 252ms/step
该患者患心脏病的概率为27.6百分比,由我们的模型评估得出。