from typing import TYPE_CHECKING, Any, Dict, Generator, List, Mapping, Optional, Union
from langchain_core.callbacks import CallbackManagerForLLMRun
from langchain_core.language_models.llms import LLM
if TYPE_CHECKING:
from xinference.client import RESTfulChatModelHandle, RESTfulGenerateModelHandle
from xinference.model.llm.core import LlamaCppGenerateConfig
[docs]class Xinference(LLM):
"""`Xinference` 大规模模型推理服务。
要使用,您应该已安装 xinference 库:
.. code-block:: bash
pip install "xinference[all]"
查看:https://github.com/xorbitsai/inference
要运行,您需要在一个服务器上启动 Xinference 监督程序,并在其他服务器上启动 Xinference 工作程序
示例:
要启动 Xinference 的本地实例,请运行
.. code-block:: bash
$ xinference
您还可以在分布式集群中部署 Xinference。以下是步骤:
启动监督程序:
.. code-block:: bash
$ xinference-supervisor
启动工作程序:
.. code-block:: bash
$ xinference-worker
然后,使用命令行界面(CLI)启动模型。
示例:
.. code-block:: bash
$ xinference launch -n orca -s 3 -q q4_0
它将返回一个模型 UID。然后,您可以使用 LangChain 与 Xinference。
示例:
.. code-block:: python
from langchain_community.llms import Xinference
llm = Xinference(
server_url="http://0.0.0.0:9997",
model_uid = {model_uid} # 用从启动模型返回的模型 UID 替换 model_uid
)
llm.invoke(
prompt="Q: where can we visit in the capital of France? A:",
generate_config={"max_tokens": 1024, "stream": True},
)
要查看所有支持的内置模型,请运行:
.. code-block:: bash
$ xinference list --all""" # noqa: E501
client: Any
server_url: Optional[str]
"""xinference服务器的URL"""
model_uid: Optional[str]
"""启动模型的UID"""
model_kwargs: Dict[str, Any]
"""要传递给xinference.LLM的关键字参数"""
def __init__(
self,
server_url: Optional[str] = None,
model_uid: Optional[str] = None,
**model_kwargs: Any,
):
try:
from xinference.client import RESTfulClient
except ImportError as e:
raise ImportError(
"Could not import RESTfulClient from xinference. Please install it"
" with `pip install xinference`."
) from e
model_kwargs = model_kwargs or {}
super().__init__(
**{ # type: ignore[arg-type]
"server_url": server_url,
"model_uid": model_uid,
"model_kwargs": model_kwargs,
}
)
if self.server_url is None:
raise ValueError("Please provide server URL")
if self.model_uid is None:
raise ValueError("Please provide the model UID")
self.client = RESTfulClient(server_url)
@property
def _llm_type(self) -> str:
"""llm的返回类型。"""
return "xinference"
@property
def _identifying_params(self) -> Mapping[str, Any]:
"""获取识别参数。"""
return {
**{"server_url": self.server_url},
**{"model_uid": self.model_uid},
**{"model_kwargs": self.model_kwargs},
}
def _call(
self,
prompt: str,
stop: Optional[List[str]] = None,
run_manager: Optional[CallbackManagerForLLMRun] = None,
**kwargs: Any,
) -> str:
"""调用xinference模型并返回输出。
参数:
prompt: 用于生成的提示。
stop: 生成时使用的可选停止词列表。
generate_config: 用于生成的配置的可选字典。
返回:
模型生成的字符串。
"""
model = self.client.get_model(self.model_uid)
generate_config: "LlamaCppGenerateConfig" = kwargs.get("generate_config", {})
generate_config = {**self.model_kwargs, **generate_config}
if stop:
generate_config["stop"] = stop
if generate_config and generate_config.get("stream"):
combined_text_output = ""
for token in self._stream_generate(
model=model,
prompt=prompt,
run_manager=run_manager,
generate_config=generate_config,
):
combined_text_output += token
return combined_text_output
else:
completion = model.generate(prompt=prompt, generate_config=generate_config)
return completion["choices"][0]["text"]
def _stream_generate(
self,
model: Union["RESTfulGenerateModelHandle", "RESTfulChatModelHandle"],
prompt: str,
run_manager: Optional[CallbackManagerForLLMRun] = None,
generate_config: Optional["LlamaCppGenerateConfig"] = None,
) -> Generator[str, None, None]:
"""参数:
prompt:用于生成的提示。
model:用于生成的模型。
stop:生成时使用的可选停止词列表。
generate_config:用于生成的配置的可选字典。
产出:
一个字符串标记。
"""
streaming_response = model.generate(
prompt=prompt, generate_config=generate_config
)
for chunk in streaming_response:
if isinstance(chunk, dict):
choices = chunk.get("choices", [])
if choices:
choice = choices[0]
if isinstance(choice, dict):
token = choice.get("text", "")
log_probs = choice.get("logprobs")
if run_manager:
run_manager.on_llm_new_token(
token=token, verbose=self.verbose, log_probs=log_probs
)
yield token