from __future__ import annotations
import base64
import logging
import uuid
from typing import (
TYPE_CHECKING,
Any,
Callable,
Dict,
Iterable,
List,
Optional,
Tuple,
Type,
)
import numpy as np
from langchain_core._api import deprecated
from langchain_core.documents import Document
from langchain_core.embeddings import Embeddings
from langchain_core.utils import xor_args
from langchain_core.vectorstores import VectorStore
from langchain_community.vectorstores.utils import maximal_marginal_relevance
if TYPE_CHECKING:
import chromadb
import chromadb.config
from chromadb.api.types import ID, OneOrMany, Where, WhereDocument
logger = logging.getLogger()
DEFAULT_K = 4 # Number of Documents to return.
def _results_to_docs(results: Any) -> List[Document]:
return [doc for doc, _ in _results_to_docs_and_scores(results)]
def _results_to_docs_and_scores(results: Any) -> List[Tuple[Document, float]]:
return [
# TODO: Chroma can do batch querying,
# we shouldn't hard code to the 1st result
(Document(page_content=result[0], metadata=result[1] or {}), result[2])
for result in zip(
results["documents"][0],
results["metadatas"][0],
results["distances"][0],
)
]
[docs]class Chroma(VectorStore):
"""`ChromaDB` 向量存储。
要使用,您应该已经安装了 ``chromadb`` python 包。
示例:
.. code-block:: python
from langchain_community.vectorstores import Chroma
from langchain_community.embeddings.openai import OpenAIEmbeddings
embeddings = OpenAIEmbeddings()
vectorstore = Chroma("langchain_store", embeddings)"""
_LANGCHAIN_DEFAULT_COLLECTION_NAME = "langchain"
[docs] def __init__(
self,
collection_name: str = _LANGCHAIN_DEFAULT_COLLECTION_NAME,
embedding_function: Optional[Embeddings] = None,
persist_directory: Optional[str] = None,
client_settings: Optional[chromadb.config.Settings] = None,
collection_metadata: Optional[Dict] = None,
client: Optional[chromadb.Client] = None,
relevance_score_fn: Optional[Callable[[float], float]] = None,
) -> None:
"""使用Chroma客户端进行初始化。"""
try:
import chromadb
import chromadb.config
except ImportError:
raise ImportError(
"Could not import chromadb python package. "
"Please install it with `pip install chromadb`."
)
if client is not None:
self._client_settings = client_settings
self._client = client
self._persist_directory = persist_directory
else:
if client_settings:
# If client_settings is provided with persist_directory specified,
# then it is "in-memory and persisting to disk" mode.
client_settings.persist_directory = (
persist_directory or client_settings.persist_directory
)
if client_settings.persist_directory is not None:
# Maintain backwards compatibility with chromadb < 0.4.0
major, minor, _ = chromadb.__version__.split(".")
if int(major) == 0 and int(minor) < 4:
client_settings.chroma_db_impl = "duckdb+parquet"
_client_settings = client_settings
elif persist_directory:
# Maintain backwards compatibility with chromadb < 0.4.0
major, minor, _ = chromadb.__version__.split(".")
if int(major) == 0 and int(minor) < 4:
_client_settings = chromadb.config.Settings(
chroma_db_impl="duckdb+parquet",
)
else:
_client_settings = chromadb.config.Settings(is_persistent=True)
_client_settings.persist_directory = persist_directory
else:
_client_settings = chromadb.config.Settings()
self._client_settings = _client_settings
self._client = chromadb.Client(_client_settings)
self._persist_directory = (
_client_settings.persist_directory or persist_directory
)
self._embedding_function = embedding_function
self._collection = self._client.get_or_create_collection(
name=collection_name,
embedding_function=None,
metadata=collection_metadata,
)
self.override_relevance_score_fn = relevance_score_fn
@property
def embeddings(self) -> Optional[Embeddings]:
return self._embedding_function
@xor_args(("query_texts", "query_embeddings"))
def __query_collection(
self,
query_texts: Optional[List[str]] = None,
query_embeddings: Optional[List[List[float]]] = None,
n_results: int = 4,
where: Optional[Dict[str, str]] = None,
where_document: Optional[Dict[str, str]] = None,
**kwargs: Any,
) -> List[Document]:
"""查询色度集合。"""
try:
import chromadb # noqa: F401
except ImportError:
raise ImportError(
"Could not import chromadb python package. "
"Please install it with `pip install chromadb`."
)
return self._collection.query(
query_texts=query_texts,
query_embeddings=query_embeddings,
n_results=n_results,
where=where,
where_document=where_document,
**kwargs,
)
[docs] def encode_image(self, uri: str) -> str:
"""从图像URI获取base64字符串。"""
with open(uri, "rb") as image_file:
return base64.b64encode(image_file.read()).decode("utf-8")
[docs] def add_images(
self,
uris: List[str],
metadatas: Optional[List[dict]] = None,
ids: Optional[List[str]] = None,
**kwargs: Any,
) -> List[str]:
"""运行更多的图像通过嵌入并添加到向量存储。
参数:
uris List[str]: 图像的文件路径。
metadatas (Optional[List[dict]], optional): 可选的元数据列表。
ids (Optional[List[str]], optional): 可选的ID列表。
返回:
List[str]: 添加的图像的ID列表。
"""
# Map from uris to b64 encoded strings
b64_texts = [self.encode_image(uri=uri) for uri in uris]
# Populate IDs
if ids is None:
ids = [str(uuid.uuid4()) for _ in uris]
embeddings = None
# Set embeddings
if self._embedding_function is not None and hasattr(
self._embedding_function, "embed_image"
):
embeddings = self._embedding_function.embed_image(uris=uris)
if metadatas:
# fill metadatas with empty dicts if somebody
# did not specify metadata for all images
length_diff = len(uris) - len(metadatas)
if length_diff:
metadatas = metadatas + [{}] * length_diff
empty_ids = []
non_empty_ids = []
for idx, m in enumerate(metadatas):
if m:
non_empty_ids.append(idx)
else:
empty_ids.append(idx)
if non_empty_ids:
metadatas = [metadatas[idx] for idx in non_empty_ids]
images_with_metadatas = [b64_texts[idx] for idx in non_empty_ids]
embeddings_with_metadatas = (
[embeddings[idx] for idx in non_empty_ids] if embeddings else None
)
ids_with_metadata = [ids[idx] for idx in non_empty_ids]
try:
self._collection.upsert(
metadatas=metadatas,
embeddings=embeddings_with_metadatas,
documents=images_with_metadatas,
ids=ids_with_metadata,
)
except ValueError as e:
if "Expected metadata value to be" in str(e):
msg = (
"Try filtering complex metadata using "
"langchain_community.vectorstores.utils.filter_complex_metadata."
)
raise ValueError(e.args[0] + "\n\n" + msg)
else:
raise e
if empty_ids:
images_without_metadatas = [b64_texts[j] for j in empty_ids]
embeddings_without_metadatas = (
[embeddings[j] for j in empty_ids] if embeddings else None
)
ids_without_metadatas = [ids[j] for j in empty_ids]
self._collection.upsert(
embeddings=embeddings_without_metadatas,
documents=images_without_metadatas,
ids=ids_without_metadatas,
)
else:
self._collection.upsert(
embeddings=embeddings,
documents=b64_texts,
ids=ids,
)
return ids
[docs] def add_texts(
self,
texts: Iterable[str],
metadatas: Optional[List[dict]] = None,
ids: Optional[List[str]] = None,
**kwargs: Any,
) -> List[str]:
"""运行更多文本通过嵌入并添加到向量存储。
参数:
texts (Iterable[str]): 要添加到向量存储的文本。
metadatas (Optional[List[dict]], optional): 元数据的可选列表。
ids (Optional[List[str]], optional): 可选的ID列表。
返回:
List[str]: 添加文本的ID列表。
"""
# TODO: Handle the case where the user doesn't provide ids on the Collection
if ids is None:
ids = [str(uuid.uuid4()) for _ in texts]
embeddings = None
texts = list(texts)
if self._embedding_function is not None:
embeddings = self._embedding_function.embed_documents(texts)
if metadatas:
# fill metadatas with empty dicts if somebody
# did not specify metadata for all texts
length_diff = len(texts) - len(metadatas)
if length_diff:
metadatas = metadatas + [{}] * length_diff
empty_ids = []
non_empty_ids = []
for idx, m in enumerate(metadatas):
if m:
non_empty_ids.append(idx)
else:
empty_ids.append(idx)
if non_empty_ids:
metadatas = [metadatas[idx] for idx in non_empty_ids]
texts_with_metadatas = [texts[idx] for idx in non_empty_ids]
embeddings_with_metadatas = (
[embeddings[idx] for idx in non_empty_ids] if embeddings else None
)
ids_with_metadata = [ids[idx] for idx in non_empty_ids]
try:
self._collection.upsert(
metadatas=metadatas,
embeddings=embeddings_with_metadatas,
documents=texts_with_metadatas,
ids=ids_with_metadata,
)
except ValueError as e:
if "Expected metadata value to be" in str(e):
msg = (
"Try filtering complex metadata from the document using "
"langchain_community.vectorstores.utils.filter_complex_metadata."
)
raise ValueError(e.args[0] + "\n\n" + msg)
else:
raise e
if empty_ids:
texts_without_metadatas = [texts[j] for j in empty_ids]
embeddings_without_metadatas = (
[embeddings[j] for j in empty_ids] if embeddings else None
)
ids_without_metadatas = [ids[j] for j in empty_ids]
self._collection.upsert(
embeddings=embeddings_without_metadatas,
documents=texts_without_metadatas,
ids=ids_without_metadatas,
)
else:
self._collection.upsert(
embeddings=embeddings,
documents=texts,
ids=ids,
)
return ids
[docs] def similarity_search(
self,
query: str,
k: int = DEFAULT_K,
filter: Optional[Dict[str, str]] = None,
**kwargs: Any,
) -> List[Document]:
"""使用Chroma运行相似性搜索。
参数:
query (str): 要搜索的查询文本。
k (int): 要返回的结果数量。默认为4。
filter (Optional[Dict[str, str]]): 按元数据过滤。默认为None。
返回:
List[Document]: 与查询文本最相似的文档列表。
"""
docs_and_scores = self.similarity_search_with_score(
query, k, filter=filter, **kwargs
)
return [doc for doc, _ in docs_and_scores]
[docs] def similarity_search_by_vector(
self,
embedding: List[float],
k: int = DEFAULT_K,
filter: Optional[Dict[str, str]] = None,
where_document: Optional[Dict[str, str]] = None,
**kwargs: Any,
) -> List[Document]:
"""返回与嵌入向量最相似的文档。
参数:
embedding(List[float]):要查找相似文档的嵌入。
k(int):要返回的文档数量。默认为4。
filter(Optional[Dict[str, str]]):按元数据过滤。默认为None。
返回:
与查询向量最相似的文档列表。
"""
results = self.__query_collection(
query_embeddings=embedding,
n_results=k,
where=filter,
where_document=where_document,
**kwargs,
)
return _results_to_docs(results)
[docs] def similarity_search_by_vector_with_relevance_scores(
self,
embedding: List[float],
k: int = DEFAULT_K,
filter: Optional[Dict[str, str]] = None,
where_document: Optional[Dict[str, str]] = None,
**kwargs: Any,
) -> List[Tuple[Document, float]]:
"""返回与嵌入向量和相似度分数最相似的文档。
参数:
embedding(List[float]):要查找相似文档的嵌入。
k(int):要返回的文档数量。默认为4。
filter(Optional[Dict[str, str]):按元数据过滤。默认为None。
返回:
List[Tuple[Document, float]]:与查询文本最相似的文档列表,每个文档对应的余弦距离浮点数。
较低的分数表示更相似。
"""
results = self.__query_collection(
query_embeddings=embedding,
n_results=k,
where=filter,
where_document=where_document,
**kwargs,
)
return _results_to_docs_and_scores(results)
[docs] def similarity_search_with_score(
self,
query: str,
k: int = DEFAULT_K,
filter: Optional[Dict[str, str]] = None,
where_document: Optional[Dict[str, str]] = None,
**kwargs: Any,
) -> List[Tuple[Document, float]]:
"""运行使用Chroma和距离进行相似性搜索。
参数:
query (str): 要搜索的查询文本。
k (int): 要返回的结果数量。默认为4。
filter (Optional[Dict[str, str]]): 按元数据进行过滤。默认为None。
返回:
List[Tuple[Document, float]]: 与查询文本最相似的文档列表,每个文档对应的余弦距离值为浮点数。
较低的分数表示更相似。
"""
if self._embedding_function is None:
results = self.__query_collection(
query_texts=[query],
n_results=k,
where=filter,
where_document=where_document,
**kwargs,
)
else:
query_embedding = self._embedding_function.embed_query(query)
results = self.__query_collection(
query_embeddings=[query_embedding],
n_results=k,
where=filter,
where_document=where_document,
**kwargs,
)
return _results_to_docs_and_scores(results)
def _select_relevance_score_fn(self) -> Callable[[float], float]:
"""“正确”的相关性函数可能会有所不同,取决于一些因素,包括:
- 向量存储中使用的距离/相似度度量
- 嵌入的规模(OpenAI的是单位规范化的。许多其他嵌入不是!)
- 嵌入的维度
- 等等。
"""
if self.override_relevance_score_fn:
return self.override_relevance_score_fn
distance = "l2"
distance_key = "hnsw:space"
metadata = self._collection.metadata
if metadata and distance_key in metadata:
distance = metadata[distance_key]
if distance == "cosine":
return self._cosine_relevance_score_fn
elif distance == "l2":
return self._euclidean_relevance_score_fn
elif distance == "ip":
return self._max_inner_product_relevance_score_fn
else:
raise ValueError(
"No supported normalization function"
f" for distance metric of type: {distance}."
"Consider providing relevance_score_fn to Chroma constructor."
)
[docs] def max_marginal_relevance_search_by_vector(
self,
embedding: List[float],
k: int = DEFAULT_K,
fetch_k: int = 20,
lambda_mult: float = 0.5,
filter: Optional[Dict[str, str]] = None,
where_document: Optional[Dict[str, str]] = None,
**kwargs: Any,
) -> List[Document]:
"""返回使用最大边际相关性选择的文档。
最大边际相关性优化了与查询的相似性和所选文档之间的多样性。
参数:
embedding: 查找与之相似文档的嵌入。
k: 要返回的文档数量。默认为4。
fetch_k: 要获取的文档数量,以传递给MMR算法。
lambda_mult: 介于0和1之间的数字,确定结果之间多样性的程度,
其中0对应于最大多样性,1对应于最小多样性。
默认为0.5。
filter (Optional[Dict[str, str]]): 按元数据过滤。默认为None。
返回:
通过最大边际相关性选择的文档列表。
"""
results = self.__query_collection(
query_embeddings=embedding,
n_results=fetch_k,
where=filter,
where_document=where_document,
include=["metadatas", "documents", "distances", "embeddings"],
**kwargs,
)
mmr_selected = maximal_marginal_relevance(
np.array(embedding, dtype=np.float32),
results["embeddings"][0],
k=k,
lambda_mult=lambda_mult,
)
candidates = _results_to_docs(results)
selected_results = [r for i, r in enumerate(candidates) if i in mmr_selected]
return selected_results
[docs] def max_marginal_relevance_search(
self,
query: str,
k: int = DEFAULT_K,
fetch_k: int = 20,
lambda_mult: float = 0.5,
filter: Optional[Dict[str, str]] = None,
where_document: Optional[Dict[str, str]] = None,
**kwargs: Any,
) -> List[Document]:
"""返回使用最大边际相关性选择的文档。
最大边际相关性优化了与查询的相似性和所选文档之间的多样性。
参数:
query:要查找类似文档的文本。
k:要返回的文档数量。默认为4。
fetch_k:要获取以传递给MMR算法的文档数量。
lambda_mult:介于0和1之间的数字,确定结果之间多样性的程度,
其中0对应于最大多样性,1对应于最小多样性。
默认为0.5。
filter(可选[Dict[str,str]]):按元数据筛选。默认为None。
返回:
通过最大边际相关性选择的文档列表。
"""
if self._embedding_function is None:
raise ValueError(
"For MMR search, you must specify an embedding function on" "creation."
)
embedding = self._embedding_function.embed_query(query)
docs = self.max_marginal_relevance_search_by_vector(
embedding,
k,
fetch_k,
lambda_mult=lambda_mult,
filter=filter,
where_document=where_document,
)
return docs
[docs] def delete_collection(self) -> None:
"""删除集合。"""
self._client.delete_collection(self._collection.name)
[docs] def get(
self,
ids: Optional[OneOrMany[ID]] = None,
where: Optional[Where] = None,
limit: Optional[int] = None,
offset: Optional[int] = None,
where_document: Optional[WhereDocument] = None,
include: Optional[List[str]] = None,
) -> Dict[str, Any]:
"""获取集合。
参数:
ids:要获取的嵌入的ids。可选。
where:用于过滤结果的Where类型字典。
例如 `{"color" : "red", "price": 4.20}`。可选。
limit:要返回的文档数量。可选。
offset:从哪里开始返回结果的偏移量。
用于分页结果与限制。可选。
where_document:用于按文档过滤的WhereDocument类型字典。
例如 `{$contains: "hello"}`。可选。
include:要包含在结果中的内容列表。
可包含 `"embeddings"`, `"metadatas"`, `"documents"`。
ids始终包含在内。
默认为 `["metadatas", "documents"]`。可选。
"""
kwargs = {
"ids": ids,
"where": where,
"limit": limit,
"offset": offset,
"where_document": where_document,
}
if include is not None:
kwargs["include"] = include
return self._collection.get(**kwargs)
[docs] @deprecated(
since="0.1.17",
message=(
"Since Chroma 0.4.x the manual persistence method is no longer "
"supported as docs are automatically persisted."
),
removal="0.3.0",
)
def persist(self) -> None:
"""持久化集合。
这可以用来显式地将数据持久化到磁盘。
当对象被销毁时,它也会被自动调用。
自Chroma 0.4.x以来,不再支持手动持久化方法,因为文档会自动持久化。
"""
if self._persist_directory is None:
raise ValueError(
"You must specify a persist_directory on"
"creation to persist the collection."
)
import chromadb
# Maintain backwards compatibility with chromadb < 0.4.0
major, minor, _ = chromadb.__version__.split(".")
if int(major) == 0 and int(minor) < 4:
self._client.persist()
[docs] def update_document(self, document_id: str, document: Document) -> None:
"""更新集合中的文档。
参数:
document_id (str): 需要更新的文档的ID。
document (Document): 需要更新的文档。
"""
return self.update_documents([document_id], [document])
[docs] def update_documents(self, ids: List[str], documents: List[Document]) -> None:
"""更新集合中的文档。
参数:
ids(List[str]):要更新的文档的id列表。
documents(List[Document]):要更新的文档列表。
"""
text = [document.page_content for document in documents]
metadata = [document.metadata for document in documents]
if self._embedding_function is None:
raise ValueError(
"For update, you must specify an embedding function on creation."
)
embeddings = self._embedding_function.embed_documents(text)
if hasattr(
self._collection._client, "max_batch_size"
): # for Chroma 0.4.10 and above
from chromadb.utils.batch_utils import create_batches
for batch in create_batches(
api=self._collection._client,
ids=ids,
metadatas=metadata,
documents=text,
embeddings=embeddings,
):
self._collection.update(
ids=batch[0],
embeddings=batch[1],
documents=batch[3],
metadatas=batch[2],
)
else:
self._collection.update(
ids=ids,
embeddings=embeddings,
documents=text,
metadatas=metadata,
)
[docs] @classmethod
def from_texts(
cls: Type[Chroma],
texts: List[str],
embedding: Optional[Embeddings] = None,
metadatas: Optional[List[dict]] = None,
ids: Optional[List[str]] = None,
collection_name: str = _LANGCHAIN_DEFAULT_COLLECTION_NAME,
persist_directory: Optional[str] = None,
client_settings: Optional[chromadb.config.Settings] = None,
client: Optional[chromadb.Client] = None,
collection_metadata: Optional[Dict] = None,
**kwargs: Any,
) -> Chroma:
"""从原始文档创建一个Chroma向量存储。
如果指定了persist_directory,则集合将持久化在那里。
否则,数据将是临时的内存中数据。
参数:
texts (List[str]): 要添加到集合中的文本列表。
collection_name (str): 要创建的集合的名称。
persist_directory (Optional[str]): 持久化集合的目录。
embedding (Optional[Embeddings]): 嵌入函数。默认为None。
metadatas (Optional[List[dict]]): 元数据列表。默认为None。
ids (Optional[List[str]]): 文档ID列表。默认为None。
client_settings (Optional[chromadb.config.Settings]): Chroma客户端设置。
collection_metadata (Optional[Dict]): 集合配置。默认为None。
返回:
Chroma: Chroma向量存储。
"""
chroma_collection = cls(
collection_name=collection_name,
embedding_function=embedding,
persist_directory=persist_directory,
client_settings=client_settings,
client=client,
collection_metadata=collection_metadata,
**kwargs,
)
if ids is None:
ids = [str(uuid.uuid4()) for _ in texts]
if hasattr(
chroma_collection._client, "max_batch_size"
): # for Chroma 0.4.10 and above
from chromadb.utils.batch_utils import create_batches
for batch in create_batches(
api=chroma_collection._client,
ids=ids,
metadatas=metadatas,
documents=texts,
):
chroma_collection.add_texts(
texts=batch[3] if batch[3] else [],
metadatas=batch[2] if batch[2] else None,
ids=batch[0],
)
else:
chroma_collection.add_texts(texts=texts, metadatas=metadatas, ids=ids)
return chroma_collection
[docs] @classmethod
def from_documents(
cls: Type[Chroma],
documents: List[Document],
embedding: Optional[Embeddings] = None,
ids: Optional[List[str]] = None,
collection_name: str = _LANGCHAIN_DEFAULT_COLLECTION_NAME,
persist_directory: Optional[str] = None,
client_settings: Optional[chromadb.config.Settings] = None,
client: Optional[chromadb.Client] = None, # Add this line
collection_metadata: Optional[Dict] = None,
**kwargs: Any,
) -> Chroma:
"""从文档列表创建一个Chroma向量存储。
如果指定了persist_directory,则集合将被持久化存储在那里。
否则,数据将在内存中是临时的。
参数:
collection_name (str): 要创建的集合的名称。
persist_directory (Optional[str]): 持久化存储集合的目录。
ids (Optional[List[str]]): 文档ID列表。默认为None。
documents (List[Document]): 要添加到向量存储的文档列表。
embedding (Optional[Embeddings]): 嵌入函数。默认为None。
client_settings (Optional[chromadb.config.Settings]): Chroma客户端设置。
collection_metadata (Optional[Dict]): 集合配置。默认为None。
返回:
Chroma: Chroma向量存储。
"""
texts = [doc.page_content for doc in documents]
metadatas = [doc.metadata for doc in documents]
return cls.from_texts(
texts=texts,
embedding=embedding,
metadatas=metadatas,
ids=ids,
collection_name=collection_name,
persist_directory=persist_directory,
client_settings=client_settings,
client=client,
collection_metadata=collection_metadata,
**kwargs,
)
[docs] def delete(self, ids: Optional[List[str]] = None, **kwargs: Any) -> None:
"""根据向量ID删除。
参数:
ids:要删除的ID列表。
"""
self._collection.delete(ids=ids)
def __len__(self) -> int:
"""统计集合中文档的数量。"""
return self._collection.count()