Source code for langchain_community.llms.self_hosted_hugging_face

import importlib.util
import logging
from typing import Any, Callable, List, Mapping, Optional

from langchain_core.callbacks import CallbackManagerForLLMRun
from langchain_core.pydantic_v1 import Extra

from langchain_community.llms.self_hosted import SelfHostedPipeline
from langchain_community.llms.utils import enforce_stop_tokens

DEFAULT_MODEL_ID = "gpt2"
DEFAULT_TASK = "text-generation"
VALID_TASKS = ("text2text-generation", "text-generation", "summarization")

logger = logging.getLogger(__name__)


def _generate_text(
    pipeline: Any,
    prompt: str,
    *args: Any,
    stop: Optional[List[str]] = None,
    **kwargs: Any,
) -> str:
    """推断函数,用于发送到远程硬件。

接受一个Hugging Face pipeline(或更可能是指向集群对象存储中这样一个pipeline的键),并返回生成的文本。
"""
    response = pipeline(prompt, *args, **kwargs)
    if pipeline.task == "text-generation":
        # Text generation return includes the starter text.
        text = response[0]["generated_text"][len(prompt) :]
    elif pipeline.task == "text2text-generation":
        text = response[0]["generated_text"]
    elif pipeline.task == "summarization":
        text = response[0]["summary_text"]
    else:
        raise ValueError(
            f"Got invalid task {pipeline.task}, "
            f"currently only {VALID_TASKS} are supported"
        )
    if stop is not None:
        text = enforce_stop_tokens(text, stop)
    return text


def _load_transformer(
    model_id: str = DEFAULT_MODEL_ID,
    task: str = DEFAULT_TASK,
    device: int = 0,
    model_kwargs: Optional[dict] = None,
) -> Any:
    """将发送到远程硬件的推理函数。

接受一个huggingface模型ID,并返回一个用于该任务的流水线。
"""
    from transformers import AutoModelForCausalLM, AutoModelForSeq2SeqLM, AutoTokenizer
    from transformers import pipeline as hf_pipeline

    _model_kwargs = model_kwargs or {}
    tokenizer = AutoTokenizer.from_pretrained(model_id, **_model_kwargs)

    try:
        if task == "text-generation":
            model = AutoModelForCausalLM.from_pretrained(model_id, **_model_kwargs)
        elif task in ("text2text-generation", "summarization"):
            model = AutoModelForSeq2SeqLM.from_pretrained(model_id, **_model_kwargs)
        else:
            raise ValueError(
                f"Got invalid task {task}, "
                f"currently only {VALID_TASKS} are supported"
            )
    except ImportError as e:
        raise ImportError(
            f"Could not load the {task} model due to missing dependencies."
        ) from e

    if importlib.util.find_spec("torch") is not None:
        import torch

        cuda_device_count = torch.cuda.device_count()
        if device < -1 or (device >= cuda_device_count):
            raise ValueError(
                f"Got device=={device}, "
                f"device is required to be within [-1, {cuda_device_count})"
            )
        if device < 0 and cuda_device_count > 0:
            logger.warning(
                "Device has %d GPUs available. "
                "Provide device={deviceId} to `from_model_id` to use available"
                "GPUs for execution. deviceId is -1 for CPU and "
                "can be a positive integer associated with CUDA device id.",
                cuda_device_count,
            )

    pipeline = hf_pipeline(
        task=task,
        model=model,
        tokenizer=tokenizer,
        device=device,
        model_kwargs=_model_kwargs,
    )
    if pipeline.task not in VALID_TASKS:
        raise ValueError(
            f"Got invalid task {pipeline.task}, "
            f"currently only {VALID_TASKS} are supported"
        )
    return pipeline


[docs]class SelfHostedHuggingFaceLLM(SelfHostedPipeline): """使用HuggingFace Pipeline API 在自托管的远程硬件上运行。 支持的硬件包括在AWS、GCP、Azure和Lambda上自动启动的实例,以及通过IP地址和SSH凭据指定的服务器(例如本地、或其他云,如Paperspace、Coreweave等)。 要使用,应安装``runhouse`` python包。 目前仅支持`text-generation`、`text2text-generation`和`summarization`。 使用`from_model_id`的示例: ```python from langchain_community.llms import SelfHostedHuggingFaceLLM import runhouse as rh gpu = rh.cluster(name="rh-a10x", instance_type="A100:1") hf = SelfHostedHuggingFaceLLM( model_id="google/flan-t5-large", task="text2text-generation", hardware=gpu ) ``` 通过传递生成管道的fn来示例(因为管道不可序列化): ```python from langchain_community.llms import SelfHostedHuggingFaceLLM from transformers import AutoModelForCausalLM, AutoTokenizer, pipeline import runhouse as rh def get_pipeline(): model_id = "gpt2" tokenizer = AutoTokenizer.from_pretrained(model_id) model = AutoModelForCausalLM.from_pretrained(model_id) pipe = pipeline( "text-generation", model=model, tokenizer=tokenizer ) return pipe hf = SelfHostedHuggingFaceLLM( model_load_fn=get_pipeline, model_id="gpt2", hardware=gpu) ```""" model_id: str = DEFAULT_MODEL_ID """Hugging Face model_id用于加载模型。""" task: str = DEFAULT_TASK """Hugging Face 任务("text-generation","text2text-generation" 或 "summarization")。""" device: int = 0 """用于推理的设备。-1表示CPU,0表示GPU,1表示第二个GPU,依此类推。""" model_kwargs: Optional[dict] = None """传递给模型的关键字参数。""" hardware: Any """远程硬件发送推理函数。""" model_reqs: List[str] = ["./", "transformers", "torch"] """在硬件上安装推断模型所需的要求。""" model_load_fn: Callable = _load_transformer """在服务器上远程加载模型的函数。""" inference_fn: Callable = _generate_text #: :meta private: """发送到远程硬件的推理函数。""" class Config: """此pydantic对象的配置。""" extra = Extra.forbid def __init__(self, **kwargs: Any): """使用辅助函数远程构建流水线。 加载函数需要可导入才能在服务器上导入和运行,即在模块中而不是在REPL或闭包中。然后,初始化远程推断函数。 """ load_fn_kwargs = { "model_id": kwargs.get("model_id", DEFAULT_MODEL_ID), "task": kwargs.get("task", DEFAULT_TASK), "device": kwargs.get("device", 0), "model_kwargs": kwargs.get("model_kwargs", None), } super().__init__(load_fn_kwargs=load_fn_kwargs, **kwargs) @property def _identifying_params(self) -> Mapping[str, Any]: """获取识别参数。""" return { **{"model_id": self.model_id}, **{"model_kwargs": self.model_kwargs}, } @property def _llm_type(self) -> str: return "selfhosted_huggingface_pipeline" def _call( self, prompt: str, stop: Optional[List[str]] = None, run_manager: Optional[CallbackManagerForLLMRun] = None, **kwargs: Any, ) -> str: return self.client( pipeline=self.pipeline_ref, prompt=prompt, stop=stop, **kwargs )