Source code for langchain_community.embeddings.premai

from __future__ import annotations

import logging
from typing import Any, Callable, Dict, List, Optional, Union

from langchain_core.embeddings import Embeddings
from langchain_core.language_models.llms import create_base_retry_decorator
from langchain_core.pydantic_v1 import BaseModel, SecretStr, root_validator
from langchain_core.utils import get_from_dict_or_env

logger = logging.getLogger(__name__)


[docs]class PremAIEmbeddings(BaseModel, Embeddings): """Prem的嵌入式APIs""" project_id: int """实验或部署所在的项目ID。 您可以在此处找到所有您的项目:https://app.premai.io/projects/""" premai_api_key: Optional[SecretStr] = None """Prem AI API密钥。在此处获取:https://app.premai.io/api_keys/""" model: str """可供选择的嵌入模型""" show_progress_bar: bool = False """是否显示tqdm进度条。必须安装`tqdm`。""" max_retries: int = 1 """tenacity的最大重试次数""" client: Any @root_validator() def validate_environments(cls, values: Dict) -> Dict: """验证包是否已安装并且API令牌是否有效""" try: from premai import Prem except ImportError as error: raise ImportError( "Could not import Prem Python package." "Please install it with: `pip install premai`" ) from error try: premai_api_key = get_from_dict_or_env( values, "premai_api_key", "PREMAI_API_KEY" ) values["client"] = Prem(api_key=premai_api_key) except Exception as error: raise ValueError("Your API Key is incorrect. Please try again.") from error return values
[docs] def embed_query(self, text: str) -> List[float]: """嵌入查询文本""" embeddings = embed_with_retry( self, model=self.model, project_id=self.project_id, input=text ) return embeddings.data[0].embedding
[docs] def embed_documents(self, texts: List[str]) -> List[List[float]]: embeddings = embed_with_retry( self, model=self.model, project_id=self.project_id, input=texts ).data return [embedding.embedding for embedding in embeddings]
[docs]def create_prem_retry_decorator( embedder: PremAIEmbeddings, *, max_retries: int = 1, ) -> Callable[[Any], Any]: """为PremAIEmbeddings创建一个重试装饰器。 参数: embedder (PremAIEmbeddings): PremAIEmbeddings实例 max_retries (int): 最大重试次数 返回: Callable[[Any], Any]: 重试装饰器 """ import premai.models errors = [ premai.models.api_response_validation_error.APIResponseValidationError, premai.models.conflict_error.ConflictError, premai.models.model_not_found_error.ModelNotFoundError, premai.models.permission_denied_error.PermissionDeniedError, premai.models.provider_api_connection_error.ProviderAPIConnectionError, premai.models.provider_api_status_error.ProviderAPIStatusError, premai.models.provider_api_timeout_error.ProviderAPITimeoutError, premai.models.provider_internal_server_error.ProviderInternalServerError, premai.models.provider_not_found_error.ProviderNotFoundError, premai.models.rate_limit_error.RateLimitError, premai.models.unprocessable_entity_error.UnprocessableEntityError, premai.models.validation_error.ValidationError, ] decorator = create_base_retry_decorator( error_types=errors, max_retries=max_retries, run_manager=None ) return decorator
[docs]def embed_with_retry( embedder: PremAIEmbeddings, model: str, project_id: int, input: Union[str, List[str]], ) -> Any: """使用tenacity库进行嵌入调用的重试""" retry_decorator = create_prem_retry_decorator( embedder, max_retries=embedder.max_retries ) @retry_decorator def _embed_with_retry( embedder: PremAIEmbeddings, project_id: int, model: str, input: Union[str, List[str]], ) -> Any: embedding_response = embedder.client.embeddings.create( project_id=project_id, model=model, input=input ) return embedding_response return _embed_with_retry(embedder, project_id=project_id, model=model, input=input)