Source code for langchain_community.embeddings.yandex

"""封装了YandexGPT嵌入模型。"""
from __future__ import annotations

import logging
import time
from typing import Any, Callable, Dict, List, Sequence

from langchain_core.embeddings import Embeddings
from langchain_core.pydantic_v1 import BaseModel, Field, SecretStr, root_validator
from langchain_core.utils import convert_to_secret_str, get_from_dict_or_env
from tenacity import (
    before_sleep_log,
    retry,
    retry_if_exception_type,
    stop_after_attempt,
    wait_exponential,
)

logger = logging.getLogger(__name__)


[docs]class YandexGPTEmbeddings(BaseModel, Embeddings): """YandexGPT嵌入模型。 要使用,您应该已安装``yandexcloud`` python包。 有两种身份验证选项适用于具有``ai.languageModels.user``角色的服务帐户: - 您可以在构造函数参数`iam_token`中指定令牌,也可以在环境变量`YC_IAM_TOKEN`中指定。 - 您可以在构造函数参数`api_key`中指定密钥,也可以在环境变量`YC_API_KEY`中指定。 要使用默认模型,请在参数`folder_id`中指定文件夹ID,或在环境变量`YC_FOLDER_ID`中指定。 示例: .. code-block:: python from langchain_community.embeddings.yandex import YandexGPTEmbeddings embeddings = YandexGPTEmbeddings(iam_token="t1.9eu...", folder_id=<folder-id>) """ # noqa: E501 iam_token: SecretStr = "" # type: ignore[assignment] """Yandex Cloud IAM服务账号的令牌 具有`ai.languageModels.user`角色""" api_key: SecretStr = "" # type: ignore[assignment] """Yandex云服务帐户的API密钥 具有`ai.languageModels.user`角色""" model_uri: str = Field(default="", alias="query_model_uri") """查询模型URI以使用。""" doc_model_uri: str = "" """使用的文档模型URI。""" folder_id: str = "" """Yandex云文件夹ID""" doc_model_name: str = "text-search-doc" """使用的文档模型名称。""" model_name: str = Field(default="text-search-query", alias="query_model_name") """查询要使用的模型名称。""" model_version: str = "latest" """要使用的模型版本。""" url: str = "llm.api.cloud.yandex.net:443" """API的URL。""" max_retries: int = 6 """生成时最大的重试次数。""" sleep_interval: float = 0.0 """API请求之间的延迟""" disable_request_logging: bool = False """YandexGPT API默认记录所有请求数据。 如果您提供个人数据、机密信息,请禁用日志记录。""" _grpc_metadata: Sequence class Config: """此pydantic对象的配置。""" allow_population_by_field_name = True @root_validator() def validate_environment(cls, values: Dict) -> Dict: """验证环境中是否存在IAM令牌。""" iam_token = convert_to_secret_str( get_from_dict_or_env(values, "iam_token", "YC_IAM_TOKEN", "") ) values["iam_token"] = iam_token api_key = convert_to_secret_str( get_from_dict_or_env(values, "api_key", "YC_API_KEY", "") ) values["api_key"] = api_key folder_id = get_from_dict_or_env(values, "folder_id", "YC_FOLDER_ID", "") values["folder_id"] = folder_id if api_key.get_secret_value() == "" and iam_token.get_secret_value() == "": raise ValueError("Either 'YC_API_KEY' or 'YC_IAM_TOKEN' must be provided.") if values["iam_token"]: values["_grpc_metadata"] = [ ("authorization", f"Bearer {values['iam_token'].get_secret_value()}") ] if values["folder_id"]: values["_grpc_metadata"].append(("x-folder-id", values["folder_id"])) else: values["_grpc_metadata"] = ( ("authorization", f"Api-Key {values['api_key'].get_secret_value()}"), ) if not values.get("doc_model_uri"): if values["folder_id"] == "": raise ValueError("'doc_model_uri' or 'folder_id' must be provided.") values[ "doc_model_uri" ] = f"emb://{values['folder_id']}/{values['doc_model_name']}/{values['model_version']}" # noqa: E501 if not values.get("model_uri"): if values["folder_id"] == "": raise ValueError("'model_uri' or 'folder_id' must be provided.") values[ "model_uri" ] = f"emb://{values['folder_id']}/{values['model_name']}/{values['model_version']}" # noqa: E501 if values["disable_request_logging"]: values["_grpc_metadata"].append( ( "x-data-logging-enabled", "false", ) ) return values
[docs] def embed_documents(self, texts: List[str]) -> List[List[float]]: """使用YandexGPT嵌入模型嵌入文档。 参数: texts:要嵌入的文本列表。 返回: 每个文本的嵌入列表。 """ return _embed_with_retry(self, texts=texts)
[docs] def embed_query(self, text: str) -> List[float]: """使用YandexGPT嵌入模型嵌入一个查询。 参数: text: 要嵌入的文本。 返回: 文本的嵌入。 """ return _embed_with_retry(self, texts=[text], embed_query=True)[0]
def _create_retry_decorator(llm: YandexGPTEmbeddings) -> Callable[[Any], Any]: from grpc import RpcError min_seconds = 1 max_seconds = 60 return retry( reraise=True, stop=stop_after_attempt(llm.max_retries), wait=wait_exponential(multiplier=1, min=min_seconds, max=max_seconds), retry=(retry_if_exception_type((RpcError))), before_sleep=before_sleep_log(logger, logging.WARNING), ) def _embed_with_retry(llm: YandexGPTEmbeddings, **kwargs: Any) -> Any: """使用tenacity来重试嵌入调用。""" retry_decorator = _create_retry_decorator(llm) @retry_decorator def _completion_with_retry(**_kwargs: Any) -> Any: return _make_request(llm, **_kwargs) return _completion_with_retry(**kwargs) def _make_request(self: YandexGPTEmbeddings, texts: List[str], **kwargs): # type: ignore[no-untyped-def] try: import grpc try: from yandex.cloud.ai.foundation_models.v1.embedding.embedding_service_pb2 import ( # noqa: E501 TextEmbeddingRequest, ) from yandex.cloud.ai.foundation_models.v1.embedding.embedding_service_pb2_grpc import ( # noqa: E501 EmbeddingsServiceStub, ) except ModuleNotFoundError: from yandex.cloud.ai.foundation_models.v1.foundation_models_service_pb2 import ( # noqa: E501 TextEmbeddingRequest, ) from yandex.cloud.ai.foundation_models.v1.foundation_models_service_pb2_grpc import ( # noqa: E501 EmbeddingsServiceStub, ) except ImportError as e: raise ImportError( "Please install YandexCloud SDK with `pip install yandexcloud` \ or upgrade it to recent version." ) from e result = [] channel_credentials = grpc.ssl_channel_credentials() channel = grpc.secure_channel(self.url, channel_credentials) # Use the query model if embed_query is True if kwargs.get("embed_query"): model_uri = self.model_uri else: model_uri = self.doc_model_uri for text in texts: request = TextEmbeddingRequest(model_uri=model_uri, text=text) stub = EmbeddingsServiceStub(channel) res = stub.TextEmbedding(request, metadata=self._grpc_metadata) # type: ignore[attr-defined] result.append(list(res.embedding)) time.sleep(self.sleep_interval) return result