Source code for langchain_community.embeddings.titan_takeoff

from enum import Enum
from typing import Any, List, Optional, Set, Union

from langchain_core.embeddings import Embeddings
from langchain_core.pydantic_v1 import BaseModel


[docs]class TakeoffEmbeddingException(Exception): """用于与Takeoff Embedding类进行交互的自定义异常。"""
[docs]class MissingConsumerGroup(TakeoffEmbeddingException): """在初始化TitanTakeoffEmbed或在嵌入请求中未提供消费者组时引发的异常。"""
[docs]class Device(str, Enum): """用于推理的设备,cuda或cpu。""" cuda = "cuda" cpu = "cpu"
[docs]class ReaderConfig(BaseModel): """部署在Takeoff中的读取器配置。""" class Config: protected_namespaces = () model_name: str """要使用的模型名称""" device: Device = Device.cuda """用于推理的设备,cuda或cpu""" consumer_group: str = "primary" """将读者放入的消费者组"""
[docs]class TitanTakeoffEmbed(Embeddings): """与Takeoff推理API接口进行模型嵌入。 用于发送嵌入请求并部署嵌入阅读器与Takeoff。 示例: 这是一个部署嵌入模型并发送请求的示例。 .. code-block:: python # 从社区包中导入TitanTakeoffEmbed类 import time from langchain_community.embeddings import TitanTakeoffEmbed # 指定要部署的嵌入阅读器 reader_1 = { "model_name": "avsolatorio/GIST-large-Embedding-v0", "device": "cpu", "consumer_group": "embed" } # 对于传递到models参数的每个阅读器,Takeoff将根据您提供的规格启动一个阅读器。 # 如果您不指定该参数,则不会启动任何模型,并且它会假定您已经单独完成了这一步。 embed = TitanTakeoffEmbed(models=[reader_1]) # 等待阅读器部署完成,所需时间取决于模型大小和您的互联网速度 time.sleep(60) # 返回嵌入查询,即发送到`embed`消费者组的List[float],我们刚刚启动了嵌入阅读器 print(embed.embed_query( "我在哪里可以看到足球?", consumer_group="embed" )) # 返回嵌入列表,即发送到`embed`消费者组的List[List[float]],我们刚刚启动了嵌入阅读器 print(embed.embed_document( ["文档1", "文档2"], consumer_group="embed" ))""" base_url: str = "http://localhost" """Titan Takeoff (Pro) 服务器的基本URL。默认值为 "http://localhost"。""" port: int = 3000 """Titan Takeoff(Pro)服务器的端口。默认值为3000。""" mgmt_port: int = 3001 """Titan Takeoff(Pro)服务器的管理端口。默认值为3001。""" client: Any = None """Takeoff客户端Python SDK用于与Takeoff API进行交互""" embed_consumer_groups: Set[str] = set() """Takeoff中包含嵌入模型的消费者组"""
[docs] def __init__( self, base_url: str = "http://localhost", port: int = 3000, mgmt_port: int = 3001, models: List[ReaderConfig] = [], ): """初始化Titan Takeoff嵌入式包装器。 参数: base_url (str, 可选): Takeoff推理服务器监听的基本URL。默认为"http://localhost"。 port (int, 可选): Takeoff推理API监听的端口号。默认为3000。 mgmt_port (int, 可选): Takeoff管理API监听的端口号。默认为3001。 models (List[ReaderConfig], 可选): 您想要在其中启动的任何读取器。默认为[]。 抛出: ImportError: 如果您尚未安装takeoff-client,则会出现ImportError。要解决此问题,请运行 `pip install 'takeoff-client==0.4.0'`。 """ self.base_url = base_url self.port = port self.mgmt_port = mgmt_port try: from takeoff_client import TakeoffClient except ImportError: raise ImportError( "takeoff-client is required for TitanTakeoff. " "Please install it with `pip install 'takeoff-client==0.4.0'`." ) self.client = TakeoffClient( self.base_url, port=self.port, mgmt_port=self.mgmt_port ) for model in models: self.client.create_reader(model) if isinstance(model, dict): self.embed_consumer_groups.add(model.get("consumer_group")) else: self.embed_consumer_groups.add(model.consumer_group) super(TitanTakeoffEmbed, self).__init__()
def _embed( self, input: Union[List[str], str], consumer_group: Optional[str] ) -> dict: """嵌入文本。 参数: input (List[str]): 要嵌入的提示/文档或提示/文档列表 consumer_group (Optional[str]): 发送嵌入请求的消费者组。如果未指定,并且在初始化期间只指定了一个消费者组,则将使用该消费者组。如果在初始化期间指定了多个消费者组,则必须指定要使用的消费者组。 引发: MissingConsumerGroup: 无法从初始化中推断出消费者组,必须在请求中指定。 返回: Dict[str, Any]: 查询结果,{"result": List[List[float]]} 或 {"result": List[float]} """ if not consumer_group: if len(self.embed_consumer_groups) == 1: consumer_group = list(self.embed_consumer_groups)[0] elif len(self.embed_consumer_groups) > 1: raise MissingConsumerGroup( "TakeoffEmbedding was initialized with multiple embedding reader" "groups, you must specify which one to use." ) else: raise MissingConsumerGroup( "You must specify what consumer group you want to send embedding" "response to as TitanTakeoffEmbed was not initialized with an " "embedding reader." ) return self.client.embed(input, consumer_group)
[docs] def embed_documents( self, texts: List[str], consumer_group: Optional[str] = None ) -> List[List[float]]: """嵌入文档。 参数: texts (List[str]): 需要嵌入的提示/文档列表 consumer_group (Optional[str], optional): 发送请求的消费者组,包含嵌入模型。默认为None。 返回: List[List[float]]: 嵌入列表 """ return self._embed(texts, consumer_group)["result"]
[docs] def embed_query( self, text: str, consumer_group: Optional[str] = None ) -> List[float]: """嵌入查询。 参数: text(str):要嵌入的提示/文档 consumer_group(Optional[str],可选):发送请求的消费者组,其中包含嵌入模型。默认为None。 返回: List[float]:嵌入 """ return self._embed(text, consumer_group)["result"]