Source code for langchain_community.cross_encoders.sagemaker_endpoint

import json
from typing import Any, Dict, List, Optional, Tuple

from langchain_core.pydantic_v1 import BaseModel, Extra, root_validator

from langchain_community.cross_encoders.base import BaseCrossEncoder


[docs]class CrossEncoderContentHandler: """用于CrossEncoder类的内容处理程序。""" content_type = "application/json" accepts = "application/json"
[docs] def transform_input(self, text_pairs: List[Tuple[str, str]]) -> bytes: input_str = json.dumps({"text_pairs": text_pairs}) return input_str.encode("utf-8")
[docs] def transform_output(self, output: Any) -> List[float]: response_json = json.loads(output.read().decode("utf-8")) scores = response_json["scores"] return scores
[docs]class SagemakerEndpointCrossEncoder(BaseModel, BaseCrossEncoder): """SageMaker 推理 CrossEncoder 端点。 要使用,必须提供部署的 Sagemaker 模型的端点名称和部署的区域。 要进行身份验证,AWS 客户端使用以下方法自动加载凭据: https://boto3.amazonaws.com/v1/documentation/api/latest/guide/credentials.html 如果应该使用特定的凭据配置文件,必须传递要使用的 ~/.aws/credentials 文件中的配置文件名称。 确保使用的凭据/角色具有访问 Sagemaker 端点所需的策略。 参见:https://docs.aws.amazon.com/IAM/latest/UserGuide/access_policies.html""" """ 示例: .. code-block:: python from langchain.embeddings import SagemakerEndpointCrossEncoder endpoint_name = ( "my-endpoint-name" ) region_name = ( "us-west-2" ) credentials_profile_name = ( "default" ) se = SagemakerEndpointCrossEncoder( endpoint_name=endpoint_name, region_name=region_name, credentials_profile_name=credentials_profile_name )""" client: Any #: :meta private: endpoint_name: str = "" """部署的Sagemaker模型的端点名称。 在AWS区域内必须是唯一的。""" region_name: str = "" """Sagemaker模型部署的AWS区域,例如`us-west-2`。""" credentials_profile_name: Optional[str] = None """~/.aws/credentials 或 ~/.aws/config 文件中配置文件的名称,其中包含指定的访问密钥或角色信息。 如果未指定,则将使用默认凭证配置文件,或者如果在EC2实例上,则将使用来自IMDS的凭证。 参见:https://boto3.amazonaws.com/v1/documentation/api/latest/guide/credentials.html""" content_handler: CrossEncoderContentHandler = CrossEncoderContentHandler() model_kwargs: Optional[Dict] = None """传递给模型的关键字参数。""" endpoint_kwargs: Optional[Dict] = None """传递给invoke_endpoint函数的可选属性。查看`boto3`文档获取更多信息。 .. _boto3: <https://boto3.amazonaws.com/v1/documentation/api/latest/index.html>""" class Config: """此pydantic对象的配置。""" extra = Extra.forbid arbitrary_types_allowed = True @root_validator() def validate_environment(cls, values: Dict) -> Dict: """验证AWS凭证和Python包是否存在于环境中。""" try: import boto3 try: if values["credentials_profile_name"] is not None: session = boto3.Session( profile_name=values["credentials_profile_name"] ) else: # use default credentials session = boto3.Session() values["client"] = session.client( "sagemaker-runtime", region_name=values["region_name"] ) except Exception as e: raise ValueError( "Could not load credentials to authenticate with AWS client. " "Please check that credentials in the specified " "profile name are valid." ) from e except ImportError: raise ImportError( "Could not import boto3 python package. " "Please install it with `pip install boto3`." ) return values
[docs] def score(self, text_pairs: List[Tuple[str, str]]) -> List[float]: """调用SageMaker推理CrossEncoder端点。""" _endpoint_kwargs = self.endpoint_kwargs or {} body = self.content_handler.transform_input(text_pairs) content_type = self.content_handler.content_type accepts = self.content_handler.accepts # send request try: response = self.client.invoke_endpoint( EndpointName=self.endpoint_name, Body=body, ContentType=content_type, Accept=accepts, **_endpoint_kwargs, ) except Exception as e: raise ValueError(f"Error raised by inference endpoint: {e}") return self.content_handler.transform_output(response["Body"])