from __future__ import annotations
import logging
from typing import Any, Iterable, List, Optional, Tuple, Union
from uuid import uuid4
import numpy as np
from langchain_core.documents import Document
from langchain_core.embeddings import Embeddings
from langchain_core.vectorstores import VectorStore
from langchain_community.vectorstores.utils import maximal_marginal_relevance
logger = logging.getLogger(__name__)
DEFAULT_MILVUS_CONNECTION = {
"host": "localhost",
"port": "19530",
"user": "",
"password": "",
"secure": False,
}
[docs]class Milvus(VectorStore):
"""`Milvus`向量存储。
您需要安装`pymilvus`并运行Milvus。
请查看以下文档以了解如何运行Milvus实例:
https://milvus.io/docs/install_standalone-docker.md
如果正在寻找托管的Milvus,请查看此文档:
https://zilliz.com/cloud 并使用在此项目中找到的Zilliz向量存储。
如果使用L2/IP度量,强烈建议对数据进行归一化。
参数:
embedding_function (Embeddings): 用于嵌入文本的函数。
collection_name (str): 要使用的Milvus集合。默认为"LangChainCollection"。
collection_description (str): 集合的描述。默认为空。
collection_properties (Optional[dict[str, any]]): 集合属性。默认为None。
如果设置,将覆盖集合的现有属性。
例如: {"collection.ttl.seconds": 60}。
connection_args (Optional[dict[str, any]]): 用于此类的连接参数以字典形式提供。
consistency_level (str): 用于集合的一致性级别。默认为"Session"。
index_params (Optional[dict]): 要使用的索引参数。默认为HNSW/AUTOINDEX,取决于服务。
search_params (Optional[dict]): 要使用的搜索参数。默认为索引的默认值。
drop_old (Optional[bool]): 是否删除当前集合。默认为False。
auto_id (bool): 是否启用主键的自动id。默认为False。
如果为False,您需要提供文本id(小于65535字节的字符串)。
如果为True,Milvus将生成唯一整数作为主键。
primary_field (str): 主键字段的名称。默认为"pk"。
text_field (str): 文本字段的名称。默认为"text"。
vector_field (str): 向量字段的名称。默认为"vector"。
metadata_field (str): 元数据字段的名称。默认为None。
当指定metadata_field时,
文档的元数据将存储为json。
用于此类的连接参数以字典形式提供,
这里是一些选项:
address (str): Milvus实例的实际地址。
示例地址: "localhost:19530"。
uri (str): Milvus实例的uri。
示例uri: "http://randomwebsite:19530",
"tcp:foobarsite:19530",
"https://ok.s3.south.com:19530"。
host (str): Milvus实例的主机。默认为"localhost",
如果只提供端口,PyMilvus将填充默认主机。
port (str/int): Milvus实例的端口。默认为19530,
如果只提供主机,PyMilvus将填充默认端口。
user (str): 用于连接到Milvus实例的用户。
如果提供了用户和密码,我们将在每个RPC调用中添加相关的标头。
password (str): 在提供用户时需要。与用户对应的密码。
secure (bool): 默认为false。如果设置为true,将启用tls。
client_key_path (str): 如果使用tls双向认证,需要写入client.key路径。
client_pem_path (str): 如果使用tls双向认证,需要写入client.pem路径。
ca_pem_path (str): 如果使用tls双向认证,需要写入ca.pem路径。
server_pem_path (str): 如果使用tls单向认证,需要写入server.pem路径。
server_name (str): 如果使用tls,需要写入通用名称。
示例:
.. code-block:: python
from langchain_community.vectorstores import Milvus
from langchain_community.embeddings import OpenAIEmbeddings
embedding = OpenAIEmbeddings()
# 连接到本地主机上的milvus实例
milvus_store = Milvus(
embedding_function = Embeddings,
collection_name = "LangChainCollection",
drop_old = True,
auto_id = True
)
引发:
ValueError: 如果未安装pymilvus python包。"""
[docs] def __init__(
self,
embedding_function: Embeddings,
collection_name: str = "LangChainCollection",
collection_description: str = "",
collection_properties: Optional[dict[str, Any]] = None,
connection_args: Optional[dict[str, Any]] = None,
consistency_level: str = "Session",
index_params: Optional[dict] = None,
search_params: Optional[dict] = None,
drop_old: Optional[bool] = False,
auto_id: bool = False,
*,
primary_field: str = "pk",
text_field: str = "text",
vector_field: str = "vector",
metadata_field: Optional[str] = None,
partition_key_field: Optional[str] = None,
partition_names: Optional[list] = None,
replica_number: int = 1,
timeout: Optional[float] = None,
num_shards: Optional[int] = None,
):
"""初始化Milvus向量存储。"""
try:
from pymilvus import Collection, utility
except ImportError:
raise ImportError(
"Could not import pymilvus python package. "
"Please install it with `pip install pymilvus`."
)
# Default search params when one is not provided.
self.default_search_params = {
"IVF_FLAT": {"metric_type": "L2", "params": {"nprobe": 10}},
"IVF_SQ8": {"metric_type": "L2", "params": {"nprobe": 10}},
"IVF_PQ": {"metric_type": "L2", "params": {"nprobe": 10}},
"HNSW": {"metric_type": "L2", "params": {"ef": 10}},
"RHNSW_FLAT": {"metric_type": "L2", "params": {"ef": 10}},
"RHNSW_SQ": {"metric_type": "L2", "params": {"ef": 10}},
"RHNSW_PQ": {"metric_type": "L2", "params": {"ef": 10}},
"IVF_HNSW": {"metric_type": "L2", "params": {"nprobe": 10, "ef": 10}},
"ANNOY": {"metric_type": "L2", "params": {"search_k": 10}},
"SCANN": {"metric_type": "L2", "params": {"search_k": 10}},
"AUTOINDEX": {"metric_type": "L2", "params": {}},
"GPU_CAGRA": {
"metric_type": "L2",
"params": {
"itopk_size": 128,
"search_width": 4,
"min_iterations": 0,
"max_iterations": 0,
"team_size": 0,
},
},
"GPU_IVF_FLAT": {"metric_type": "L2", "params": {"nprobe": 10}},
"GPU_IVF_PQ": {"metric_type": "L2", "params": {"nprobe": 10}},
}
self.embedding_func = embedding_function
self.collection_name = collection_name
self.collection_description = collection_description
self.collection_properties = collection_properties
self.index_params = index_params
self.search_params = search_params
self.consistency_level = consistency_level
self.auto_id = auto_id
# In order for a collection to be compatible, pk needs to be varchar
self._primary_field = primary_field
# In order for compatibility, the text field will need to be called "text"
self._text_field = text_field
# In order for compatibility, the vector field needs to be called "vector"
self._vector_field = vector_field
self._metadata_field = metadata_field
self._partition_key_field = partition_key_field
self.fields: list[str] = []
self.partition_names = partition_names
self.replica_number = replica_number
self.timeout = timeout
self.num_shards = num_shards
# Create the connection to the server
if connection_args is None:
connection_args = DEFAULT_MILVUS_CONNECTION
self.alias = self._create_connection_alias(connection_args)
self.col: Optional[Collection] = None
# Grab the existing collection if it exists
if utility.has_collection(self.collection_name, using=self.alias):
self.col = Collection(
self.collection_name,
using=self.alias,
)
if self.collection_properties is not None:
self.col.set_properties(self.collection_properties)
# If need to drop old, drop it
if drop_old and isinstance(self.col, Collection):
self.col.drop()
self.col = None
# Initialize the vector store
self._init(
partition_names=partition_names,
replica_number=replica_number,
timeout=timeout,
)
@property
def embeddings(self) -> Embeddings:
return self.embedding_func
def _create_connection_alias(self, connection_args: dict) -> str:
"""创建与Milvus服务器的连接。"""
from pymilvus import MilvusException, connections
# Grab the connection arguments that are used for checking existing connection
host: str = connection_args.get("host", None)
port: Union[str, int] = connection_args.get("port", None)
address: str = connection_args.get("address", None)
uri: str = connection_args.get("uri", None)
user = connection_args.get("user", None)
# Order of use is host/port, uri, address
if host is not None and port is not None:
given_address = str(host) + ":" + str(port)
elif uri is not None:
if uri.startswith("https://"):
given_address = uri.split("https://")[1]
elif uri.startswith("http://"):
given_address = uri.split("http://")[1]
else:
logger.error("Invalid Milvus URI: %s", uri)
raise ValueError("Invalid Milvus URI: %s", uri)
elif address is not None:
given_address = address
else:
given_address = None
logger.debug("Missing standard address type for reuse attempt")
# User defaults to empty string when getting connection info
if user is not None:
tmp_user = user
else:
tmp_user = ""
# If a valid address was given, then check if a connection exists
if given_address is not None:
for con in connections.list_connections():
addr = connections.get_connection_addr(con[0])
if (
con[1]
and ("address" in addr)
and (addr["address"] == given_address)
and ("user" in addr)
and (addr["user"] == tmp_user)
):
logger.debug("Using previous connection: %s", con[0])
return con[0]
# Generate a new connection if one doesn't exist
alias = uuid4().hex
try:
connections.connect(alias=alias, **connection_args)
logger.debug("Created new connection using: %s", alias)
return alias
except MilvusException as e:
logger.error("Failed to create new connection using: %s", alias)
raise e
def _init(
self,
embeddings: Optional[list] = None,
metadatas: Optional[list[dict]] = None,
partition_names: Optional[list] = None,
replica_number: int = 1,
timeout: Optional[float] = None,
) -> None:
if embeddings is not None:
self._create_collection(embeddings, metadatas)
self._extract_fields()
self._create_index()
self._create_search_params()
self._load(
partition_names=partition_names,
replica_number=replica_number,
timeout=timeout,
)
def _create_collection(
self, embeddings: list, metadatas: Optional[list[dict]] = None
) -> None:
from pymilvus import (
Collection,
CollectionSchema,
DataType,
FieldSchema,
MilvusException,
)
from pymilvus.orm.types import infer_dtype_bydata
# Determine embedding dim
dim = len(embeddings[0])
fields = []
if self._metadata_field is not None:
fields.append(FieldSchema(self._metadata_field, DataType.JSON))
else:
# Determine metadata schema
if metadatas:
# Create FieldSchema for each entry in metadata.
for key, value in metadatas[0].items():
# Infer the corresponding datatype of the metadata
dtype = infer_dtype_bydata(value)
# Datatype isn't compatible
if dtype == DataType.UNKNOWN or dtype == DataType.NONE:
logger.error(
(
"Failure to create collection, "
"unrecognized dtype for key: %s"
),
key,
)
raise ValueError(f"Unrecognized datatype for {key}.")
# Dataype is a string/varchar equivalent
elif dtype == DataType.VARCHAR:
fields.append(
FieldSchema(key, DataType.VARCHAR, max_length=65_535)
)
else:
fields.append(FieldSchema(key, dtype))
# Create the text field
fields.append(
FieldSchema(self._text_field, DataType.VARCHAR, max_length=65_535)
)
# Create the primary key field
if self.auto_id:
fields.append(
FieldSchema(
self._primary_field, DataType.INT64, is_primary=True, auto_id=True
)
)
else:
fields.append(
FieldSchema(
self._primary_field,
DataType.VARCHAR,
is_primary=True,
auto_id=False,
max_length=65_535,
)
)
# Create the vector field, supports binary or float vectors
fields.append(
FieldSchema(self._vector_field, infer_dtype_bydata(embeddings[0]), dim=dim)
)
# Create the schema for the collection
schema = CollectionSchema(
fields,
description=self.collection_description,
partition_key_field=self._partition_key_field,
)
# Create the collection
try:
if self.num_shards is not None:
# Issue with defaults:
# https://github.com/milvus-io/pymilvus/blob/59bf5e811ad56e20946559317fed855330758d9c/pymilvus/client/prepare.py#L82-L85
self.col = Collection(
name=self.collection_name,
schema=schema,
consistency_level=self.consistency_level,
using=self.alias,
num_shards=self.num_shards,
)
else:
self.col = Collection(
name=self.collection_name,
schema=schema,
consistency_level=self.consistency_level,
using=self.alias,
)
# Set the collection properties if they exist
if self.collection_properties is not None:
self.col.set_properties(self.collection_properties)
except MilvusException as e:
logger.error(
"Failed to create collection: %s error: %s", self.collection_name, e
)
raise e
def _extract_fields(self) -> None:
"""从集合中获取现有字段"""
from pymilvus import Collection
if isinstance(self.col, Collection):
schema = self.col.schema
for x in schema.fields:
self.fields.append(x.name)
def _get_index(self) -> Optional[dict[str, Any]]:
"""如果存在,返回向量索引信息"""
from pymilvus import Collection
if isinstance(self.col, Collection):
for x in self.col.indexes:
if x.field_name == self._vector_field:
return x.to_dict()
return None
def _create_index(self) -> None:
"""在集合上创建一个索引"""
from pymilvus import Collection, MilvusException
if isinstance(self.col, Collection) and self._get_index() is None:
try:
# If no index params, use a default HNSW based one
if self.index_params is None:
self.index_params = {
"metric_type": "L2",
"index_type": "HNSW",
"params": {"M": 8, "efConstruction": 64},
}
try:
self.col.create_index(
self._vector_field,
index_params=self.index_params,
using=self.alias,
)
# If default did not work, most likely on Zilliz Cloud
except MilvusException:
# Use AUTOINDEX based index
self.index_params = {
"metric_type": "L2",
"index_type": "AUTOINDEX",
"params": {},
}
self.col.create_index(
self._vector_field,
index_params=self.index_params,
using=self.alias,
)
logger.debug(
"Successfully created an index on collection: %s",
self.collection_name,
)
except MilvusException as e:
logger.error(
"Failed to create an index on collection: %s", self.collection_name
)
raise e
def _create_search_params(self) -> None:
"""根据当前索引类型生成搜索参数"""
from pymilvus import Collection
if isinstance(self.col, Collection) and self.search_params is None:
index = self._get_index()
if index is not None:
index_type: str = index["index_param"]["index_type"]
metric_type: str = index["index_param"]["metric_type"]
self.search_params = self.default_search_params[index_type]
self.search_params["metric_type"] = metric_type
def _load(
self,
partition_names: Optional[list] = None,
replica_number: int = 1,
timeout: Optional[float] = None,
) -> None:
"""如果可用,加载集合。"""
from pymilvus import Collection, utility
from pymilvus.client.types import LoadState
timeout = self.timeout or timeout
if (
isinstance(self.col, Collection)
and self._get_index() is not None
and utility.load_state(self.collection_name, using=self.alias)
== LoadState.NotLoad
):
self.col.load(
partition_names=partition_names,
replica_number=replica_number,
timeout=timeout,
)
[docs] def add_texts(
self,
texts: Iterable[str],
metadatas: Optional[List[dict]] = None,
timeout: Optional[float] = None,
batch_size: int = 1000,
*,
ids: Optional[List[str]] = None,
**kwargs: Any,
) -> List[str]:
"""将文本数据插入Milvus。
在尚未创建集合时插入数据将导致创建新的集合。第一个实体的数据决定了新集合的模式,dim从第一个嵌入中提取,列由第一个元数据字典决定。所有插入的值都需要存在元数据键。目前在Milvus中没有None的等价物。
参数:
texts (Iterable[str]): 要嵌入的文本,假定它们都适合内存。
metadatas (Optional[List[dict]]): 附加到每个文本的元数据字典。默认为None。
应小于65535字节。在auto_id为False时是必需的并且有效。
timeout (Optional[float]): 每个批次插入的超时时间。默认为None。
batch_size (int, optional): 用于插入的批次大小。默认为1000。
ids (Optional[List[str]]): 文本id列表。每个项目的长度
引发:
MilvusException: 添加文本失败
返回:
List[str]: 每个插入元素的结果键。
"""
from pymilvus import Collection, MilvusException
texts = list(texts)
if not self.auto_id:
assert isinstance(
ids, list
), "A list of valid ids are required when auto_id is False."
assert len(set(ids)) == len(
texts
), "Different lengths of texts and unique ids are provided."
assert all(
len(x.encode()) <= 65_535 for x in ids
), "Each id should be a string less than 65535 bytes."
try:
embeddings = self.embedding_func.embed_documents(texts)
except NotImplementedError:
embeddings = [self.embedding_func.embed_query(x) for x in texts]
if len(embeddings) == 0:
logger.debug("Nothing to insert, skipping.")
return []
# If the collection hasn't been initialized yet, perform all steps to do so
if not isinstance(self.col, Collection):
kwargs = {"embeddings": embeddings, "metadatas": metadatas}
if self.partition_names:
kwargs["partition_names"] = self.partition_names
if self.replica_number:
kwargs["replica_number"] = self.replica_number
if self.timeout:
kwargs["timeout"] = self.timeout
self._init(**kwargs)
# Dict to hold all insert columns
insert_dict: dict[str, list] = {
self._text_field: texts,
self._vector_field: embeddings,
}
if not self.auto_id:
insert_dict[self._primary_field] = ids # type: ignore[assignment]
if self._metadata_field is not None:
for d in metadatas: # type: ignore[union-attr]
insert_dict.setdefault(self._metadata_field, []).append(d)
else:
# Collect the metadata into the insert dict.
if metadatas is not None:
for d in metadatas:
for key, value in d.items():
keys = (
[x for x in self.fields if x != self._primary_field]
if self.auto_id
else [x for x in self.fields]
)
if key in keys:
insert_dict.setdefault(key, []).append(value)
# Total insert count
vectors: list = insert_dict[self._vector_field]
total_count = len(vectors)
pks: list[str] = []
assert isinstance(self.col, Collection)
for i in range(0, total_count, batch_size):
# Grab end index
end = min(i + batch_size, total_count)
# Convert dict to list of lists batch for insertion
insert_list = [
insert_dict[x][i:end] for x in self.fields if x in insert_dict
]
# Insert into the collection.
try:
res: Collection
timeout = self.timeout or timeout
res = self.col.insert(insert_list, timeout=timeout, **kwargs)
pks.extend(res.primary_keys)
except MilvusException as e:
logger.error(
"Failed to insert batch starting at entity: %s/%s", i, total_count
)
raise e
return pks
[docs] def similarity_search(
self,
query: str,
k: int = 4,
param: Optional[dict] = None,
expr: Optional[str] = None,
timeout: Optional[float] = None,
**kwargs: Any,
) -> List[Document]:
"""对查询字符串执行相似性搜索。
参数:
query (str):要搜索的文本。
k (int, 可选):要返回的结果数量。默认为4。
param (dict, 可选):索引类型的搜索参数。默认为None。
expr (str, 可选):过滤表达式。默认为None。
timeout (int, 可选):超时错误前等待的时间。默认为None。
kwargs:Collection.search() 的关键字参数。
返回:
List[Document]:搜索的文档结果。
"""
if self.col is None:
logger.debug("No existing collection to search.")
return []
timeout = self.timeout or timeout
res = self.similarity_search_with_score(
query=query, k=k, param=param, expr=expr, timeout=timeout, **kwargs
)
return [doc for doc, _ in res]
[docs] def similarity_search_by_vector(
self,
embedding: List[float],
k: int = 4,
param: Optional[dict] = None,
expr: Optional[str] = None,
timeout: Optional[float] = None,
**kwargs: Any,
) -> List[Document]:
"""对查询字符串执行相似性搜索。
参数:
embedding(List[float]):要搜索的嵌入向量。
k(int,可选):要返回的结果数量。默认为4。
param(dict,可选):索引类型的搜索参数。默认为None。
expr(str,可选):过滤表达式。默认为None。
timeout(int,可选):超时错误前等待的时间。默认为None。
kwargs:Collection.search()的关键字参数。
返回:
List[Document]:搜索的文档结果。
"""
if self.col is None:
logger.debug("No existing collection to search.")
return []
timeout = self.timeout or timeout
res = self.similarity_search_with_score_by_vector(
embedding=embedding, k=k, param=param, expr=expr, timeout=timeout, **kwargs
)
return [doc for doc, _ in res]
[docs] def similarity_search_with_score(
self,
query: str,
k: int = 4,
param: Optional[dict] = None,
expr: Optional[str] = None,
timeout: Optional[float] = None,
**kwargs: Any,
) -> List[Tuple[Document, float]]:
"""对查询字符串执行搜索,并返回带有分数的结果。
有关搜索参数的更多信息,请查看pymilvus文档,网址如下:
https://milvus.io/api-reference/pymilvus/v2.2.6/Collection/search().md
参数:
query (str): 要搜索的文本。
k (int, 可选): 要返回的结果数量。默认为4。
param (dict): 指定索引的搜索参数。默认为None。
expr (str, 可选): 过滤表达式。默认为None。
timeout (float, 可选): 超时错误前等待的时间。默认为None。
kwargs: Collection.search() 的关键字参数。
返回:
List[float], List[Tuple[Document, any, any]]:
"""
if self.col is None:
logger.debug("No existing collection to search.")
return []
# Embed the query text.
embedding = self.embedding_func.embed_query(query)
timeout = self.timeout or timeout
res = self.similarity_search_with_score_by_vector(
embedding=embedding, k=k, param=param, expr=expr, timeout=timeout, **kwargs
)
return res
[docs] def similarity_search_with_score_by_vector(
self,
embedding: List[float],
k: int = 4,
param: Optional[dict] = None,
expr: Optional[str] = None,
timeout: Optional[float] = None,
**kwargs: Any,
) -> List[Tuple[Document, float]]:
"""对查询字符串执行搜索,并返回带有分数的结果。
有关搜索参数的更多信息,请查看pymilvus文档,链接如下:
https://milvus.io/api-reference/pymilvus/v2.2.6/Collection/search().md
参数:
embedding(List[float]):要搜索的嵌入向量。
k(int,可选):要返回的结果数量。默认为4。
param(dict):指定索引的搜索参数。默认为None。
expr(str,可选):过滤表达式。默认为None。
timeout(float,可选):超时错误前的等待时间。默认为None。
kwargs:Collection.search()的关键字参数。
返回:
List[Tuple[Document, float]]:结果文档和分数。
"""
if self.col is None:
logger.debug("No existing collection to search.")
return []
if param is None:
param = self.search_params
# Determine result metadata fields with PK.
output_fields = self.fields[:]
output_fields.remove(self._vector_field)
timeout = self.timeout or timeout
# Perform the search.
res = self.col.search(
data=[embedding],
anns_field=self._vector_field,
param=param,
limit=k,
expr=expr,
output_fields=output_fields,
timeout=timeout,
**kwargs,
)
# Organize results.
ret = []
for result in res[0]:
data = {x: result.entity.get(x) for x in output_fields}
doc = self._parse_document(data)
pair = (doc, result.score)
ret.append(pair)
return ret
[docs] def max_marginal_relevance_search(
self,
query: str,
k: int = 4,
fetch_k: int = 20,
lambda_mult: float = 0.5,
param: Optional[dict] = None,
expr: Optional[str] = None,
timeout: Optional[float] = None,
**kwargs: Any,
) -> List[Document]:
"""执行搜索并返回按MMR重新排序的结果。
参数:
query (str): 要搜索的文本。
k (int, optional): 要返回的结果数量。默认为4。
fetch_k (int, optional): 从中选择k的总结果数量。默认为20。
lambda_mult: 介于0和1之间的数字,确定结果之间多样性的程度,其中0对应最大多样性,1对应最小多样性。默认为0.5。
param (dict, optional): 指定索引的搜索参数。默认为None。
expr (str, optional): 过滤表达式。默认为None。
timeout (float, optional): 超时错误前等待的时间。默认为None。
kwargs: Collection.search()的关键字参数。
返回:
List[Document]: 搜索的文档结果。
"""
if self.col is None:
logger.debug("No existing collection to search.")
return []
embedding = self.embedding_func.embed_query(query)
timeout = self.timeout or timeout
return self.max_marginal_relevance_search_by_vector(
embedding=embedding,
k=k,
fetch_k=fetch_k,
lambda_mult=lambda_mult,
param=param,
expr=expr,
timeout=timeout,
**kwargs,
)
[docs] def max_marginal_relevance_search_by_vector(
self,
embedding: list[float],
k: int = 4,
fetch_k: int = 20,
lambda_mult: float = 0.5,
param: Optional[dict] = None,
expr: Optional[str] = None,
timeout: Optional[float] = None,
**kwargs: Any,
) -> List[Document]:
"""执行搜索并返回按MMR重新排序的结果。
参数:
embedding (str): 正在搜索的嵌入向量。
k (int, optional): 要返回的结果数量。默认为4。
fetch_k (int, optional): 从中选择k的总结果数量。默认为20。
lambda_mult: 介于0和1之间的数字,确定结果之间多样性的程度,其中0对应于最大多样性,1对应于最小多样性。默认为0.5。
param (dict, optional): 指定索引的搜索参数。默认为None。
expr (str, optional): 过滤表达式。默认为None。
timeout (float, optional): 超时错误前等待的时间长度。默认为None。
kwargs: Collection.search()的关键字参数。
返回:
List[Document]: 搜索的文档结果。
"""
if self.col is None:
logger.debug("No existing collection to search.")
return []
if param is None:
param = self.search_params
# Determine result metadata fields.
output_fields = self.fields[:]
output_fields.remove(self._vector_field)
timeout = self.timeout or timeout
# Perform the search.
res = self.col.search(
data=[embedding],
anns_field=self._vector_field,
param=param,
limit=fetch_k,
expr=expr,
output_fields=output_fields,
timeout=timeout,
**kwargs,
)
# Organize results.
ids = []
documents = []
scores = []
for result in res[0]:
data = {x: result.entity.get(x) for x in output_fields}
doc = self._parse_document(data)
documents.append(doc)
scores.append(result.score)
ids.append(result.id)
vectors = self.col.query(
expr=f"{self._primary_field} in {ids}",
output_fields=[self._primary_field, self._vector_field],
timeout=timeout,
)
# Reorganize the results from query to match search order.
vectors = {x[self._primary_field]: x[self._vector_field] for x in vectors}
ordered_result_embeddings = [vectors[x] for x in ids]
# Get the new order of results.
new_ordering = maximal_marginal_relevance(
np.array(embedding), ordered_result_embeddings, k=k, lambda_mult=lambda_mult
)
# Reorder the values and return.
ret = []
for x in new_ordering:
# Function can return -1 index
if x == -1:
break
else:
ret.append(documents[x])
return ret
[docs] def delete( # type: ignore[no-untyped-def]
self, ids: Optional[List[str]] = None, expr: Optional[str] = None, **kwargs: str
):
"""根据向量ID或布尔表达式删除。
请参考[Milvus文档](https://milvus.io/docs/delete_data.md)查看表达式的说明和示例。
参数:
ids: 要删除的ID列表。
expr: 指定要删除的实体的布尔表达式。
kwargs: Milvus删除API中的其他参数。
"""
if isinstance(ids, list) and len(ids) > 0:
if expr is not None:
logger.warning(
"Both ids and expr are provided. " "Ignore expr and delete by ids."
)
expr = f"{self._primary_field} in {ids}"
else:
assert isinstance(
expr, str
), "Either ids list or expr string must be provided."
return self.col.delete(expr=expr, **kwargs) # type: ignore[union-attr]
[docs] @classmethod
def from_texts(
cls,
texts: List[str],
embedding: Embeddings,
metadatas: Optional[List[dict]] = None,
collection_name: str = "LangChainCollection",
connection_args: dict[str, Any] = DEFAULT_MILVUS_CONNECTION,
consistency_level: str = "Session",
index_params: Optional[dict] = None,
search_params: Optional[dict] = None,
drop_old: bool = False,
*,
ids: Optional[List[str]] = None,
**kwargs: Any,
) -> Milvus:
"""创建一个Milvus集合,使用HNSW对其进行索引,并插入数据。
参数:
texts (List[str]): 文本数据。
embedding (Embeddings): 嵌入函数。
metadatas (Optional[List[dict]]): 如果存在,每个文本的元数据。
默认为None。
collection_name (str, optional): 要使用的集合名称。默认为"LangChainCollection"。
connection_args (dict[str, Any], optional): 要使用的连接参数。默认为DEFAULT_MILVUS_CONNECTION。
consistency_level (str, optional): 要使用的一致性级别。默认为"Session"。
index_params (Optional[dict], optional): 要使用的index_params。默认为None。
search_params (Optional[dict], optional): 要使用的搜索参数。默认为None。
drop_old (Optional[bool], optional): 如果存在,是否删除该名称的集合。默认为False。
ids (Optional[List[str]]): 文本id列表。默认为None。
返回:
Milvus: Milvus向量存储器
"""
if isinstance(ids, list) and len(ids) > 0:
auto_id = False
else:
auto_id = True
vector_db = cls(
embedding_function=embedding,
collection_name=collection_name,
connection_args=connection_args,
consistency_level=consistency_level,
index_params=index_params,
search_params=search_params,
drop_old=drop_old,
auto_id=auto_id,
**kwargs,
)
vector_db.add_texts(texts=texts, metadatas=metadatas, ids=ids)
return vector_db
def _parse_document(self, data: dict) -> Document:
return Document(
page_content=data.pop(self._text_field),
metadata=data.pop(self._metadata_field) if self._metadata_field else data,
)
[docs] def get_pks(self, expr: str, **kwargs: Any) -> List[int] | None:
"""获取带有表达式的主键
参数:
expr: 表达式 - 例如:"id in [1, 2]",或者 "title LIKE 'Abc%'"
返回:
List[int]: ID列表(主键)
"""
from pymilvus import MilvusException
if self.col is None:
logger.debug("No existing collection to get pk.")
return None
try:
query_result = self.col.query(
expr=expr, output_fields=[self._primary_field]
)
except MilvusException as exc:
logger.error("Failed to get ids: %s error: %s", self.collection_name, exc)
raise exc
pks = [item.get(self._primary_field) for item in query_result]
return pks
[docs] def upsert(
self,
ids: Optional[List[str]] = None,
documents: List[Document] | None = None,
**kwargs: Any,
) -> List[str] | None:
"""更新/插入文档到向量存储。
参数:
ids: 要更新的ID - 让我们调用get_pks来获取带有表达式的ID
documents(List[Document]):要添加到向量存储的文档。
返回:
List[str]:已添加文本的ID。
"""
from pymilvus import MilvusException
if documents is None or len(documents) == 0:
logger.debug("No documents to upsert.")
return None
if ids is not None and len(ids):
try:
self.delete(ids=ids)
except MilvusException:
pass
try:
return self.add_documents(documents=documents, **kwargs)
except MilvusException as exc:
logger.error(
"Failed to upsert entities: %s error: %s", self.collection_name, exc
)
raise exc