Source code for langchain_community.vectorstores.chroma

from __future__ import annotations

import base64
import logging
import uuid
from typing import (
    TYPE_CHECKING,
    Any,
    Callable,
    Dict,
    Iterable,
    List,
    Optional,
    Tuple,
    Type,
)

import numpy as np
from langchain_core._api import deprecated
from langchain_core.documents import Document
from langchain_core.embeddings import Embeddings
from langchain_core.utils import xor_args
from langchain_core.vectorstores import VectorStore

from langchain_community.vectorstores.utils import maximal_marginal_relevance

if TYPE_CHECKING:
    import chromadb
    import chromadb.config
    from chromadb.api.types import ID, OneOrMany, Where, WhereDocument

logger = logging.getLogger()
DEFAULT_K = 4  # Number of Documents to return.


def _results_to_docs(results: Any) -> List[Document]:
    return [doc for doc, _ in _results_to_docs_and_scores(results)]


def _results_to_docs_and_scores(results: Any) -> List[Tuple[Document, float]]:
    return [
        # TODO: Chroma can do batch querying,
        # we shouldn't hard code to the 1st result
        (Document(page_content=result[0], metadata=result[1] or {}), result[2])
        for result in zip(
            results["documents"][0],
            results["metadatas"][0],
            results["distances"][0],
        )
    ]


