"""封装LLMRails向量数据库。"""
from __future__ import annotations
import json
import logging
import os
import uuid
from typing import Any, Iterable, List, Optional, Tuple
import requests
from langchain_core.documents import Document
from langchain_core.embeddings import Embeddings
from langchain_core.pydantic_v1 import Field
from langchain_core.vectorstores import VectorStore, VectorStoreRetriever
[docs]class LLMRails(VectorStore):
"""使用LLMRails实现向量存储。
请参阅 https://llmrails.com/
示例:
.. code-block:: python
from langchain_community.vectorstores import LLMRails
vectorstore = LLMRails(
api_key=llm_rails_api_key,
datastore_id=datastore_id
)"""
[docs] def __init__(
self,
datastore_id: Optional[str] = None,
api_key: Optional[str] = None,
):
"""使用LLMRails API 进行初始化。"""
self._datastore_id = datastore_id or os.environ.get("LLM_RAILS_DATASTORE_ID")
self._api_key = api_key or os.environ.get("LLM_RAILS_API_KEY")
if self._api_key is None:
logging.warning("Can't find Rails credentials in environment.")
self._session = requests.Session() # to reuse connections
self.datastore_id = datastore_id
self.base_url = "https://api.llmrails.com/v1"
def _get_post_headers(self) -> dict:
"""返回应附加到每个POST请求的标头。"""
return {"X-API-KEY": self._api_key}
[docs] def add_texts(
self,
texts: Iterable[str],
metadatas: Optional[List[dict]] = None,
**kwargs: Any,
) -> List[str]:
"""运行更多的文本通过嵌入并添加到向量存储中。
参数:
texts: 要添加到向量存储中的字符串的可迭代对象。
返回:
将文本添加到向量存储中后的ID列表。
"""
names: List[str] = []
for text in texts:
doc_name = str(uuid.uuid4())
response = self._session.post(
f"{self.base_url}/datastores/{self._datastore_id}/text",
json={"name": doc_name, "text": text},
verify=True,
headers=self._get_post_headers(),
)
if response.status_code != 200:
logging.error(
f"Create request failed for doc_name = {doc_name} with status code "
f"{response.status_code}, reason {response.reason}, text "
f"{response.text}"
)
return names
names.append(doc_name)
return names
[docs] def add_files(
self,
files_list: Iterable[str],
metadatas: Optional[List[dict]] = None,
**kwargs: Any,
) -> bool:
"""LLMRails提供了一种通过我们的API直接添加文档的方式,其中预处理和分块在内部以最佳方式进行。
这种方法提供了一种在LangChain中使用该API的方式
参数:
files_list:字符串的可迭代对象,每个字符串表示一个本地文件路径。
文件可以是文本、HTML、PDF、markdown、doc/docx、ppt/pptx等。
请查看API文档以获取完整列表
返回:
与索引的每个文件相关联的id列表
"""
files = []
for file in files_list:
if not os.path.exists(file):
logging.error(f"File {file} does not exist, skipping")
continue
files.append(("file", (os.path.basename(file), open(file, "rb"))))
response = self._session.post(
f"{self.base_url}/datastores/{self._datastore_id}/file",
files=files,
verify=True,
headers=self._get_post_headers(),
)
if response.status_code != 200:
logging.error(
f"Create request failed for datastore = {self._datastore_id} "
f"with status code {response.status_code}, reason {response.reason}, "
f"text {response.text}"
)
return False
return True
[docs] def similarity_search_with_score(
self, query: str, k: int = 5
) -> List[Tuple[Document, float]]:
"""返回LLMRails文档与查询最相似的文档,以及相似度分数。
参数:
query:要查找相似文档的文本。
k:要返回的文档数量。默认为5,最大为10。
alpha:混合搜索的参数。
返回:
返回与查询最相似的文档列表,以及每个文档的相似度分数。
"""
response = self._session.post(
headers=self._get_post_headers(),
url=f"{self.base_url}/datastores/{self._datastore_id}/search",
data=json.dumps({"k": k, "text": query}),
timeout=10,
)
if response.status_code != 200:
logging.error(
"Query failed %s",
f"(code {response.status_code}, reason {response.reason}, details "
f"{response.text})",
)
return []
results = response.json()["results"]
docs = [
(
Document(
page_content=x["text"],
metadata={
key: value
for key, value in x["metadata"].items()
if key != "score"
},
),
x["metadata"]["score"],
)
for x in results
]
return docs
[docs] def similarity_search(
self, query: str, k: int = 4, **kwargs: Any
) -> List[Document]:
"""返回LLMRails文档与查询最相似的文档,以及分数。
参数:
query:要查找类似文档的文本。
k:要返回的文档数量。默认为5。
返回:
与查询最相似的文档列表
"""
docs_and_scores = self.similarity_search_with_score(query, k=k)
return [doc for doc, _ in docs_and_scores]
[docs] @classmethod
def from_texts(
cls,
texts: List[str],
embedding: Optional[Embeddings] = None,
metadatas: Optional[List[dict]] = None,
**kwargs: Any,
) -> LLMRails:
"""从原始文档构建LLMRails包装器。
这旨在是一个快速入门的方式。
示例:
.. code-block:: python
from langchain_community.vectorstores import LLMRails
llm_rails = LLMRails.from_texts(
texts,
datastore_id=datastore_id,
api_key=llm_rails_api_key
)
"""
# Note: LLMRails generates its own embeddings, so we ignore the provided
# embeddings (required by interface)
llm_rails = cls(**kwargs)
llm_rails.add_texts(texts)
return llm_rails
[docs] def as_retriever(self, **kwargs: Any) -> LLMRailsRetriever:
return LLMRailsRetriever(vectorstore=self, **kwargs)
[docs]class LLMRailsRetriever(VectorStoreRetriever):
"""LLMRails的检索器。"""
vectorstore: LLMRails
search_kwargs: dict = Field(default_factory=lambda: {"k": 5})
"""搜索参数。
k: 返回的文档数量。默认为5。
alpha: 混合搜索的参数。"""
[docs] def add_texts(self, texts: List[str]) -> None:
"""将文本添加到数据存储中。
参数:
texts(List[str]):文本
"""
self.vectorstore.add_texts(texts)