"""模块包含一个支持缓存的嵌入器的代码。
支持缓存的嵌入器是一个包装在一个嵌入器周围的功能,它将嵌入缓存在一个键值存储中。缓存用于避免为相同文本重新计算嵌入。
文本被哈希化,哈希被用作缓存中的键。
"""
from __future__ import annotations
import hashlib
import json
import uuid
from functools import partial
from typing import Callable, List, Optional, Sequence, Union, cast
from langchain_core.embeddings import Embeddings
from langchain_core.stores import BaseStore, ByteStore
from langchain_core.utils.iter import batch_iterate
from langchain.storage.encoder_backed import EncoderBackedStore
NAMESPACE_UUID = uuid.UUID(int=1985)
def _hash_string_to_uuid(input_string: str) -> uuid.UUID:
"""对字符串进行哈希处理,并返回相应的UUID。"""
hash_value = hashlib.sha1(input_string.encode("utf-8")).hexdigest()
return uuid.uuid5(NAMESPACE_UUID, hash_value)
def _key_encoder(key: str, namespace: str) -> str:
"""编码一个密钥。"""
return namespace + str(_hash_string_to_uuid(key))
def _create_key_encoder(namespace: str) -> Callable[[str], str]:
"""为密钥创建一个编码器。"""
return partial(_key_encoder, namespace=namespace)
def _value_serializer(value: Sequence[float]) -> bytes:
"""序列化一个数值。"""
return json.dumps(value).encode()
def _value_deserializer(serialized_value: bytes) -> List[float]:
"""反序列化一个值。"""
return cast(List[float], json.loads(serialized_value.decode()))
[docs]class CacheBackedEmbeddings(Embeddings):
"""用于缓存嵌入模型结果的接口。
该接口允许与实现接受类型为str的键和浮点数列表值的抽象存储接口一起使用。
如果需要,接口可以扩展以接受其他值序列化器和反序列化器的实现,以及键编码器。
请注意,默认情况下仅缓存文档嵌入。要缓存查询嵌入,也可以将query_embedding_store传递给构造函数。
示例:
```python
from langchain.embeddings import CacheBackedEmbeddings
from langchain.storage import LocalFileStore
from langchain_community.embeddings import OpenAIEmbeddings
store = LocalFileStore('./my_cache')
underlying_embedder = OpenAIEmbeddings()
embedder = CacheBackedEmbeddings.from_bytes_store(
underlying_embedder, store, namespace=underlying_embedder.model
)
# 嵌入被计算并缓存
embeddings = embedder.embed_documents(["hello", "goodbye"])
# 从缓存中检索嵌入,不进行计算
embeddings = embedder.embed_documents(["hello", "goodbye"])
```"""
[docs] def __init__(
self,
underlying_embeddings: Embeddings,
document_embedding_store: BaseStore[str, List[float]],
*,
batch_size: Optional[int] = None,
query_embedding_store: Optional[BaseStore[str, List[float]]] = None,
) -> None:
"""初始化嵌入器。
参数:
underlying_embeddings: 用于计算嵌入的嵌入器。
document_embedding_store: 用于缓存文档嵌入的存储。
batch_size: 在存储更新之间嵌入的文档数量。
query_embedding_store: 用于缓存查询嵌入的存储。
如果为None,则不缓存查询嵌入。
"""
super().__init__()
self.document_embedding_store = document_embedding_store
self.query_embedding_store = query_embedding_store
self.underlying_embeddings = underlying_embeddings
self.batch_size = batch_size
[docs] def embed_documents(self, texts: List[str]) -> List[List[float]]:
"""嵌入文本列表。
该方法首先检查嵌入缓存。
如果找不到嵌入,该方法将使用底层的嵌入器来嵌入文档,并将结果存储在缓存中。
参数:
texts:要嵌入的文本列表。
返回:
给定文本的嵌入列表。
"""
vectors: List[Union[List[float], None]] = self.document_embedding_store.mget(
texts
)
all_missing_indices: List[int] = [
i for i, vector in enumerate(vectors) if vector is None
]
for missing_indices in batch_iterate(self.batch_size, all_missing_indices):
missing_texts = [texts[i] for i in missing_indices]
missing_vectors = self.underlying_embeddings.embed_documents(missing_texts)
self.document_embedding_store.mset(
list(zip(missing_texts, missing_vectors))
)
for index, updated_vector in zip(missing_indices, missing_vectors):
vectors[index] = updated_vector
return cast(
List[List[float]], vectors
) # Nones should have been resolved by now
[docs] async def aembed_documents(self, texts: List[str]) -> List[List[float]]:
"""嵌入文本列表。
该方法首先检查嵌入缓存。
如果找不到嵌入,该方法将使用底层的嵌入器来嵌入文档,并将结果存储在缓存中。
参数:
texts:要嵌入的文本列表。
返回:
给定文本的嵌入列表。
"""
vectors: List[
Union[List[float], None]
] = await self.document_embedding_store.amget(texts)
all_missing_indices: List[int] = [
i for i, vector in enumerate(vectors) if vector is None
]
# batch_iterate supports None batch_size which returns all elements at once
# as a single batch.
for missing_indices in batch_iterate(self.batch_size, all_missing_indices):
missing_texts = [texts[i] for i in missing_indices]
missing_vectors = await self.underlying_embeddings.aembed_documents(
missing_texts
)
await self.document_embedding_store.amset(
list(zip(missing_texts, missing_vectors))
)
for index, updated_vector in zip(missing_indices, missing_vectors):
vectors[index] = updated_vector
return cast(
List[List[float]], vectors
) # Nones should have been resolved by now
[docs] def embed_query(self, text: str) -> List[float]:
"""嵌入查询文本。
默认情况下,此方法不会缓存查询。要启用缓存,请在初始化嵌入器时将`cache_query`参数设置为`True`。
参数:
text:要嵌入的文本。
返回:
给定文本的嵌入。
"""
if not self.query_embedding_store:
return self.underlying_embeddings.embed_query(text)
(cached,) = self.query_embedding_store.mget([text])
if cached is not None:
return cached
vector = self.underlying_embeddings.embed_query(text)
self.query_embedding_store.mset([(text, vector)])
return vector
[docs] async def aembed_query(self, text: str) -> List[float]:
"""嵌入查询文本。
默认情况下,此方法不会缓存查询。要启用缓存,请在初始化嵌入器时将`cache_query`参数设置为`True`。
参数:
text:要嵌入的文本。
返回:
给定文本的嵌入。
"""
if not self.query_embedding_store:
return await self.underlying_embeddings.aembed_query(text)
(cached,) = await self.query_embedding_store.amget([text])
if cached is not None:
return cached
vector = await self.underlying_embeddings.aembed_query(text)
await self.query_embedding_store.amset([(text, vector)])
return vector
[docs] @classmethod
def from_bytes_store(
cls,
underlying_embeddings: Embeddings,
document_embedding_cache: ByteStore,
*,
namespace: str = "",
batch_size: Optional[int] = None,
query_embedding_cache: Union[bool, ByteStore] = False,
) -> CacheBackedEmbeddings:
"""添加必要的序列化和编码到存储的入口。
参数:
underlying_embeddings: 用于嵌入的嵌入器。
document_embedding_cache: 用于存储文档嵌入的缓存。
*,
namespace: 用于文档缓存的命名空间。
此命名空间用于避免与其他缓存发生冲突。
例如,将其设置为所使用的嵌入模型的名称。
batch_size: 在存储更新之间要嵌入的文档数量。
query_embedding_cache: 用于存储查询嵌入的缓存。
True 表示使用与文档嵌入相同的缓存。
False 表示不缓存查询嵌入。
"""
namespace = namespace
key_encoder = _create_key_encoder(namespace)
document_embedding_store = EncoderBackedStore[str, List[float]](
document_embedding_cache,
key_encoder,
_value_serializer,
_value_deserializer,
)
if query_embedding_cache is True:
query_embedding_store = document_embedding_store
elif query_embedding_cache is False:
query_embedding_store = None
else:
query_embedding_store = EncoderBackedStore[str, List[float]](
query_embedding_cache,
key_encoder,
_value_serializer,
_value_deserializer,
)
return cls(
underlying_embeddings,
document_embedding_store,
batch_size=batch_size,
query_embedding_store=query_embedding_store,
)