Source code for langchain_community.embeddings.xinference

"""Xinference嵌入模型的封装器。"""

from typing import Any, List, Optional

from langchain_core.embeddings import Embeddings


[docs]class XinferenceEmbeddings(Embeddings): """运行 Xinference 嵌入模型。 要使用,您应该已安装 xinference 库: .. code-block:: bash pip install xinference 如果您只是使用 Xinference 提供的服务,可以使用 xinference_client 包: .. code-block:: bash pip install xinference_client 查看:https://github.com/xorbitsai/inference 要运行,您需要在一个服务器上启动 Xinference 监督员,并在其他服务器上启动 Xinference 工作者。 示例: 要启动 Xinference 的本地实例,请运行 .. code-block:: bash $ xinference 您还可以在分布式集群中部署 Xinference。以下是步骤: 启动监督员: .. code-block:: bash $ xinference-supervisor 如果您只是使用 Xinference 提供的服务,可以使用 xinference_client 包: .. code-block:: bash pip install xinference_client 启动工作者: .. code-block:: bash $ xinference-worker 然后,使用命令行界面(CLI)启动模型。 示例: .. code-block:: bash $ xinference launch -n orca -s 3 -q q4_0 它将返回一个模型 UID。然后您可以使用 LangChain 中的 Xinference 嵌入。 示例: .. code-block:: python from langchain_community.embeddings import XinferenceEmbeddings xinference = XinferenceEmbeddings( server_url="http://0.0.0.0:9997", model_uid = {model_uid} # 用从启动模型返回的模型 UID 替换 model_uid )""" # noqa: E501 client: Any server_url: Optional[str] """xinference服务器的URL""" model_uid: Optional[str] """启动模型的UID"""
[docs] def __init__( self, server_url: Optional[str] = None, model_uid: Optional[str] = None ): try: from xinference.client import RESTfulClient except ImportError: try: from xinference_client import RESTfulClient except ImportError as e: raise ImportError( "Could not import RESTfulClient from xinference. Please install it" " with `pip install xinference` or `pip install xinference_client`." ) from e super().__init__() if server_url is None: raise ValueError("Please provide server URL") if model_uid is None: raise ValueError("Please provide the model UID") self.server_url = server_url self.model_uid = model_uid self.client = RESTfulClient(server_url)
[docs] def embed_documents(self, texts: List[str]) -> List[List[float]]: """使用Xinference嵌入文档列表。 参数: texts:要嵌入的文本列表。 返回: 每个文本的嵌入列表。 """ model = self.client.get_model(self.model_uid) embeddings = [ model.create_embedding(text)["data"][0]["embedding"] for text in texts ] return [list(map(float, e)) for e in embeddings]
[docs] def embed_query(self, text: str) -> List[float]: """使用Xinference嵌入文档查询。 参数: text:要嵌入的文本。 返回: 文本的嵌入。 """ model = self.client.get_model(self.model_uid) embedding_res = model.create_embedding(text) embedding = embedding_res["data"][0]["embedding"] return list(map(float, embedding))