Source code for langchain_community.embeddings.titan_takeoff
from enum import Enum
from typing import Any, List, Optional, Set, Union
from langchain_core.embeddings import Embeddings
from langchain_core.pydantic_v1 import BaseModel
[docs]class TakeoffEmbeddingException(Exception):
"""用于与Takeoff Embedding类进行交互的自定义异常。"""
[docs]class MissingConsumerGroup(TakeoffEmbeddingException):
"""在初始化TitanTakeoffEmbed或在嵌入请求中未提供消费者组时引发的异常。"""
[docs]class Device(str, Enum):
"""用于推理的设备,cuda或cpu。"""
cuda = "cuda"
cpu = "cpu"
[docs]class ReaderConfig(BaseModel):
"""部署在Takeoff中的读取器配置。"""
class Config:
protected_namespaces = ()
model_name: str
"""要使用的模型名称"""
device: Device = Device.cuda
"""用于推理的设备,cuda或cpu"""
consumer_group: str = "primary"
"""将读者放入的消费者组"""
[docs]class TitanTakeoffEmbed(Embeddings):
"""与Takeoff推理API接口进行模型嵌入。
用于发送嵌入请求并部署嵌入阅读器与Takeoff。
示例:
这是一个部署嵌入模型并发送请求的示例。
.. code-block:: python
# 从社区包中导入TitanTakeoffEmbed类
import time
from langchain_community.embeddings import TitanTakeoffEmbed
# 指定要部署的嵌入阅读器
reader_1 = {
"model_name": "avsolatorio/GIST-large-Embedding-v0",
"device": "cpu",
"consumer_group": "embed"
}
# 对于传递到models参数的每个阅读器,Takeoff将根据您提供的规格启动一个阅读器。
# 如果您不指定该参数,则不会启动任何模型,并且它会假定您已经单独完成了这一步。
embed = TitanTakeoffEmbed(models=[reader_1])
# 等待阅读器部署完成,所需时间取决于模型大小和您的互联网速度
time.sleep(60)
# 返回嵌入查询,即发送到`embed`消费者组的List[float],我们刚刚启动了嵌入阅读器
print(embed.embed_query(
"我在哪里可以看到足球?", consumer_group="embed"
))
# 返回嵌入列表,即发送到`embed`消费者组的List[List[float]],我们刚刚启动了嵌入阅读器
print(embed.embed_document(
["文档1", "文档2"],
consumer_group="embed"
))"""
base_url: str = "http://localhost"
"""Titan Takeoff (Pro) 服务器的基本URL。默认值为 "http://localhost"。"""
port: int = 3000
"""Titan Takeoff(Pro)服务器的端口。默认值为3000。"""
mgmt_port: int = 3001
"""Titan Takeoff(Pro)服务器的管理端口。默认值为3001。"""
client: Any = None
"""Takeoff客户端Python SDK用于与Takeoff API进行交互"""
embed_consumer_groups: Set[str] = set()
"""Takeoff中包含嵌入模型的消费者组"""
[docs] def __init__(
self,
base_url: str = "http://localhost",
port: int = 3000,
mgmt_port: int = 3001,
models: List[ReaderConfig] = [],
):
"""初始化Titan Takeoff嵌入式包装器。
参数:
base_url (str, 可选): Takeoff推理服务器监听的基本URL。默认为"http://localhost"。
port (int, 可选): Takeoff推理API监听的端口号。默认为3000。
mgmt_port (int, 可选): Takeoff管理API监听的端口号。默认为3001。
models (List[ReaderConfig], 可选): 您想要在其中启动的任何读取器。默认为[]。
抛出:
ImportError: 如果您尚未安装takeoff-client,则会出现ImportError。要解决此问题,请运行 `pip install 'takeoff-client==0.4.0'`。
"""
self.base_url = base_url
self.port = port
self.mgmt_port = mgmt_port
try:
from takeoff_client import TakeoffClient
except ImportError:
raise ImportError(
"takeoff-client is required for TitanTakeoff. "
"Please install it with `pip install 'takeoff-client==0.4.0'`."
)
self.client = TakeoffClient(
self.base_url, port=self.port, mgmt_port=self.mgmt_port
)
for model in models:
self.client.create_reader(model)
if isinstance(model, dict):
self.embed_consumer_groups.add(model.get("consumer_group"))
else:
self.embed_consumer_groups.add(model.consumer_group)
super(TitanTakeoffEmbed, self).__init__()
def _embed(
self, input: Union[List[str], str], consumer_group: Optional[str]
) -> dict:
"""嵌入文本。
参数:
input (List[str]): 要嵌入的提示/文档或提示/文档列表
consumer_group (Optional[str]): 发送嵌入请求的消费者组。如果未指定,并且在初始化期间只指定了一个消费者组,则将使用该消费者组。如果在初始化期间指定了多个消费者组,则必须指定要使用的消费者组。
引发:
MissingConsumerGroup: 无法从初始化中推断出消费者组,必须在请求中指定。
返回:
Dict[str, Any]: 查询结果,{"result": List[List[float]]} 或 {"result": List[float]}
"""
if not consumer_group:
if len(self.embed_consumer_groups) == 1:
consumer_group = list(self.embed_consumer_groups)[0]
elif len(self.embed_consumer_groups) > 1:
raise MissingConsumerGroup(
"TakeoffEmbedding was initialized with multiple embedding reader"
"groups, you must specify which one to use."
)
else:
raise MissingConsumerGroup(
"You must specify what consumer group you want to send embedding"
"response to as TitanTakeoffEmbed was not initialized with an "
"embedding reader."
)
return self.client.embed(input, consumer_group)
[docs] def embed_documents(
self, texts: List[str], consumer_group: Optional[str] = None
) -> List[List[float]]:
"""嵌入文档。
参数:
texts (List[str]): 需要嵌入的提示/文档列表
consumer_group (Optional[str], optional): 发送请求的消费者组,包含嵌入模型。默认为None。
返回:
List[List[float]]: 嵌入列表
"""
return self._embed(texts, consumer_group)["result"]
[docs] def embed_query(
self, text: str, consumer_group: Optional[str] = None
) -> List[float]:
"""嵌入查询。
参数:
text(str):要嵌入的提示/文档
consumer_group(Optional[str],可选):发送请求的消费者组,其中包含嵌入模型。默认为None。
返回:
List[float]:嵌入
"""
return self._embed(text, consumer_group)["result"]