Source code for langchain_community.utilities.dria_index
import logging
from typing import Any, Dict, List, Optional, Union
logger = logging.getLogger(__name__)
[docs]class DriaAPIWrapper:
"""封装了Dria API。
该封装简化了与Dria的向量搜索和检索服务的交互,包括创建知识库、插入数据和获取搜索结果。
属性:
api_key: 用于访问Dria的API密钥。
contract_id: 要交互的知识库的合同ID。
top_n: 搜索时要获取的前N个结果的数量。"""
[docs] def __init__(
self, api_key: str, contract_id: Optional[str] = None, top_n: int = 10
):
try:
from dria import Dria, Models
except ImportError:
logger.error(
"""Dria is not installed. Please install Dria to use this wrapper.
You can install Dria using the following command:
pip install dria
"""
)
return
self.api_key = api_key
self.models = Models
self.contract_id = contract_id
self.top_n = top_n
self.dria_client = Dria(api_key=self.api_key)
if self.contract_id:
self.dria_client.set_contract(self.contract_id)
[docs] def create_knowledge_base(
self,
name: str,
description: str,
category: str,
embedding: str,
) -> str:
"""创建一个新的知识库。"""
contract_id = self.dria_client.create(
name=name, embedding=embedding, category=category, description=description
)
logger.info(f"Knowledge base created with ID: {contract_id}")
self.contract_id = contract_id
return contract_id
[docs] def insert_data(self, data: List[Dict[str, Any]]) -> str:
"""将数据插入知识库。"""
response = self.dria_client.insert_text(data)
logger.info(f"Data inserted: {response}")
return response
[docs] def search(self, query: str) -> List[Dict[str, Any]]:
"""执行基于文本的搜索。"""
results = self.dria_client.search(query, top_n=self.top_n)
logger.info(f"Search results: {results}")
return results
[docs] def query_with_vector(self, vector: List[float]) -> List[Dict[str, Any]]:
"""执行基于向量的查询。"""
vector_query_results = self.dria_client.query(vector, top_n=self.top_n)
logger.info(f"Vector query results: {vector_query_results}")
return vector_query_results
[docs] def run(self, query: Union[str, List[float]]) -> Optional[List[Dict[str, Any]]]:
"""处理基于文本搜索和基于向量查询的方法。
参数:
query:用于文本搜索的字符串或用于向量查询的浮点数列表。
返回:
来自Dria的搜索或查询结果。
"""
if isinstance(query, str):
return self.search(query)
elif isinstance(query, list) and all(isinstance(item, float) for item in query):
return self.query_with_vector(query)
else:
logger.error(
"""Invalid query type. Please provide a string for text search or a
list of floats for vector query."""
)
return None