Source code for langchain_community.utilities.arcee

# This module contains utility classes and functions for interacting with Arcee API.
# For more information and updates, refer to the Arcee utils page:
# [https://github.com/arcee-ai/arcee-python/blob/main/arcee/dalm.py]

from enum import Enum
from typing import Any, Dict, List, Literal, Mapping, Optional, Union

import requests
from langchain_core.pydantic_v1 import BaseModel, SecretStr, root_validator
from langchain_core.retrievers import Document


[docs]class ArceeRoute(str, Enum): """Arcee API可用的路由作为枚举器。""" generate = "models/generate" retrieve = "models/retrieve" model_training_status = "models/status/{id_or_name}"
[docs]class DALMFilterType(str, Enum): """用作DALM检索的过滤器类型作为枚举器可用。""" fuzzy_search = "fuzzy_search" strict_search = "strict_search"
[docs]class DALMFilter(BaseModel): """用于DALM检索和生成的过滤器。 参数: field_name: 要过滤的字段。可以是'document'或'name',用于过滤文档的原始文本或标题。上传上下文数据时,任何其他字段将被假定为元数据字段。 filter_type: 目前支持'fuzzy_search'和'strict_search'。'fuzzy_search'表示对提供的字段进行模糊搜索。确切的严格搜索不需要在文档中存在才能找到匹配项。非常适用于扫描文档中的关键字词。'strict_search'表示提供的字段中必须出现确切的字符串。这不是一个精确的等于过滤器。即具有内容“the happy dog crossed the street”的文档将在对“dog”进行strict_search时匹配,但在“the dog”上不匹配。Python等效于`return search_string in full_string`。 value: 要在上下文数据/元数据中搜索的实际值。""" field_name: str filter_type: DALMFilterType value: str _is_metadata: bool = False @root_validator() def set_meta(cls, values: Dict) -> Dict: """document和name是保留的关键字。其他内容都是元数据。""" values["_is_meta"] = values.get("field_name") not in ["document", "name"] return values
[docs]class ArceeDocumentSource(BaseModel): """Arcee文档的来源。""" document: str name: str id: str
[docs]class ArceeDocument(BaseModel): """Arcee 文档。""" index: str id: str score: float source: ArceeDocumentSource
[docs]class ArceeDocumentAdapter: """Arcee文档的适配器"""
[docs] @classmethod def adapt(cls, arcee_document: ArceeDocument) -> Document: """将`ArceeDocument`适配到langchain的`Document`对象。""" return Document( page_content=arcee_document.source.document, metadata={ # arcee document; source metadata "name": arcee_document.source.name, "source_id": arcee_document.source.id, # arcee document metadata "index": arcee_document.index, "id": arcee_document.id, "score": arcee_document.score, }, )
[docs]class ArceeWrapper: """Arcee API的封装器。 更多详情,请参阅:https://www.arcee.ai/"""
[docs] def __init__( self, arcee_api_key: Union[str, SecretStr], arcee_api_url: str, arcee_api_version: str, model_kwargs: Optional[Dict[str, Any]], model_name: str, ): """初始化ArceeWrapper。 参数: arcee_api_key: Arcee API的API密钥。 arcee_api_url: Arcee API的URL。 arcee_api_version: Arcee API的版本。 model_kwargs: Arcee API的关键字参数。 model_name: Arcee模型的名称。 """ if isinstance(arcee_api_key, str): arcee_api_key_ = SecretStr(arcee_api_key) else: arcee_api_key_ = arcee_api_key self.arcee_api_key: SecretStr = arcee_api_key_ self.model_kwargs = model_kwargs self.arcee_api_url = arcee_api_url self.arcee_api_version = arcee_api_version try: route = ArceeRoute.model_training_status.value.format(id_or_name=model_name) response = self._make_request("get", route) self.model_id = response.get("model_id") self.model_training_status = response.get("status") except Exception as e: raise ValueError( f"Error while validating model training status for '{model_name}': {e}" ) from e
[docs] def validate_model_training_status(self) -> None: if self.model_training_status != "training_complete": raise Exception( f"Model {self.model_id} is not ready. " "Please wait for training to complete." )
def _make_request( self, method: Literal["post", "get"], route: Union[ArceeRoute, str], body: Optional[Mapping[str, Any]] = None, params: Optional[dict] = None, headers: Optional[dict] = None, ) -> dict: """向Arcee API 发送请求 参数: method: 使用的HTTP方法 route: 要调用的路由 body: 请求的主体 params: 请求的查询参数 headers: 请求的头部 """ headers = self._make_request_headers(headers=headers) url = self._make_request_url(route=route) req_type = getattr(requests, method) response = req_type(url, json=body, params=params, headers=headers) if response.status_code not in (200, 201): raise Exception(f"Failed to make request. Response: {response.text}") return response.json() def _make_request_headers(self, headers: Optional[Dict] = None) -> Dict: headers = headers or {} if not isinstance(self.arcee_api_key, SecretStr): raise TypeError( f"arcee_api_key must be a SecretStr. Got {type(self.arcee_api_key)}" ) api_key = self.arcee_api_key.get_secret_value() internal_headers = { "X-Token": api_key, "Content-Type": "application/json", } headers.update(internal_headers) return headers def _make_request_url(self, route: Union[ArceeRoute, str]) -> str: return f"{self.arcee_api_url}/{self.arcee_api_version}/{route}" def _make_request_body_for_models( self, prompt: str, **kwargs: Mapping[str, Any] ) -> Mapping[str, Any]: """生成/检索模型端点的请求主体""" _model_kwargs = self.model_kwargs or {} _params = {**_model_kwargs, **kwargs} filters = [DALMFilter(**f) for f in _params.get("filters", [])] return dict( model_id=self.model_id, query=prompt, size=_params.get("size", 3), filters=filters, id=self.model_id, )
[docs] def generate( self, prompt: str, **kwargs: Any, ) -> str: """从Arcee DALM生成文本。 参数: prompt: 用于生成文本的提示。 size: 要检索的上下文结果的最大数量。默认为3。 (如果提供了过滤器,可能会更少)。 filters: 要应用于上下文数据集的过滤器。 """ response = self._make_request( method="post", route=ArceeRoute.generate.value, body=self._make_request_body_for_models( prompt=prompt, **kwargs, ), ) return response["text"]
[docs] def retrieve( self, query: str, **kwargs: Any, ) -> List[Document]: """检索给定查询的{size}个上下文 参数: query: 提交给模型的查询 size: 要检索的上下文结果的最大数量。默认为3。 (如果提供了过滤器,可能会更少)。 filters: 要应用于上下文数据集的过滤器。 """ response = self._make_request( method="post", route=ArceeRoute.retrieve.value, body=self._make_request_body_for_models( prompt=query, **kwargs, ), ) return [ ArceeDocumentAdapter.adapt(ArceeDocument(**doc)) for doc in response["results"] ]