[docs]class Chroma(VectorStore): """`ChromaDB` 向量存储。 要使用,您应该已经安装了 ``chromadb`` python 包。 示例: .. code-block:: python from langchain_community.vectorstores import Chroma from langchain_community.embeddings.openai import OpenAIEmbeddings embeddings = OpenAIEmbeddings() vectorstore = Chroma("langchain_store", embeddings)""" _LANGCHAIN_DEFAULT_COLLECTION_NAME = "langchain"
[docs] def __init__( self, collection_name: str = _LANGCHAIN_DEFAULT_COLLECTION_NAME, embedding_function: Optional[Embeddings] = None, persist_directory: Optional[str] = None, client_settings: Optional[chromadb.config.Settings] = None, collection_metadata: Optional[Dict] = None, client: Optional[chromadb.Client] = None, relevance_score_fn: Optional[Callable[[float], float]] = None, ) -> None: """使用Chroma客户端进行初始化。""" try: import chromadb import chromadb.config except ImportError: raise ImportError( "Could not import chromadb python package. " "Please install it with `pip install chromadb`." ) if client is not None: self._client_settings = client_settings self._client = client self._persist_directory = persist_directory else: if client_settings: # If client_settings is provided with persist_directory specified, # then it is "in-memory and persisting to disk" mode. client_settings.persist_directory = ( persist_directory or client_settings.persist_directory ) if client_settings.persist_directory is not None: # Maintain backwards compatibility with chromadb < 0.4.0 major, minor, _ = chromadb.__version__.split(".") if int(major) == 0 and int(minor) < 4: client_settings.chroma_db_impl = "duckdb+parquet" _client_settings = client_settings elif persist_directory: # Maintain backwards compatibility with chromadb < 0.4.0 major, minor, _ = chromadb.__version__.split(".") if int(major) == 0 and int(minor) < 4: _client_settings = chromadb.config.Settings( chroma_db_impl="duckdb+parquet", ) else: _client_settings = chromadb.config.Settings(is_persistent=True) _client_settings.persist_directory = persist_directory else: _client_settings = chromadb.config.Settings() self._client_settings = _client_settings self._client = chromadb.Client(_client_settings) self._persist_directory = ( _client_settings.persist_directory or persist_directory ) self._embedding_function = embedding_function self._collection = self._client.get_or_create_collection( name=collection_name, embedding_function=None, metadata=collection_metadata, ) self.override_relevance_score_fn = relevance_score_fn
@property def embeddings(self) -> Optional[Embeddings]: return self._embedding_function @xor_args(("query_texts", "query_embeddings")) def __query_collection( self, query_texts: Optional[List[str]] = None, query_embeddings: Optional[List[List[float]]] = None, n_results: int = 4, where: Optional[Dict[str, str]] = None, where_document: Optional[Dict[str, str]] = None, **kwargs: Any, ) -> List[Document]: """查询色度集合。""" try: import chromadb # noqa: F401 except ImportError: raise ImportError( "Could not import chromadb python package. " "Please install it with `pip install chromadb`." ) return self._collection.query( query_texts=query_texts, query_embeddings=query_embeddings, n_results=n_results, where=where, where_document=where_document, **kwargs, )
[docs] def encode_image(self, uri: str) -> str: """从图像URI获取base64字符串。""" with open(uri, "rb") as image_file: return base64.b64encode(image_file.read()).decode("utf-8")
[docs] def add_images( self, uris: List[str], metadatas: Optional[List[dict]] = None, ids: Optional[List[str]] = None, **kwargs: Any, ) -> List[str]: """运行更多的图像通过嵌入并添加到向量存储。 参数: uris List[str]: 图像的文件路径。 metadatas (Optional[List[dict]], optional): 可选的元数据列表。 ids (Optional[List[str]], optional): 可选的ID列表。 返回: List[str]: 添加的图像的ID列表。 """ # Map from uris to b64 encoded strings b64_texts = [self.encode_image(uri=uri) for uri in uris] # Populate IDs if ids is None: ids = [str(uuid.uuid4()) for _ in uris] embeddings = None # Set embeddings if self._embedding_function is not None and hasattr( self._embedding_function, "embed_image" ): embeddings = self._embedding_function.embed_image(uris=uris) if metadatas: # fill metadatas with empty dicts if somebody # did not specify metadata for all images length_diff = len(uris) - len(metadatas) if length_diff: metadatas = metadatas + [{}] * length_diff empty_ids = [] non_empty_ids = [] for idx, m in enumerate(metadatas): if m: non_empty_ids.append(idx) else: empty_ids.append(idx) if non_empty_ids: metadatas = [metadatas[idx] for idx in non_empty_ids] images_with_metadatas = [b64_texts[idx] for idx in non_empty_ids] embeddings_with_metadatas = ( [embeddings[idx] for idx in non_empty_ids] if embeddings else None ) ids_with_metadata = [ids[idx] for idx in non_empty_ids] try: self._collection.upsert( metadatas=metadatas, embeddings=embeddings_with_metadatas, documents=images_with_metadatas, ids=ids_with_metadata, ) except ValueError as e: if "Expected metadata value to be" in str(e): msg = ( "Try filtering complex metadata using " "langchain_community.vectorstores.utils.filter_complex_metadata." ) raise ValueError(e.args[0] + "\n\n" + msg) else: raise e if empty_ids: images_without_metadatas = [b64_texts[j] for j in empty_ids] embeddings_without_metadatas = ( [embeddings[j] for j in empty_ids] if embeddings else None ) ids_without_metadatas = [ids[j] for j in empty_ids] self._collection.upsert( embeddings=embeddings_without_metadatas, documents=images_without_metadatas, ids=ids_without_metadatas, ) else: self._collection.upsert( embeddings=embeddings, documents=b64_texts, ids=ids, ) return ids
[docs] def add_texts( self, texts: Iterable[str], metadatas: Optional[List[dict]] = None, ids: Optional[List[str]] = None, **kwargs: Any, ) -> List[str]: """运行更多文本通过嵌入并添加到向量存储。 参数: texts (Iterable[str]): 要添加到向量存储的文本。 metadatas (Optional[List[dict]], optional): 元数据的可选列表。 ids (Optional[List[str]], optional): 可选的ID列表。 返回: List[str]: 添加文本的ID列表。 """ # TODO: Handle the case where the user doesn't provide ids on the Collection if ids is None: ids = [str(uuid.uuid4()) for _ in texts] embeddings = None texts = list(texts) if self._embedding_function is not None: embeddings = self._embedding_function.embed_documents(texts) if metadatas: # fill metadatas with empty dicts if somebody # did not specify metadata for all texts length_diff = len(texts) - len(metadatas) if length_diff: metadatas = metadatas + [{}] * length_diff empty_ids = [] non_empty_ids = [] for idx, m in enumerate(metadatas): if m: non_empty_ids.append(idx) else: empty_ids.append(idx) if non_empty_ids: metadatas = [metadatas[idx] for idx in non_empty_ids] texts_with_metadatas = [texts[idx] for idx in non_empty_ids] embeddings_with_metadatas = ( [embeddings[idx] for idx in non_empty_ids] if embeddings else None ) ids_with_metadata = [ids[idx] for idx in non_empty_ids] try: self._collection.upsert( metadatas=metadatas, embeddings=embeddings_with_metadatas, documents=texts_with_metadatas, ids=ids_with_metadata, ) except ValueError as e: if "Expected metadata value to be" in str(e): msg = ( "Try filtering complex metadata from the document using " "langchain_community.vectorstores.utils.filter_complex_metadata." ) raise ValueError(e.args[0] + "\n\n" + msg) else: raise e if empty_ids: texts_without_metadatas = [texts[j] for j in empty_ids] embeddings_without_metadatas = ( [embeddings[j] for j in empty_ids] if embeddings else None ) ids_without_metadatas = [ids[j] for j in empty_ids] self._collection.upsert( embeddings=embeddings_without_metadatas, documents=texts_without_metadatas, ids=ids_without_metadatas, ) else: self._collection.upsert( embeddings=embeddings, documents=texts, ids=ids, ) return ids
[docs] def similarity_search_by_vector( self, embedding: List[float], k: int = DEFAULT_K, filter: Optional[Dict[str, str]] = None, where_document: Optional[Dict[str, str]] = None, **kwargs: Any, ) -> List[Document]: """返回与嵌入向量最相似的文档。 参数: embedding(List[float]):要查找相似文档的嵌入。 k(int):要返回的文档数量。默认为4。 filter(Optional[Dict[str, str]]):按元数据过滤。默认为None。 返回: 与查询向量最相似的文档列表。 """ results = self.__query_collection( query_embeddings=embedding, n_results=k, where=filter, where_document=where_document, **kwargs, ) return _results_to_docs(results)
[docs] def similarity_search_by_vector_with_relevance_scores( self, embedding: List[float], k: int = DEFAULT_K, filter: Optional[Dict[str, str]] = None, where_document: Optional[Dict[str, str]] = None, **kwargs: Any, ) -> List[Tuple[Document, float]]: """返回与嵌入向量和相似度分数最相似的文档。 参数: embedding(List[float]):要查找相似文档的嵌入。 k(int):要返回的文档数量。默认为4。 filter(Optional[Dict[str, str]):按元数据过滤。默认为None。 返回: List[Tuple[Document, float]]:与查询文本最相似的文档列表,每个文档对应的余弦距离浮点数。 较低的分数表示更相似。 """ results = self.__query_collection( query_embeddings=embedding, n_results=k, where=filter, where_document=where_document, **kwargs, ) return _results_to_docs_and_scores(results)
[docs] def similarity_search_with_score( self, query: str, k: int = DEFAULT_K, filter: Optional[Dict[str, str]] = None, where_document: Optional[Dict[str, str]] = None, **kwargs: Any, ) -> List[Tuple[Document, float]]: """运行使用Chroma和距离进行相似性搜索。 参数: query (str): 要搜索的查询文本。 k (int): 要返回的结果数量。默认为4。 filter (Optional[Dict[str, str]]): 按元数据进行过滤。默认为None。 返回: List[Tuple[Document, float]]: 与查询文本最相似的文档列表,每个文档对应的余弦距离值为浮点数。 较低的分数表示更相似。 """ if self._embedding_function is None: results = self.__query_collection( query_texts=[query], n_results=k, where=filter, where_document=where_document, **kwargs, ) else: query_embedding = self._embedding_function.embed_query(query) results = self.__query_collection( query_embeddings=[query_embedding], n_results=k, where=filter, where_document=where_document, **kwargs, ) return _results_to_docs_and_scores(results)
def _select_relevance_score_fn(self) -> Callable[[float], float]: """“正确”的相关性函数可能会有所不同,取决于一些因素,包括: - 向量存储中使用的距离/相似度度量 - 嵌入的规模(OpenAI的是单位规范化的。许多其他嵌入不是!) - 嵌入的维度 - 等等。 """ if self.override_relevance_score_fn: return self.override_relevance_score_fn distance = "l2" distance_key = "hnsw:space" metadata = self._collection.metadata if metadata and distance_key in metadata: distance = metadata[distance_key] if distance == "cosine": return self._cosine_relevance_score_fn elif distance == "l2": return self._euclidean_relevance_score_fn elif distance == "ip": return self._max_inner_product_relevance_score_fn else: raise ValueError( "No supported normalization function" f" for distance metric of type: {distance}." "Consider providing relevance_score_fn to Chroma constructor." )
[docs] def max_marginal_relevance_search_by_vector( self, embedding: List[float], k: int = DEFAULT_K, fetch_k: int = 20, lambda_mult: float = 0.5, filter: Optional[Dict[str, str]] = None, where_document: Optional[Dict[str, str]] = None, **kwargs: Any, ) -> List[Document]: """返回使用最大边际相关性选择的文档。 最大边际相关性优化了与查询的相似性和所选文档之间的多样性。 参数: embedding: 查找与之相似文档的嵌入。 k: 要返回的文档数量。默认为4。 fetch_k: 要获取的文档数量,以传递给MMR算法。 lambda_mult: 介于0和1之间的数字,确定结果之间多样性的程度, 其中0对应于最大多样性,1对应于最小多样性。 默认为0.5。 filter (Optional[Dict[str, str]]): 按元数据过滤。默认为None。 返回: 通过最大边际相关性选择的文档列表。 """ results = self.__query_collection( query_embeddings=embedding, n_results=fetch_k, where=filter, where_document=where_document, include=["metadatas", "documents", "distances", "embeddings"], **kwargs, ) mmr_selected = maximal_marginal_relevance( np.array(embedding, dtype=np.float32), results["embeddings"][0], k=k, lambda_mult=lambda_mult, ) candidates = _results_to_docs(results) selected_results = [r for i, r in enumerate(candidates) if i in mmr_selected] return selected_results
[docs] def delete_collection(self) -> None: """删除集合。""" self._client.delete_collection(self._collection.name)
[docs] def get( self, ids: Optional[OneOrMany[ID]] = None, where: Optional[Where] = None, limit: Optional[int] = None, offset: Optional[int] = None, where_document: Optional[WhereDocument] = None, include: Optional[List[str]] = None, ) -> Dict[str, Any]: """获取集合。 参数: ids:要获取的嵌入的ids。可选。 where:用于过滤结果的Where类型字典。 例如 `{"color" : "red", "price": 4.20}`。可选。 limit:要返回的文档数量。可选。 offset:从哪里开始返回结果的偏移量。 用于分页结果与限制。可选。 where_document:用于按文档过滤的WhereDocument类型字典。 例如 `{$contains: "hello"}`。可选。 include:要包含在结果中的内容列表。 可包含 `"embeddings"`, `"metadatas"`, `"documents"`。 ids始终包含在内。 默认为 `["metadatas", "documents"]`。可选。 """ kwargs = { "ids": ids, "where": where, "limit": limit, "offset": offset, "where_document": where_document, } if include is not None: kwargs["include"] = include return self._collection.get(**kwargs)
[docs] @deprecated( since="0.1.17", message=( "Since Chroma 0.4.x the manual persistence method is no longer " "supported as docs are automatically persisted." ), removal="0.3.0", ) def persist(self) -> None: """持久化集合。 这可以用来显式地将数据持久化到磁盘。 当对象被销毁时,它也会被自动调用。 自Chroma 0.4.x以来,不再支持手动持久化方法,因为文档会自动持久化。 """ if self._persist_directory is None: raise ValueError( "You must specify a persist_directory on" "creation to persist the collection." ) import chromadb # Maintain backwards compatibility with chromadb < 0.4.0 major, minor, _ = chromadb.__version__.split(".") if int(major) == 0 and int(minor) < 4: self._client.persist()
[docs] def update_document(self, document_id: str, document: Document) -> None: """更新集合中的文档。 参数: document_id (str): 需要更新的文档的ID。 document (Document): 需要更新的文档。 """ return self.update_documents([document_id], [document])
[docs] def update_documents(self, ids: List[str], documents: List[Document]) -> None: """更新集合中的文档。 参数: ids(List[str]):要更新的文档的id列表。 documents(List[Document]):要更新的文档列表。 """ text = [document.page_content for document in documents] metadata = [document.metadata for document in documents] if self._embedding_function is None: raise ValueError( "For update, you must specify an embedding function on creation." ) embeddings = self._embedding_function.embed_documents(text) if hasattr( self._collection._client, "max_batch_size" ): # for Chroma 0.4.10 and above from chromadb.utils.batch_utils import create_batches for batch in create_batches( api=self._collection._client, ids=ids, metadatas=metadata, documents=text, embeddings=embeddings, ): self._collection.update( ids=batch[0], embeddings=batch[1], documents=batch[3], metadatas=batch[2], ) else: self._collection.update( ids=ids, embeddings=embeddings, documents=text, metadatas=metadata, )
[docs] @classmethod def from_texts( cls: Type[Chroma], texts: List[str], embedding: Optional[Embeddings] = None, metadatas: Optional[List[dict]] = None, ids: Optional[List[str]] = None, collection_name: str = _LANGCHAIN_DEFAULT_COLLECTION_NAME, persist_directory: Optional[str] = None, client_settings: Optional[chromadb.config.Settings] = None, client: Optional[chromadb.Client] = None, collection_metadata: Optional[Dict] = None, **kwargs: Any, ) -> Chroma: """从原始文档创建一个Chroma向量存储。 如果指定了persist_directory,则集合将持久化在那里。 否则,数据将是临时的内存中数据。 参数: texts (List[str]): 要添加到集合中的文本列表。 collection_name (str): 要创建的集合的名称。 persist_directory (Optional[str]): 持久化集合的目录。 embedding (Optional[Embeddings]): 嵌入函数。默认为None。 metadatas (Optional[List[dict]]): 元数据列表。默认为None。 ids (Optional[List[str]]): 文档ID列表。默认为None。 client_settings (Optional[chromadb.config.Settings]): Chroma客户端设置。 collection_metadata (Optional[Dict]): 集合配置。默认为None。 返回: Chroma: Chroma向量存储。 """ chroma_collection = cls( collection_name=collection_name, embedding_function=embedding, persist_directory=persist_directory, client_settings=client_settings, client=client, collection_metadata=collection_metadata, **kwargs, ) if ids is None: ids = [str(uuid.uuid4()) for _ in texts] if hasattr( chroma_collection._client, "max_batch_size" ): # for Chroma 0.4.10 and above from chromadb.utils.batch_utils import create_batches for batch in create_batches( api=chroma_collection._client, ids=ids, metadatas=metadatas, documents=texts, ): chroma_collection.add_texts( texts=batch[3] if batch[3] else [], metadatas=batch[2] if batch[2] else None, ids=batch[0], ) else: chroma_collection.add_texts(texts=texts, metadatas=metadatas, ids=ids) return chroma_collection
[docs] @classmethod def from_documents( cls: Type[Chroma], documents: List[Document], embedding: Optional[Embeddings] = None, ids: Optional[List[str]] = None, collection_name: str = _LANGCHAIN_DEFAULT_COLLECTION_NAME, persist_directory: Optional[str] = None, client_settings: Optional[chromadb.config.Settings] = None, client: Optional[chromadb.Client] = None, # Add this line collection_metadata: Optional[Dict] = None, **kwargs: Any, ) -> Chroma: """从文档列表创建一个Chroma向量存储。 如果指定了persist_directory,则集合将被持久化存储在那里。 否则,数据将在内存中是临时的。 参数: collection_name (str): 要创建的集合的名称。 persist_directory (Optional[str]): 持久化存储集合的目录。 ids (Optional[List[str]]): 文档ID列表。默认为None。 documents (List[Document]): 要添加到向量存储的文档列表。 embedding (Optional[Embeddings]): 嵌入函数。默认为None。 client_settings (Optional[chromadb.config.Settings]): Chroma客户端设置。 collection_metadata (Optional[Dict]): 集合配置。默认为None。 返回: Chroma: Chroma向量存储。 """ texts = [doc.page_content for doc in documents] metadatas = [doc.metadata for doc in documents] return cls.from_texts( texts=texts, embedding=embedding, metadatas=metadatas, ids=ids, collection_name=collection_name, persist_directory=persist_directory, client_settings=client_settings, client=client, collection_metadata=collection_metadata, **kwargs, )
[docs] def delete(self, ids: Optional[List[str]] = None, **kwargs: Any) -> None: """根据向量ID删除。 参数: ids:要删除的ID列表。 """ self._collection.delete(ids=ids)
def __len__(self) -> int: """统计集合中文档的数量。""" return self._collection.count()