# This module contains utility classes and functions for interacting with Arcee API.
# For more information and updates, refer to the Arcee utils page:
# [https://github.com/arcee-ai/arcee-python/blob/main/arcee/dalm.py]
from enum import Enum
from typing import Any, Dict, List, Literal, Mapping, Optional, Union
import requests
from langchain_core.pydantic_v1 import BaseModel, SecretStr, root_validator
from langchain_core.retrievers import Document
[docs]class ArceeRoute(str, Enum):
"""Arcee API可用的路由作为枚举器。"""
generate = "models/generate"
retrieve = "models/retrieve"
model_training_status = "models/status/{id_or_name}"
[docs]class DALMFilterType(str, Enum):
"""用作DALM检索的过滤器类型作为枚举器可用。"""
fuzzy_search = "fuzzy_search"
strict_search = "strict_search"
[docs]class DALMFilter(BaseModel):
"""用于DALM检索和生成的过滤器。
参数:
field_name: 要过滤的字段。可以是'document'或'name',用于过滤文档的原始文本或标题。上传上下文数据时,任何其他字段将被假定为元数据字段。
filter_type: 目前支持'fuzzy_search'和'strict_search'。'fuzzy_search'表示对提供的字段进行模糊搜索。确切的严格搜索不需要在文档中存在才能找到匹配项。非常适用于扫描文档中的关键字词。'strict_search'表示提供的字段中必须出现确切的字符串。这不是一个精确的等于过滤器。即具有内容“the happy dog crossed the street”的文档将在对“dog”进行strict_search时匹配,但在“the dog”上不匹配。Python等效于`return search_string in full_string`。
value: 要在上下文数据/元数据中搜索的实际值。"""
field_name: str
filter_type: DALMFilterType
value: str
_is_metadata: bool = False
@root_validator()
def set_meta(cls, values: Dict) -> Dict:
"""document和name是保留的关键字。其他内容都是元数据。"""
values["_is_meta"] = values.get("field_name") not in ["document", "name"]
return values
[docs]class ArceeDocumentSource(BaseModel):
"""Arcee文档的来源。"""
document: str
name: str
id: str
[docs]class ArceeDocument(BaseModel):
"""Arcee 文档。"""
index: str
id: str
score: float
source: ArceeDocumentSource
[docs]class ArceeDocumentAdapter:
"""Arcee文档的适配器"""
[docs] @classmethod
def adapt(cls, arcee_document: ArceeDocument) -> Document:
"""将`ArceeDocument`适配到langchain的`Document`对象。"""
return Document(
page_content=arcee_document.source.document,
metadata={
# arcee document; source metadata
"name": arcee_document.source.name,
"source_id": arcee_document.source.id,
# arcee document metadata
"index": arcee_document.index,
"id": arcee_document.id,
"score": arcee_document.score,
},
)
[docs]class ArceeWrapper:
"""Arcee API的封装器。
更多详情,请参阅:https://www.arcee.ai/"""
[docs] def __init__(
self,
arcee_api_key: Union[str, SecretStr],
arcee_api_url: str,
arcee_api_version: str,
model_kwargs: Optional[Dict[str, Any]],
model_name: str,
):
"""初始化ArceeWrapper。
参数:
arcee_api_key: Arcee API的API密钥。
arcee_api_url: Arcee API的URL。
arcee_api_version: Arcee API的版本。
model_kwargs: Arcee API的关键字参数。
model_name: Arcee模型的名称。
"""
if isinstance(arcee_api_key, str):
arcee_api_key_ = SecretStr(arcee_api_key)
else:
arcee_api_key_ = arcee_api_key
self.arcee_api_key: SecretStr = arcee_api_key_
self.model_kwargs = model_kwargs
self.arcee_api_url = arcee_api_url
self.arcee_api_version = arcee_api_version
try:
route = ArceeRoute.model_training_status.value.format(id_or_name=model_name)
response = self._make_request("get", route)
self.model_id = response.get("model_id")
self.model_training_status = response.get("status")
except Exception as e:
raise ValueError(
f"Error while validating model training status for '{model_name}': {e}"
) from e
[docs] def validate_model_training_status(self) -> None:
if self.model_training_status != "training_complete":
raise Exception(
f"Model {self.model_id} is not ready. "
"Please wait for training to complete."
)
def _make_request(
self,
method: Literal["post", "get"],
route: Union[ArceeRoute, str],
body: Optional[Mapping[str, Any]] = None,
params: Optional[dict] = None,
headers: Optional[dict] = None,
) -> dict:
"""向Arcee API 发送请求
参数:
method: 使用的HTTP方法
route: 要调用的路由
body: 请求的主体
params: 请求的查询参数
headers: 请求的头部
"""
headers = self._make_request_headers(headers=headers)
url = self._make_request_url(route=route)
req_type = getattr(requests, method)
response = req_type(url, json=body, params=params, headers=headers)
if response.status_code not in (200, 201):
raise Exception(f"Failed to make request. Response: {response.text}")
return response.json()
def _make_request_headers(self, headers: Optional[Dict] = None) -> Dict:
headers = headers or {}
if not isinstance(self.arcee_api_key, SecretStr):
raise TypeError(
f"arcee_api_key must be a SecretStr. Got {type(self.arcee_api_key)}"
)
api_key = self.arcee_api_key.get_secret_value()
internal_headers = {
"X-Token": api_key,
"Content-Type": "application/json",
}
headers.update(internal_headers)
return headers
def _make_request_url(self, route: Union[ArceeRoute, str]) -> str:
return f"{self.arcee_api_url}/{self.arcee_api_version}/{route}"
def _make_request_body_for_models(
self, prompt: str, **kwargs: Mapping[str, Any]
) -> Mapping[str, Any]:
"""生成/检索模型端点的请求主体"""
_model_kwargs = self.model_kwargs or {}
_params = {**_model_kwargs, **kwargs}
filters = [DALMFilter(**f) for f in _params.get("filters", [])]
return dict(
model_id=self.model_id,
query=prompt,
size=_params.get("size", 3),
filters=filters,
id=self.model_id,
)
[docs] def generate(
self,
prompt: str,
**kwargs: Any,
) -> str:
"""从Arcee DALM生成文本。
参数:
prompt: 用于生成文本的提示。
size: 要检索的上下文结果的最大数量。默认为3。
(如果提供了过滤器,可能会更少)。
filters: 要应用于上下文数据集的过滤器。
"""
response = self._make_request(
method="post",
route=ArceeRoute.generate.value,
body=self._make_request_body_for_models(
prompt=prompt,
**kwargs,
),
)
return response["text"]
[docs] def retrieve(
self,
query: str,
**kwargs: Any,
) -> List[Document]:
"""检索给定查询的{size}个上下文
参数:
query: 提交给模型的查询
size: 要检索的上下文结果的最大数量。默认为3。
(如果提供了过滤器,可能会更少)。
filters: 要应用于上下文数据集的过滤器。
"""
response = self._make_request(
method="post",
route=ArceeRoute.retrieve.value,
body=self._make_request_body_for_models(
prompt=query,
**kwargs,
),
)
return [
ArceeDocumentAdapter.adapt(ArceeDocument(**doc))
for doc in response["results"]
]