Source code for langchain_community.embeddings.voyageai

from __future__ import annotations

import json
import logging
from typing import (
    Any,
    Callable,
    Dict,
    List,
    Optional,
    Tuple,
    Union,
    cast,
)

import requests
from langchain_core._api.deprecation import deprecated
from langchain_core.embeddings import Embeddings
from langchain_core.pydantic_v1 import BaseModel, Extra, SecretStr, root_validator
from langchain_core.utils import convert_to_secret_str, get_from_dict_or_env
from tenacity import (
    before_sleep_log,
    retry,
    stop_after_attempt,
    wait_exponential,
)

logger = logging.getLogger(__name__)


def _create_retry_decorator(embeddings: VoyageEmbeddings) -> Callable[[Any], Any]:
    min_seconds = 4
    max_seconds = 10
    # Wait 2^x * 1 second between each retry starting with
    # 4 seconds, then up to 10 seconds, then 10 seconds afterwards
    return retry(
        reraise=True,
        stop=stop_after_attempt(embeddings.max_retries),
        wait=wait_exponential(multiplier=1, min=min_seconds, max=max_seconds),
        before_sleep=before_sleep_log(logger, logging.WARNING),
    )


def _check_response(response: dict) -> dict:
    if "data" not in response:
        raise RuntimeError(f"Voyage API Error. Message: {json.dumps(response)}")
    return response


[docs]def embed_with_retry(embeddings: VoyageEmbeddings, **kwargs: Any) -> Any: """使用tenacity来重试嵌入调用。""" retry_decorator = _create_retry_decorator(embeddings) @retry_decorator def _embed_with_retry(**kwargs: Any) -> Any: response = requests.post(**kwargs) return _check_response(response.json()) return _embed_with_retry(**kwargs)
[docs]@deprecated( since="0.0.29", removal="0.3", alternative_import="langchain_voyageai.VoyageAIEmbeddings", ) class VoyageEmbeddings(BaseModel, Embeddings): """航行嵌入模型。 要使用,您应该设置环境变量``VOYAGE_API_KEY``为您的API密钥,或将其作为构造函数的命名参数传递。 示例: .. code-block:: python from langchain_community.embeddings import VoyageEmbeddings voyage = VoyageEmbeddings(voyage_api_key="your-api-key", model="voyage-2") text = "This is a test query." query_result = voyage.embed_query(text)""" model: str voyage_api_base: str = "https://api.voyageai.com/v1/embeddings" voyage_api_key: Optional[SecretStr] = None batch_size: int """每个API请求中嵌入的最大文本数。""" max_retries: int = 6 """生成时最大的重试次数。""" request_timeout: Optional[Union[float, Tuple[float, float]]] = None """API请求的超时时间(秒)。""" show_progress_bar: bool = False """在嵌入时是否显示进度条。如果设置为True,则必须安装tqdm。""" truncation: bool = True """是否截断输入文本以适应上下文长度。 如果为True,则超长的输入文本将被截断以适应上下文长度,然后再由嵌入模型进行向量化。如果为False,则如果任何给定文本超过上下文长度,将引发错误。""" class Config: """此pydantic对象的配置。""" extra = Extra.forbid @root_validator(pre=True) def validate_environment(cls, values: Dict) -> Dict: """验证环境中是否存在API密钥和Python包。""" values["voyage_api_key"] = convert_to_secret_str( get_from_dict_or_env(values, "voyage_api_key", "VOYAGE_API_KEY") ) if "model" not in values: values["model"] = "voyage-01" logger.warning( "model will become a required arg for VoyageAIEmbeddings, " "we recommend to specify it when using this class. " "Currently the default is set to voyage-01." ) if "batch_size" not in values: values["batch_size"] = ( 72 if "model" in values and (values["model"] in ["voyage-2", "voyage-02"]) else 7 ) return values def _invocation_params( self, input: List[str], input_type: Optional[str] = None ) -> Dict: api_key = cast(SecretStr, self.voyage_api_key).get_secret_value() params: Dict = { "url": self.voyage_api_base, "headers": {"Authorization": f"Bearer {api_key}"}, "json": { "model": self.model, "input": input, "input_type": input_type, "truncation": self.truncation, }, "timeout": self.request_timeout, } return params def _get_embeddings( self, texts: List[str], batch_size: Optional[int] = None, input_type: Optional[str] = None, ) -> List[List[float]]: embeddings: List[List[float]] = [] if batch_size is None: batch_size = self.batch_size if self.show_progress_bar: try: from tqdm.auto import tqdm except ImportError as e: raise ImportError( "Must have tqdm installed if `show_progress_bar` is set to True. " "Please install with `pip install tqdm`." ) from e _iter = tqdm(range(0, len(texts), batch_size)) else: _iter = range(0, len(texts), batch_size) if input_type and input_type not in ["query", "document"]: raise ValueError( f"input_type {input_type} is invalid. Options: None, 'query', " "'document'." ) for i in _iter: response = embed_with_retry( self, **self._invocation_params( input=texts[i : i + batch_size], input_type=input_type ), ) embeddings.extend(r["embedding"] for r in response["data"]) return embeddings
[docs] def embed_documents(self, texts: List[str]) -> List[List[float]]: """调用Voyage Embedding端点以嵌入搜索文档。 参数: texts:要嵌入的文本列表。 返回: 每个文本的嵌入列表。 """ return self._get_embeddings( texts, batch_size=self.batch_size, input_type="document" )
[docs] def embed_query(self, text: str) -> List[float]: """调用Voyage Embedding端点以获取嵌入查询文本。 参数: text:要嵌入的文本。 返回: 文本的嵌入。 """ return self._get_embeddings( [text], batch_size=self.batch_size, input_type="query" )[0]
[docs] def embed_general_texts( self, texts: List[str], *, input_type: Optional[str] = None ) -> List[List[float]]: """调用Voyage Embedding端点以嵌入一般文本。 参数: texts:要嵌入的文本列表。 input_type:输入文本的类型。默认为None,表示类型未指定。其他选项:query,document。 返回: 文本的嵌入。 """ return self._get_embeddings( texts, batch_size=self.batch_size, input_type=input_type )