跳到主要内容

使用嵌入和最近邻搜索进行推荐

nbviewer

推荐在网络上随处可见。

  • ‘购买了这件物品?试试这些类似的物品。’
  • ‘喜欢那本书?试试这些类似的标题。’
  • ‘没有找到您要找的帮助页面?试试这些类似的页面。’

本笔记演示了如何使用嵌入来找到类似的物品进行推荐。具体来说,我们使用AG新闻文章语料库作为我们的数据集。

我们的模型将回答这个问题:给定一篇文章,哪些其他文章与之最相似?

import pandas as pd
import pickle

from utils.embeddings_utils import (
get_embedding,
distances_from_embeddings,
tsne_components_from_embeddings,
chart_from_components,
indices_of_nearest_neighbors_from_distances,
)

EMBEDDING_MODEL = "text-embedding-3-small"


2. 加载数据

接下来,让我们加载AG新闻数据并查看其样子。

# 加载数据(完整数据集可在此处获取:http://groups.di.unipi.it/~gulli/AG_corpus_of_news_articles.html)
dataset_path = "data/AG_news_samples.csv"
df = pd.read_csv(dataset_path)

n_examples = 5
df.head(n_examples)


title description label_int label
0 World Briefings BRITAIN: BLAIR WARNS OF CLIMATE THREAT Prime M... 1 World
1 Nvidia Puts a Firewall on a Motherboard (PC Wo... PC World - Upcoming chip set will include buil... 4 Sci/Tech
2 Olympic joy in Greek, Chinese press Newspapers in Greece reflect a mixture of exhi... 2 Sports
3 U2 Can iPod with Pictures SAN JOSE, Calif. -- Apple Computer (Quote, Cha... 4 Sci/Tech
4 The Dream Factory Any product, any shape, any size -- manufactur... 4 Sci/Tech

让我们来看看那些相同的例子,但不要被省略号截断。

# 打印每个示例的标题、描述和标签。
for idx, row in df.head(n_examples).iterrows():
print("")
print(f"Title: {row['title']}")
print(f"Description: {row['description']}")
print(f"Label: {row['label']}")



Title: World Briefings
Description: BRITAIN: BLAIR WARNS OF CLIMATE THREAT Prime Minister Tony Blair urged the international community to consider global warming a dire threat and agree on a plan of action to curb the quot;alarming quot; growth of greenhouse gases.
Label: World

Title: Nvidia Puts a Firewall on a Motherboard (PC World)
Description: PC World - Upcoming chip set will include built-in security features for your PC.
Label: Sci/Tech

Title: Olympic joy in Greek, Chinese press
Description: Newspapers in Greece reflect a mixture of exhilaration that the Athens Olympics proved successful, and relief that they passed off without any major setback.
Label: Sports

Title: U2 Can iPod with Pictures
Description: SAN JOSE, Calif. -- Apple Computer (Quote, Chart) unveiled a batch of new iPods, iTunes software and promos designed to keep it atop the heap of digital music players.
Label: Sci/Tech

Title: The Dream Factory
Description: Any product, any shape, any size -- manufactured on your desktop! The future is the fabricator. By Bruce Sterling from Wired magazine.
Label: Sci/Tech

3. 构建缓存以保存嵌入向量

在获取这些文章的嵌入向量之前,让我们设置一个缓存来保存我们生成的嵌入向量。一般来说,最好保存你的嵌入向量,这样你以后可以重新使用它们。如果你不保存它们,每次重新计算时都会再次付出代价。

缓存是一个字典,将(text, model)元组映射到一个嵌入向量,即一个浮点数列表。缓存以Python pickle文件的形式保存。

# 建立一个嵌入缓存,以避免重复计算
# 缓存是一个字典,键为元组(文本, 模型),值为嵌入向量,并以 pickle 文件格式保存。

# 设置嵌入缓存的路径
embedding_cache_path = "data/recommendations_embeddings_cache.pkl"

# 如果缓存存在,则加载它,并将其副本保存到磁盘上。
try:
embedding_cache = pd.read_pickle(embedding_cache_path)
except FileNotFoundError:
embedding_cache = {}
with open(embedding_cache_path, "wb") as embedding_cache_file:
pickle.dump(embedding_cache, embedding_cache_file)

# 定义一个函数,用于从缓存中检索嵌入向量(如果存在),否则通过API请求获取。
def embedding_from_string(
string: str,
model: str = EMBEDDING_MODEL,
embedding_cache=embedding_cache
) -> list:
"""返回给定字符串的嵌入表示,利用缓存机制避免重复计算。"""
if (string, model) not in embedding_cache.keys():
embedding_cache[(string, model)] = get_embedding(string, model)
with open(embedding_cache_path, "wb") as embedding_cache_file:
pickle.dump(embedding_cache, embedding_cache_file)
return embedding_cache[(string, model)]


让我们通过获取一个嵌入向量来检查它是否有效。

# 举个例子,让我们看看数据集中的第一个描述。
example_string = df["description"].values[0]
print(f"\nExample string: {example_string}")

# 打印嵌入的前10个维度
example_embedding = embedding_from_string(example_string)
print(f"\nExample embedding: {example_embedding[:10]}...")



Example string: BRITAIN: BLAIR WARNS OF CLIMATE THREAT Prime Minister Tony Blair urged the international community to consider global warming a dire threat and agree on a plan of action to curb the quot;alarming quot; growth of greenhouse gases.

Example embedding: [0.0545826330780983, -0.00428084097802639, 0.04785159230232239, 0.01587914116680622, -0.03640881925821304, 0.0143799539655447, -0.014267769642174244, -0.015175441280007362, -0.002344391541555524, 0.011075624264776707]...

4. 基于嵌入推荐相似文章

为了找到相似的文章,让我们按照以下三个步骤进行操作: 1. 获取所有文章描述的相似度嵌入 2. 计算源标题与所有其他文章之间的距离 3. 打印出与源标题最接近的其他文章

def print_recommendations_from_strings(
strings: list[str],
index_of_source_string: int,
k_nearest_neighbors: int = 1,
model=EMBEDDING_MODEL,
) -> list[int]:
"""打印出给定字符串的k个最近邻。"""
# 获取所有字符串的嵌入表示
embeddings = [embedding_from_string(string, model=model) for string in strings]

# 获取源字符串的嵌入表示
query_embedding = embeddings[index_of_source_string]

# 获取源嵌入与其他嵌入之间的距离(来自 utils.embeddings_utils.py 的函数)
distances = distances_from_embeddings(query_embedding, embeddings, distance_metric="cosine")

# 获取最近邻索引(来自 utils.utils.embeddings_utils.py 的函数)
indices_of_nearest_neighbors = indices_of_nearest_neighbors_from_distances(distances)

# 打印出源字符串
query_string = strings[index_of_source_string]
print(f"Source string: {query_string}")
# 打印出其k个最近邻
k_counter = 0
for i in indices_of_nearest_neighbors:
# 跳过与起始字符串完全相同的任何字符串。
if query_string == strings[i]:
continue
# 在打印出k篇文章后停止
if k_counter >= k_nearest_neighbors:
break
k_counter += 1

# 打印出相似的字符串及其距离
print(
f"""
--- 推荐 #{k_counter}(最近邻 {k_counter}{k_nearest_neighbors} 中的第 {k_counter} 个最近邻) ---
字符串:{strings[i]}
距离:{distances[i]:0.3f}"""
)

return indices_of_nearest_neighbors


5. 示例推荐

让我们寻找与第一篇文章相似的文章,第一篇文章是关于托尼·布莱尔的。

article_descriptions = df["description"].tolist()

tony_blair_articles = print_recommendations_from_strings(
strings=article_descriptions, # 让我们根据文章描述来判断相似度。
index_of_source_string=0, # 与第一篇关于托尼·布莱尔的文章相似的文章
k_nearest_neighbors=5, # 5篇最相似的文章
)


Source string: BRITAIN: BLAIR WARNS OF CLIMATE THREAT Prime Minister Tony Blair urged the international community to consider global warming a dire threat and agree on a plan of action to curb the  quot;alarming quot; growth of greenhouse gases.

--- Recommendation #1 (nearest neighbor 1 of 5) ---
String: The anguish of hostage Kenneth Bigley in Iraq hangs over Prime Minister Tony Blair today as he faces the twin test of a local election and a debate by his Labour Party about the divisive war.
Distance: 0.514

--- Recommendation #2 (nearest neighbor 2 of 5) ---
String: THE re-election of British Prime Minister Tony Blair would be seen as an endorsement of the military action in Iraq, Prime Minister John Howard said today.
Distance: 0.516

--- Recommendation #3 (nearest neighbor 3 of 5) ---
String: Israel is prepared to back a Middle East conference convened by Tony Blair early next year despite having expressed fears that the British plans were over-ambitious and designed
Distance: 0.546

--- Recommendation #4 (nearest neighbor 4 of 5) ---
String: Allowing dozens of casinos to be built in the UK would bring investment and thousands of jobs, Tony Blair says.
Distance: 0.568

--- Recommendation #5 (nearest neighbor 5 of 5) ---
String: AFP - A battle group of British troops rolled out of southern Iraq on a US-requested mission to deadlier areas near Baghdad, in a major political gamble for British Prime Minister Tony Blair.
Distance: 0.579

相当不错!5个推荐中有4个明确提到了托尼·布莱尔,第五个是一篇关于伦敦气候变化的文章,这些主题可能经常与托尼·布莱尔联系在一起。

让我们看看我们的推荐系统在第二个关于NVIDIA新芯片组更安全性的文章上的表现。

chipset_security_articles = print_recommendations_from_strings(
strings=article_descriptions, # let's base similarity off of the article description
index_of_source_string=1, # let's look at articles similar to the second one about a more secure chipset
k_nearest_neighbors=5, # 让我们来看看最相似的5篇文章。
)


Source string: PC World - Upcoming chip set will include built-in security features for your PC.

--- Recommendation #1 (nearest neighbor 1 of 5) ---
String: PC World - Updated antivirus software for businesses adds intrusion prevention features.
Distance: 0.422

--- Recommendation #2 (nearest neighbor 2 of 5) ---
String: PC World - Symantec, McAfee hope raising virus-definition fees will move users to\ suites.
Distance: 0.518

--- Recommendation #3 (nearest neighbor 3 of 5) ---
String: originally offered on notebook PCs -- to its Opteron 32- and 64-bit x86 processors for server applications. The technology will help servers to run
Distance: 0.522

--- Recommendation #4 (nearest neighbor 4 of 5) ---
String: PC World - Send your video throughout your house--wirelessly--with new gateways and media adapters.
Distance: 0.532

--- Recommendation #5 (nearest neighbor 5 of 5) ---
String: Chips that help a computer's main microprocessors perform specific types of math problems are becoming a big business once again.\
Distance: 0.532

从打印出的距离可以看出,第一推荐比其他所有推荐都要接近得多(0.11 vs 0.14+)。而且第一推荐看起来与起始文章非常相似 - 它是PC World关于提高计算机安全性的另一篇文章。相当不错!

附录:在更复杂的推荐系统中使用嵌入向量

构建一个更复杂的推荐系统的一种方法是训练一个机器学习模型,该模型接收数十甚至数百个信号,比如物品流行度或用户点击数据。即使在这种系统中,嵌入向量仍然可以作为一个非常有用的信号输入到推荐系统中,特别是对于那些还没有用户数据的物品(例如,目录中新增加的全新产品,还没有任何点击记录)。

附录:使用嵌入来可视化相似文章

在这个附录中,我们将展示如何使用嵌入来可视化相似的文章。

为了了解我们的最近邻推荐系统在做什么,让我们可视化文章的嵌入。虽然我们无法绘制每个嵌入向量的2048个维度,但我们可以使用诸如t-SNEPCA之类的技术,将嵌入压缩到2或3个维度,以便进行图表化。

在可视化最近邻之前,让我们使用t-SNE可视化所有文章描述。请注意,t-SNE不是确定性的,这意味着结果可能因运行而异。

# 获取所有文章描述的嵌入向量
embeddings = [embedding_from_string(string) for string in article_descriptions]
# 使用t-SNE将2048维的嵌入压缩成2维
tsne_components = tsne_components_from_embeddings(embeddings)
# 获取用于图表着色的文章标签
labels = df["label"].tolist()

chart_from_components(
components=tsne_components,
labels=labels,
strings=article_descriptions,
width=600,
height=500,
title="t-SNE components of article descriptions",
)


Unable to display output for mime type(s): application/vnd.plotly.v1+json

正如上面的图表所示,即使是高度压缩的嵌入也能很好地按类别对文章描述进行聚类。值得强调的是:这种聚类是在没有关于标签本身的任何知识的情况下完成的!

此外,如果你仔细观察最严重的异常值,它们通常是由于错误标记而不是嵌入质量不佳造成的。例如,绿色体育聚类中大多数蓝色世界新闻点似乎是体育新闻。

接下来,让我们根据这些点是源文章、最近邻居还是其他点来重新着色。

# 为推荐文章创建标签
def nearest_neighbor_labels(
list_of_indices: list[int],
k_nearest_neighbors: int = 5
) -> list[str]:
"""返回一个标签列表,用于为k个最近邻着色。"""
labels = ["Other" for _ in list_of_indices]
source_index = list_of_indices[0]
labels[source_index] = "Source"
for i in range(k_nearest_neighbors):
nearest_neighbor_index = list_of_indices[i + 1]
labels[nearest_neighbor_index] = f"Nearest neighbor (top {k_nearest_neighbors})"
return labels


tony_blair_labels = nearest_neighbor_labels(tony_blair_articles, k_nearest_neighbors=5)
chipset_security_labels = nearest_neighbor_labels(chipset_security_articles, k_nearest_neighbors=5
)


# 托尼·布莱尔文章的最近邻二维图表
chart_from_components(
components=tsne_components,
labels=tony_blair_labels,
strings=article_descriptions,
width=600,
height=500,
title="Nearest neighbors of the Tony Blair article",
category_orders={"label": ["Other", "Nearest neighbor (top 5)", "Source"]},
)


Unable to display output for mime type(s): application/vnd.plotly.v1+json

从上面的二维图表中可以看出,关于托尼·布莱尔的文章在“世界新闻”聚类中相对较接近。有趣的是,尽管在高维空间中最接近的5个最近邻(红色)在这个压缩的二维空间中并不是最接近的点。将嵌入压缩到2维会丢失大部分信息,而在二维空间中的最近邻似乎不像完整嵌入空间中的那些那么相关。

# 芯片组安全文章的二维最近邻图
chart_from_components(
components=tsne_components,
labels=chipset_security_labels,
strings=article_descriptions,
width=600,
height=500,
title="Nearest neighbors of the chipset security article",
category_orders={"label": ["Other", "Nearest neighbor (top 5)", "Source"]},
)


Unable to display output for mime type(s): application/vnd.plotly.v1+json

在芯片组安全示例中,在完整嵌入空间中,距离最近的4个最近邻在这个压缩的2D可视化中仍然是最近邻。第五个被显示为更远,尽管在完整嵌入空间中更接近。

如果你想的话,你也可以使用函数chart_from_components_3D制作一个交互式的3D图来展示嵌入向量。(这将需要使用n_components=3重新计算t-SNE的组件。)