Source code for langchain_community.vectorstores.milvus

from __future__ import annotations

import logging
from typing import Any, Iterable, List, Optional, Tuple, Union
from uuid import uuid4

import numpy as np
from langchain_core.documents import Document
from langchain_core.embeddings import Embeddings
from langchain_core.vectorstores import VectorStore

from langchain_community.vectorstores.utils import maximal_marginal_relevance

logger = logging.getLogger(__name__)

DEFAULT_MILVUS_CONNECTION = {
    "host": "localhost",
    "port": "19530",
    "user": "",
    "password": "",
    "secure": False,
}


[docs]class Milvus(VectorStore): """`Milvus`向量存储。 您需要安装`pymilvus`并运行Milvus。 请查看以下文档以了解如何运行Milvus实例: https://milvus.io/docs/install_standalone-docker.md 如果正在寻找托管的Milvus,请查看此文档: https://zilliz.com/cloud 并使用在此项目中找到的Zilliz向量存储。 如果使用L2/IP度量,强烈建议对数据进行归一化。 参数: embedding_function (Embeddings): 用于嵌入文本的函数。 collection_name (str): 要使用的Milvus集合。默认为"LangChainCollection"。 collection_description (str): 集合的描述。默认为空。 collection_properties (Optional[dict[str, any]]): 集合属性。默认为None。 如果设置,将覆盖集合的现有属性。 例如: {"collection.ttl.seconds": 60}。 connection_args (Optional[dict[str, any]]): 用于此类的连接参数以字典形式提供。 consistency_level (str): 用于集合的一致性级别。默认为"Session"。 index_params (Optional[dict]): 要使用的索引参数。默认为HNSW/AUTOINDEX,取决于服务。 search_params (Optional[dict]): 要使用的搜索参数。默认为索引的默认值。 drop_old (Optional[bool]): 是否删除当前集合。默认为False。 auto_id (bool): 是否启用主键的自动id。默认为False。 如果为False,您需要提供文本id(小于65535字节的字符串)。 如果为True,Milvus将生成唯一整数作为主键。 primary_field (str): 主键字段的名称。默认为"pk"。 text_field (str): 文本字段的名称。默认为"text"。 vector_field (str): 向量字段的名称。默认为"vector"。 metadata_field (str): 元数据字段的名称。默认为None。 当指定metadata_field时, 文档的元数据将存储为json。 用于此类的连接参数以字典形式提供, 这里是一些选项: address (str): Milvus实例的实际地址。 示例地址: "localhost:19530"。 uri (str): Milvus实例的uri。 示例uri: "http://randomwebsite:19530", "tcp:foobarsite:19530", "https://ok.s3.south.com:19530"。 host (str): Milvus实例的主机。默认为"localhost", 如果只提供端口,PyMilvus将填充默认主机。 port (str/int): Milvus实例的端口。默认为19530, 如果只提供主机,PyMilvus将填充默认端口。 user (str): 用于连接到Milvus实例的用户。 如果提供了用户和密码,我们将在每个RPC调用中添加相关的标头。 password (str): 在提供用户时需要。与用户对应的密码。 secure (bool): 默认为false。如果设置为true,将启用tls。 client_key_path (str): 如果使用tls双向认证,需要写入client.key路径。 client_pem_path (str): 如果使用tls双向认证,需要写入client.pem路径。 ca_pem_path (str): 如果使用tls双向认证,需要写入ca.pem路径。 server_pem_path (str): 如果使用tls单向认证,需要写入server.pem路径。 server_name (str): 如果使用tls,需要写入通用名称。 示例: .. code-block:: python from langchain_community.vectorstores import Milvus from langchain_community.embeddings import OpenAIEmbeddings embedding = OpenAIEmbeddings() # 连接到本地主机上的milvus实例 milvus_store = Milvus( embedding_function = Embeddings, collection_name = "LangChainCollection", drop_old = True, auto_id = True ) 引发: ValueError: 如果未安装pymilvus python包。"""
[docs] def __init__( self, embedding_function: Embeddings, collection_name: str = "LangChainCollection", collection_description: str = "", collection_properties: Optional[dict[str, Any]] = None, connection_args: Optional[dict[str, Any]] = None, consistency_level: str = "Session", index_params: Optional[dict] = None, search_params: Optional[dict] = None, drop_old: Optional[bool] = False, auto_id: bool = False, *, primary_field: str = "pk", text_field: str = "text", vector_field: str = "vector", metadata_field: Optional[str] = None, partition_key_field: Optional[str] = None, partition_names: Optional[list] = None, replica_number: int = 1, timeout: Optional[float] = None, num_shards: Optional[int] = None, ): """初始化Milvus向量存储。""" try: from pymilvus import Collection, utility except ImportError: raise ImportError( "Could not import pymilvus python package. " "Please install it with `pip install pymilvus`." ) # Default search params when one is not provided. self.default_search_params = { "IVF_FLAT": {"metric_type": "L2", "params": {"nprobe": 10}}, "IVF_SQ8": {"metric_type": "L2", "params": {"nprobe": 10}}, "IVF_PQ": {"metric_type": "L2", "params": {"nprobe": 10}}, "HNSW": {"metric_type": "L2", "params": {"ef": 10}}, "RHNSW_FLAT": {"metric_type": "L2", "params": {"ef": 10}}, "RHNSW_SQ": {"metric_type": "L2", "params": {"ef": 10}}, "RHNSW_PQ": {"metric_type": "L2", "params": {"ef": 10}}, "IVF_HNSW": {"metric_type": "L2", "params": {"nprobe": 10, "ef": 10}}, "ANNOY": {"metric_type": "L2", "params": {"search_k": 10}}, "SCANN": {"metric_type": "L2", "params": {"search_k": 10}}, "AUTOINDEX": {"metric_type": "L2", "params": {}}, "GPU_CAGRA": { "metric_type": "L2", "params": { "itopk_size": 128, "search_width": 4, "min_iterations": 0, "max_iterations": 0, "team_size": 0, }, }, "GPU_IVF_FLAT": {"metric_type": "L2", "params": {"nprobe": 10}}, "GPU_IVF_PQ": {"metric_type": "L2", "params": {"nprobe": 10}}, } self.embedding_func = embedding_function self.collection_name = collection_name self.collection_description = collection_description self.collection_properties = collection_properties self.index_params = index_params self.search_params = search_params self.consistency_level = consistency_level self.auto_id = auto_id # In order for a collection to be compatible, pk needs to be varchar self._primary_field = primary_field # In order for compatibility, the text field will need to be called "text" self._text_field = text_field # In order for compatibility, the vector field needs to be called "vector" self._vector_field = vector_field self._metadata_field = metadata_field self._partition_key_field = partition_key_field self.fields: list[str] = [] self.partition_names = partition_names self.replica_number = replica_number self.timeout = timeout self.num_shards = num_shards # Create the connection to the server if connection_args is None: connection_args = DEFAULT_MILVUS_CONNECTION self.alias = self._create_connection_alias(connection_args) self.col: Optional[Collection] = None # Grab the existing collection if it exists if utility.has_collection(self.collection_name, using=self.alias): self.col = Collection( self.collection_name, using=self.alias, ) if self.collection_properties is not None: self.col.set_properties(self.collection_properties) # If need to drop old, drop it if drop_old and isinstance(self.col, Collection): self.col.drop() self.col = None # Initialize the vector store self._init( partition_names=partition_names, replica_number=replica_number, timeout=timeout, )
@property def embeddings(self) -> Embeddings: return self.embedding_func def _create_connection_alias(self, connection_args: dict) -> str: """创建与Milvus服务器的连接。""" from pymilvus import MilvusException, connections # Grab the connection arguments that are used for checking existing connection host: str = connection_args.get("host", None) port: Union[str, int] = connection_args.get("port", None) address: str = connection_args.get("address", None) uri: str = connection_args.get("uri", None) user = connection_args.get("user", None) # Order of use is host/port, uri, address if host is not None and port is not None: given_address = str(host) + ":" + str(port) elif uri is not None: if uri.startswith("https://"): given_address = uri.split("https://")[1] elif uri.startswith("http://"): given_address = uri.split("http://")[1] else: logger.error("Invalid Milvus URI: %s", uri) raise ValueError("Invalid Milvus URI: %s", uri) elif address is not None: given_address = address else: given_address = None logger.debug("Missing standard address type for reuse attempt") # User defaults to empty string when getting connection info if user is not None: tmp_user = user else: tmp_user = "" # If a valid address was given, then check if a connection exists if given_address is not None: for con in connections.list_connections(): addr = connections.get_connection_addr(con[0]) if ( con[1] and ("address" in addr) and (addr["address"] == given_address) and ("user" in addr) and (addr["user"] == tmp_user) ): logger.debug("Using previous connection: %s", con[0]) return con[0] # Generate a new connection if one doesn't exist alias = uuid4().hex try: connections.connect(alias=alias, **connection_args) logger.debug("Created new connection using: %s", alias) return alias except MilvusException as e: logger.error("Failed to create new connection using: %s", alias) raise e def _init( self, embeddings: Optional[list] = None, metadatas: Optional[list[dict]] = None, partition_names: Optional[list] = None, replica_number: int = 1, timeout: Optional[float] = None, ) -> None: if embeddings is not None: self._create_collection(embeddings, metadatas) self._extract_fields() self._create_index() self._create_search_params() self._load( partition_names=partition_names, replica_number=replica_number, timeout=timeout, ) def _create_collection( self, embeddings: list, metadatas: Optional[list[dict]] = None ) -> None: from pymilvus import ( Collection, CollectionSchema, DataType, FieldSchema, MilvusException, ) from pymilvus.orm.types import infer_dtype_bydata # Determine embedding dim dim = len(embeddings[0]) fields = [] if self._metadata_field is not None: fields.append(FieldSchema(self._metadata_field, DataType.JSON)) else: # Determine metadata schema if metadatas: # Create FieldSchema for each entry in metadata. for key, value in metadatas[0].items(): # Infer the corresponding datatype of the metadata dtype = infer_dtype_bydata(value) # Datatype isn't compatible if dtype == DataType.UNKNOWN or dtype == DataType.NONE: logger.error( ( "Failure to create collection, " "unrecognized dtype for key: %s" ), key, ) raise ValueError(f"Unrecognized datatype for {key}.") # Dataype is a string/varchar equivalent elif dtype == DataType.VARCHAR: fields.append( FieldSchema(key, DataType.VARCHAR, max_length=65_535) ) else: fields.append(FieldSchema(key, dtype)) # Create the text field fields.append( FieldSchema(self._text_field, DataType.VARCHAR, max_length=65_535) ) # Create the primary key field if self.auto_id: fields.append( FieldSchema( self._primary_field, DataType.INT64, is_primary=True, auto_id=True ) ) else: fields.append( FieldSchema( self._primary_field, DataType.VARCHAR, is_primary=True, auto_id=False, max_length=65_535, ) ) # Create the vector field, supports binary or float vectors fields.append( FieldSchema(self._vector_field, infer_dtype_bydata(embeddings[0]), dim=dim) ) # Create the schema for the collection schema = CollectionSchema( fields, description=self.collection_description, partition_key_field=self._partition_key_field, ) # Create the collection try: if self.num_shards is not None: # Issue with defaults: # https://github.com/milvus-io/pymilvus/blob/59bf5e811ad56e20946559317fed855330758d9c/pymilvus/client/prepare.py#L82-L85 self.col = Collection( name=self.collection_name, schema=schema, consistency_level=self.consistency_level, using=self.alias, num_shards=self.num_shards, ) else: self.col = Collection( name=self.collection_name, schema=schema, consistency_level=self.consistency_level, using=self.alias, ) # Set the collection properties if they exist if self.collection_properties is not None: self.col.set_properties(self.collection_properties) except MilvusException as e: logger.error( "Failed to create collection: %s error: %s", self.collection_name, e ) raise e def _extract_fields(self) -> None: """从集合中获取现有字段""" from pymilvus import Collection if isinstance(self.col, Collection): schema = self.col.schema for x in schema.fields: self.fields.append(x.name) def _get_index(self) -> Optional[dict[str, Any]]: """如果存在,返回向量索引信息""" from pymilvus import Collection if isinstance(self.col, Collection): for x in self.col.indexes: if x.field_name == self._vector_field: return x.to_dict() return None def _create_index(self) -> None: """在集合上创建一个索引""" from pymilvus import Collection, MilvusException if isinstance(self.col, Collection) and self._get_index() is None: try: # If no index params, use a default HNSW based one if self.index_params is None: self.index_params = { "metric_type": "L2", "index_type": "HNSW", "params": {"M": 8, "efConstruction": 64}, } try: self.col.create_index( self._vector_field, index_params=self.index_params, using=self.alias, ) # If default did not work, most likely on Zilliz Cloud except MilvusException: # Use AUTOINDEX based index self.index_params = { "metric_type": "L2", "index_type": "AUTOINDEX", "params": {}, } self.col.create_index( self._vector_field, index_params=self.index_params, using=self.alias, ) logger.debug( "Successfully created an index on collection: %s", self.collection_name, ) except MilvusException as e: logger.error( "Failed to create an index on collection: %s", self.collection_name ) raise e def _create_search_params(self) -> None: """根据当前索引类型生成搜索参数""" from pymilvus import Collection if isinstance(self.col, Collection) and self.search_params is None: index = self._get_index() if index is not None: index_type: str = index["index_param"]["index_type"] metric_type: str = index["index_param"]["metric_type"] self.search_params = self.default_search_params[index_type] self.search_params["metric_type"] = metric_type def _load( self, partition_names: Optional[list] = None, replica_number: int = 1, timeout: Optional[float] = None, ) -> None: """如果可用,加载集合。""" from pymilvus import Collection, utility from pymilvus.client.types import LoadState timeout = self.timeout or timeout if ( isinstance(self.col, Collection) and self._get_index() is not None and utility.load_state(self.collection_name, using=self.alias) == LoadState.NotLoad ): self.col.load( partition_names=partition_names, replica_number=replica_number, timeout=timeout, )
[docs] def add_texts( self, texts: Iterable[str], metadatas: Optional[List[dict]] = None, timeout: Optional[float] = None, batch_size: int = 1000, *, ids: Optional[List[str]] = None, **kwargs: Any, ) -> List[str]: """将文本数据插入Milvus。 在尚未创建集合时插入数据将导致创建新的集合。第一个实体的数据决定了新集合的模式,dim从第一个嵌入中提取,列由第一个元数据字典决定。所有插入的值都需要存在元数据键。目前在Milvus中没有None的等价物。 参数: texts (Iterable[str]): 要嵌入的文本,假定它们都适合内存。 metadatas (Optional[List[dict]]): 附加到每个文本的元数据字典。默认为None。 应小于65535字节。在auto_id为False时是必需的并且有效。 timeout (Optional[float]): 每个批次插入的超时时间。默认为None。 batch_size (int, optional): 用于插入的批次大小。默认为1000。 ids (Optional[List[str]]): 文本id列表。每个项目的长度 引发: MilvusException: 添加文本失败 返回: List[str]: 每个插入元素的结果键。 """ from pymilvus import Collection, MilvusException texts = list(texts) if not self.auto_id: assert isinstance( ids, list ), "A list of valid ids are required when auto_id is False." assert len(set(ids)) == len( texts ), "Different lengths of texts and unique ids are provided." assert all( len(x.encode()) <= 65_535 for x in ids ), "Each id should be a string less than 65535 bytes." try: embeddings = self.embedding_func.embed_documents(texts) except NotImplementedError: embeddings = [self.embedding_func.embed_query(x) for x in texts] if len(embeddings) == 0: logger.debug("Nothing to insert, skipping.") return [] # If the collection hasn't been initialized yet, perform all steps to do so if not isinstance(self.col, Collection): kwargs = {"embeddings": embeddings, "metadatas": metadatas} if self.partition_names: kwargs["partition_names"] = self.partition_names if self.replica_number: kwargs["replica_number"] = self.replica_number if self.timeout: kwargs["timeout"] = self.timeout self._init(**kwargs) # Dict to hold all insert columns insert_dict: dict[str, list] = { self._text_field: texts, self._vector_field: embeddings, } if not self.auto_id: insert_dict[self._primary_field] = ids # type: ignore[assignment] if self._metadata_field is not None: for d in metadatas: # type: ignore[union-attr] insert_dict.setdefault(self._metadata_field, []).append(d) else: # Collect the metadata into the insert dict. if metadatas is not None: for d in metadatas: for key, value in d.items(): keys = ( [x for x in self.fields if x != self._primary_field] if self.auto_id else [x for x in self.fields] ) if key in keys: insert_dict.setdefault(key, []).append(value) # Total insert count vectors: list = insert_dict[self._vector_field] total_count = len(vectors) pks: list[str] = [] assert isinstance(self.col, Collection) for i in range(0, total_count, batch_size): # Grab end index end = min(i + batch_size, total_count) # Convert dict to list of lists batch for insertion insert_list = [ insert_dict[x][i:end] for x in self.fields if x in insert_dict ] # Insert into the collection. try: res: Collection timeout = self.timeout or timeout res = self.col.insert(insert_list, timeout=timeout, **kwargs) pks.extend(res.primary_keys) except MilvusException as e: logger.error( "Failed to insert batch starting at entity: %s/%s", i, total_count ) raise e return pks
[docs] def similarity_search_by_vector( self, embedding: List[float], k: int = 4, param: Optional[dict] = None, expr: Optional[str] = None, timeout: Optional[float] = None, **kwargs: Any, ) -> List[Document]: """对查询字符串执行相似性搜索。 参数: embedding(List[float]):要搜索的嵌入向量。 k(int,可选):要返回的结果数量。默认为4。 param(dict,可选):索引类型的搜索参数。默认为None。 expr(str,可选):过滤表达式。默认为None。 timeout(int,可选):超时错误前等待的时间。默认为None。 kwargs:Collection.search()的关键字参数。 返回: List[Document]:搜索的文档结果。 """ if self.col is None: logger.debug("No existing collection to search.") return [] timeout = self.timeout or timeout res = self.similarity_search_with_score_by_vector( embedding=embedding, k=k, param=param, expr=expr, timeout=timeout, **kwargs ) return [doc for doc, _ in res]
[docs] def similarity_search_with_score( self, query: str, k: int = 4, param: Optional[dict] = None, expr: Optional[str] = None, timeout: Optional[float] = None, **kwargs: Any, ) -> List[Tuple[Document, float]]: """对查询字符串执行搜索,并返回带有分数的结果。 有关搜索参数的更多信息,请查看pymilvus文档,网址如下: https://milvus.io/api-reference/pymilvus/v2.2.6/Collection/search().md 参数: query (str): 要搜索的文本。 k (int, 可选): 要返回的结果数量。默认为4。 param (dict): 指定索引的搜索参数。默认为None。 expr (str, 可选): 过滤表达式。默认为None。 timeout (float, 可选): 超时错误前等待的时间。默认为None。 kwargs: Collection.search() 的关键字参数。 返回: List[float], List[Tuple[Document, any, any]]: """ if self.col is None: logger.debug("No existing collection to search.") return [] # Embed the query text. embedding = self.embedding_func.embed_query(query) timeout = self.timeout or timeout res = self.similarity_search_with_score_by_vector( embedding=embedding, k=k, param=param, expr=expr, timeout=timeout, **kwargs ) return res
[docs] def similarity_search_with_score_by_vector( self, embedding: List[float], k: int = 4, param: Optional[dict] = None, expr: Optional[str] = None, timeout: Optional[float] = None, **kwargs: Any, ) -> List[Tuple[Document, float]]: """对查询字符串执行搜索,并返回带有分数的结果。 有关搜索参数的更多信息,请查看pymilvus文档,链接如下: https://milvus.io/api-reference/pymilvus/v2.2.6/Collection/search().md 参数: embedding(List[float]):要搜索的嵌入向量。 k(int,可选):要返回的结果数量。默认为4。 param(dict):指定索引的搜索参数。默认为None。 expr(str,可选):过滤表达式。默认为None。 timeout(float,可选):超时错误前的等待时间。默认为None。 kwargs:Collection.search()的关键字参数。 返回: List[Tuple[Document, float]]:结果文档和分数。 """ if self.col is None: logger.debug("No existing collection to search.") return [] if param is None: param = self.search_params # Determine result metadata fields with PK. output_fields = self.fields[:] output_fields.remove(self._vector_field) timeout = self.timeout or timeout # Perform the search. res = self.col.search( data=[embedding], anns_field=self._vector_field, param=param, limit=k, expr=expr, output_fields=output_fields, timeout=timeout, **kwargs, ) # Organize results. ret = [] for result in res[0]: data = {x: result.entity.get(x) for x in output_fields} doc = self._parse_document(data) pair = (doc, result.score) ret.append(pair) return ret
[docs] def max_marginal_relevance_search_by_vector( self, embedding: list[float], k: int = 4, fetch_k: int = 20, lambda_mult: float = 0.5, param: Optional[dict] = None, expr: Optional[str] = None, timeout: Optional[float] = None, **kwargs: Any, ) -> List[Document]: """执行搜索并返回按MMR重新排序的结果。 参数: embedding (str): 正在搜索的嵌入向量。 k (int, optional): 要返回的结果数量。默认为4。 fetch_k (int, optional): 从中选择k的总结果数量。默认为20。 lambda_mult: 介于0和1之间的数字,确定结果之间多样性的程度,其中0对应于最大多样性,1对应于最小多样性。默认为0.5。 param (dict, optional): 指定索引的搜索参数。默认为None。 expr (str, optional): 过滤表达式。默认为None。 timeout (float, optional): 超时错误前等待的时间长度。默认为None。 kwargs: Collection.search()的关键字参数。 返回: List[Document]: 搜索的文档结果。 """ if self.col is None: logger.debug("No existing collection to search.") return [] if param is None: param = self.search_params # Determine result metadata fields. output_fields = self.fields[:] output_fields.remove(self._vector_field) timeout = self.timeout or timeout # Perform the search. res = self.col.search( data=[embedding], anns_field=self._vector_field, param=param, limit=fetch_k, expr=expr, output_fields=output_fields, timeout=timeout, **kwargs, ) # Organize results. ids = [] documents = [] scores = [] for result in res[0]: data = {x: result.entity.get(x) for x in output_fields} doc = self._parse_document(data) documents.append(doc) scores.append(result.score) ids.append(result.id) vectors = self.col.query( expr=f"{self._primary_field} in {ids}", output_fields=[self._primary_field, self._vector_field], timeout=timeout, ) # Reorganize the results from query to match search order. vectors = {x[self._primary_field]: x[self._vector_field] for x in vectors} ordered_result_embeddings = [vectors[x] for x in ids] # Get the new order of results. new_ordering = maximal_marginal_relevance( np.array(embedding), ordered_result_embeddings, k=k, lambda_mult=lambda_mult ) # Reorder the values and return. ret = [] for x in new_ordering: # Function can return -1 index if x == -1: break else: ret.append(documents[x]) return ret
[docs] def delete( # type: ignore[no-untyped-def] self, ids: Optional[List[str]] = None, expr: Optional[str] = None, **kwargs: str ): """根据向量ID或布尔表达式删除。 请参考[Milvus文档](https://milvus.io/docs/delete_data.md)查看表达式的说明和示例。 参数: ids: 要删除的ID列表。 expr: 指定要删除的实体的布尔表达式。 kwargs: Milvus删除API中的其他参数。 """ if isinstance(ids, list) and len(ids) > 0: if expr is not None: logger.warning( "Both ids and expr are provided. " "Ignore expr and delete by ids." ) expr = f"{self._primary_field} in {ids}" else: assert isinstance( expr, str ), "Either ids list or expr string must be provided." return self.col.delete(expr=expr, **kwargs) # type: ignore[union-attr]
[docs] @classmethod def from_texts( cls, texts: List[str], embedding: Embeddings, metadatas: Optional[List[dict]] = None, collection_name: str = "LangChainCollection", connection_args: dict[str, Any] = DEFAULT_MILVUS_CONNECTION, consistency_level: str = "Session", index_params: Optional[dict] = None, search_params: Optional[dict] = None, drop_old: bool = False, *, ids: Optional[List[str]] = None, **kwargs: Any, ) -> Milvus: """创建一个Milvus集合,使用HNSW对其进行索引,并插入数据。 参数: texts (List[str]): 文本数据。 embedding (Embeddings): 嵌入函数。 metadatas (Optional[List[dict]]): 如果存在,每个文本的元数据。 默认为None。 collection_name (str, optional): 要使用的集合名称。默认为"LangChainCollection"。 connection_args (dict[str, Any], optional): 要使用的连接参数。默认为DEFAULT_MILVUS_CONNECTION。 consistency_level (str, optional): 要使用的一致性级别。默认为"Session"。 index_params (Optional[dict], optional): 要使用的index_params。默认为None。 search_params (Optional[dict], optional): 要使用的搜索参数。默认为None。 drop_old (Optional[bool], optional): 如果存在,是否删除该名称的集合。默认为False。 ids (Optional[List[str]]): 文本id列表。默认为None。 返回: Milvus: Milvus向量存储器 """ if isinstance(ids, list) and len(ids) > 0: auto_id = False else: auto_id = True vector_db = cls( embedding_function=embedding, collection_name=collection_name, connection_args=connection_args, consistency_level=consistency_level, index_params=index_params, search_params=search_params, drop_old=drop_old, auto_id=auto_id, **kwargs, ) vector_db.add_texts(texts=texts, metadatas=metadatas, ids=ids) return vector_db
def _parse_document(self, data: dict) -> Document: return Document( page_content=data.pop(self._text_field), metadata=data.pop(self._metadata_field) if self._metadata_field else data, )
[docs] def get_pks(self, expr: str, **kwargs: Any) -> List[int] | None: """获取带有表达式的主键 参数: expr: 表达式 - 例如:"id in [1, 2]",或者 "title LIKE 'Abc%'" 返回: List[int]: ID列表(主键) """ from pymilvus import MilvusException if self.col is None: logger.debug("No existing collection to get pk.") return None try: query_result = self.col.query( expr=expr, output_fields=[self._primary_field] ) except MilvusException as exc: logger.error("Failed to get ids: %s error: %s", self.collection_name, exc) raise exc pks = [item.get(self._primary_field) for item in query_result] return pks
[docs] def upsert( self, ids: Optional[List[str]] = None, documents: List[Document] | None = None, **kwargs: Any, ) -> List[str] | None: """更新/插入文档到向量存储。 参数: ids: 要更新的ID - 让我们调用get_pks来获取带有表达式的ID documents(List[Document]):要添加到向量存储的文档。 返回: List[str]:已添加文本的ID。 """ from pymilvus import MilvusException if documents is None or len(documents) == 0: logger.debug("No documents to upsert.") return None if ids is not None and len(ids): try: self.delete(ids=ids) except MilvusException: pass try: return self.add_documents(documents=documents, **kwargs) except MilvusException as exc: logger.error( "Failed to upsert entities: %s error: %s", self.collection_name, exc ) raise exc