跳到主要内容

使用嵌入进行回归

nbviewer

回归意味着预测一个数字,而不是其中的一个类别。我们将基于评论文本的嵌入来预测评分。我们将数据集分割为训练集和测试集,以便在未见数据上实际评

import pandas as pd
import numpy as np
from ast import literal_eval

from sklearn.ensemble import RandomForestRegressor
from sklearn.model_selection import train_test_split
from sklearn.metrics import mean_squared_error, mean_absolute_error

datafile_path = "data/fine_food_reviews_with_embeddings_1k.csv"

df = pd.read_csv(datafile_path)
df["embedding"] = df.embedding.apply(literal_eval).apply(np.array)

X_train, X_test, y_train, y_test = train_test_split(list(df.embedding.values), df.Score, test_size=0.2, random_state=42)

rfr = RandomForestRegressor(n_estimators=100)
rfr.fit(X_train, y_train)
preds = rfr.predict(X_test)

mse = mean_squared_error(y_test, preds)
mae = mean_absolute_error(y_test, preds)

print(f"text-embedding-3-small performance on 1k Amazon reviews: mse={mse:.2f}, mae={mae:.2f}")


text-embedding-3-small performance on 1k Amazon reviews: mse=0.65, mae=0.52
bmse = mean_squared_error(y_test, np.repeat(y_test.mean(), len(y_test)))
bmae = mean_absolute_error(y_test, np.repeat(y_test.mean(), len(y_test)))
print(
f"Dummy mean prediction performance on Amazon reviews: mse={bmse:.2f}, mae={bmae:.2f}"
)


Dummy mean prediction performance on Amazon reviews: mse=1.73, mae=1.03

我们可以看到,嵌入能够以每个评分预测的平均误差为0.53。这大致相当于完美预测一半的评论,另一半的预测会偏离一个星级。

你也可以训练一个分类器来预测标签,或者在现有的机器学习模型中使用嵌入来编码自由文本特征。