Source code for langchain_community.llms.xinference

from typing import TYPE_CHECKING, Any, Dict, Generator, List, Mapping, Optional, Union

from langchain_core.callbacks import CallbackManagerForLLMRun
from langchain_core.language_models.llms import LLM

if TYPE_CHECKING:
    from xinference.client import RESTfulChatModelHandle, RESTfulGenerateModelHandle
    from xinference.model.llm.core import LlamaCppGenerateConfig


[docs]class Xinference(LLM): """`Xinference` 大规模模型推理服务。 要使用,您应该已安装 xinference 库: .. code-block:: bash pip install "xinference[all]" 查看:https://github.com/xorbitsai/inference 要运行,您需要在一个服务器上启动 Xinference 监督程序,并在其他服务器上启动 Xinference 工作程序 示例: 要启动 Xinference 的本地实例,请运行 .. code-block:: bash $ xinference 您还可以在分布式集群中部署 Xinference。以下是步骤: 启动监督程序: .. code-block:: bash $ xinference-supervisor 启动工作程序: .. code-block:: bash $ xinference-worker 然后,使用命令行界面(CLI)启动模型。 示例: .. code-block:: bash $ xinference launch -n orca -s 3 -q q4_0 它将返回一个模型 UID。然后,您可以使用 LangChain 与 Xinference。 示例: .. code-block:: python from langchain_community.llms import Xinference llm = Xinference( server_url="http://0.0.0.0:9997", model_uid = {model_uid} # 用从启动模型返回的模型 UID 替换 model_uid ) llm.invoke( prompt="Q: where can we visit in the capital of France? A:", generate_config={"max_tokens": 1024, "stream": True}, ) 要查看所有支持的内置模型,请运行: .. code-block:: bash $ xinference list --all""" # noqa: E501 client: Any server_url: Optional[str] """xinference服务器的URL""" model_uid: Optional[str] """启动模型的UID""" model_kwargs: Dict[str, Any] """要传递给xinference.LLM的关键字参数""" def __init__( self, server_url: Optional[str] = None, model_uid: Optional[str] = None, **model_kwargs: Any, ): try: from xinference.client import RESTfulClient except ImportError as e: raise ImportError( "Could not import RESTfulClient from xinference. Please install it" " with `pip install xinference`." ) from e model_kwargs = model_kwargs or {} super().__init__( **{ # type: ignore[arg-type] "server_url": server_url, "model_uid": model_uid, "model_kwargs": model_kwargs, } ) if self.server_url is None: raise ValueError("Please provide server URL") if self.model_uid is None: raise ValueError("Please provide the model UID") self.client = RESTfulClient(server_url) @property def _llm_type(self) -> str: """llm的返回类型。""" return "xinference" @property def _identifying_params(self) -> Mapping[str, Any]: """获取识别参数。""" return { **{"server_url": self.server_url}, **{"model_uid": self.model_uid}, **{"model_kwargs": self.model_kwargs}, } def _call( self, prompt: str, stop: Optional[List[str]] = None, run_manager: Optional[CallbackManagerForLLMRun] = None, **kwargs: Any, ) -> str: """调用xinference模型并返回输出。 参数: prompt: 用于生成的提示。 stop: 生成时使用的可选停止词列表。 generate_config: 用于生成的配置的可选字典。 返回: 模型生成的字符串。 """ model = self.client.get_model(self.model_uid) generate_config: "LlamaCppGenerateConfig" = kwargs.get("generate_config", {}) generate_config = {**self.model_kwargs, **generate_config} if stop: generate_config["stop"] = stop if generate_config and generate_config.get("stream"): combined_text_output = "" for token in self._stream_generate( model=model, prompt=prompt, run_manager=run_manager, generate_config=generate_config, ): combined_text_output += token return combined_text_output else: completion = model.generate(prompt=prompt, generate_config=generate_config) return completion["choices"][0]["text"] def _stream_generate( self, model: Union["RESTfulGenerateModelHandle", "RESTfulChatModelHandle"], prompt: str, run_manager: Optional[CallbackManagerForLLMRun] = None, generate_config: Optional["LlamaCppGenerateConfig"] = None, ) -> Generator[str, None, None]: """参数: prompt:用于生成的提示。 model:用于生成的模型。 stop:生成时使用的可选停止词列表。 generate_config:用于生成的配置的可选字典。 产出: 一个字符串标记。 """ streaming_response = model.generate( prompt=prompt, generate_config=generate_config ) for chunk in streaming_response: if isinstance(chunk, dict): choices = chunk.get("choices", []) if choices: choice = choices[0] if isinstance(choice, dict): token = choice.get("text", "") log_probs = choice.get("logprobs") if run_manager: run_manager.on_llm_new_token( token=token, verbose=self.verbose, log_probs=log_probs ) yield token