Source code for langchain_community.llms.huggingface_text_gen_inference

import logging
from typing import Any, AsyncIterator, Dict, Iterator, List, Optional

from langchain_core._api.deprecation import deprecated
from langchain_core.callbacks import (
    AsyncCallbackManagerForLLMRun,
    CallbackManagerForLLMRun,
)
from langchain_core.language_models.llms import LLM
from langchain_core.outputs import GenerationChunk
from langchain_core.pydantic_v1 import Extra, Field, root_validator
from langchain_core.utils import get_pydantic_field_names

logger = logging.getLogger(__name__)


[docs]@deprecated("0.0.21", removal="0.3.0", alternative="HuggingFaceEndpoint") class HuggingFaceTextGenInference(LLM): """ HuggingFace文本生成API。 ! 此类已被弃用,您应该使用HuggingFaceEndpoint代替! 要使用,您应该安装`text-generation` python包并运行一个文本生成服务器。 示例: .. code-block:: python # 基本示例(非流式处理) llm = HuggingFaceTextGenInference( inference_server_url="http://localhost:8010/", max_new_tokens=512, top_k=10, top_p=0.95, typical_p=0.95, temperature=0.01, repetition_penalty=1.03, ) print(llm.invoke("What is Deep Learning?")) # noqa: T201 # 流式响应示例 from langchain_community.callbacks import streaming_stdout callbacks = [streaming_stdout.StreamingStdOutCallbackHandler()] llm = HuggingFaceTextGenInference( inference_server_url="http://localhost:8010/", max_new_tokens=512, top_k=10, top_p=0.95, typical_p=0.95, temperature=0.01, repetition_penalty=1.03, callbacks=callbacks, streaming=True ) print(llm.invoke("What is Deep Learning?")) # noqa: T201""" max_new_tokens: int = 512 """生成的令牌的最大数量""" top_k: Optional[int] = None """保留最高概率词汇标记的数量,用于进行top-k过滤。""" top_p: Optional[float] = 0.95 """如果设置为<1,则仅保留概率相加达到“top_p”或更高的最小一组最有可能的标记。""" typical_p: Optional[float] = 0.95 """典型的解码质量。有关更多信息,请参阅[自然语言生成的典型解码](https://arxiv.org/abs/2202.00666)。""" temperature: Optional[float] = 0.8 """用于模块化对数分布的值。""" repetition_penalty: Optional[float] = None """重复惩罚的参数。1.0表示没有惩罚。 更多细节请参见[此论文](https://arxiv.org/pdf/1909.05858.pdf)。""" return_full_text: bool = False """生成的文本是否需要在前面添加提示符""" truncate: Optional[int] = None """将输入的标记截断到给定的大小""" stop_sequences: List[str] = Field(default_factory=list) """如果生成了`stop_sequences`中的一个成员,则停止生成令牌。""" seed: Optional[int] = None """随机抽样种子""" inference_server_url: str = "" """文本生成推断实例基本URL""" timeout: int = 120 """超时时间(秒)""" streaming: bool = False """是否异步生成令牌流""" do_sample: bool = False """激活logits采样""" watermark: bool = False """使用[A Watermark for Large Language Models]进行水印处理 (https://arxiv.org/abs/2301.10226)""" server_kwargs: Dict[str, Any] = Field(default_factory=dict) """保存未明确指定的任何文本生成推理服务器参数""" model_kwargs: Dict[str, Any] = Field(default_factory=dict) """保存任何在`call`中有效但未明确指定的模型参数""" client: Any async_client: Any class Config: """此pydantic对象的配置。""" extra = Extra.forbid @root_validator(pre=True) def build_extra(cls, values: Dict[str, Any]) -> Dict[str, Any]: """从传入的额外参数构建额外的kwargs。""" all_required_field_names = get_pydantic_field_names(cls) extra = values.get("model_kwargs", {}) for field_name in list(values): if field_name in extra: raise ValueError(f"Found {field_name} supplied twice.") if field_name not in all_required_field_names: logger.warning( f"""WARNING! {field_name} is not default parameter. {field_name} was transferred to model_kwargs. Please confirm that {field_name} is what you intended.""" ) extra[field_name] = values.pop(field_name) invalid_model_kwargs = all_required_field_names.intersection(extra.keys()) if invalid_model_kwargs: raise ValueError( f"Parameters {invalid_model_kwargs} should be specified explicitly. " f"Instead they were passed in as part of `model_kwargs` parameter." ) values["model_kwargs"] = extra return values @root_validator() def validate_environment(cls, values: Dict) -> Dict: """验证Python包是否存在于环境中。""" try: import text_generation values["client"] = text_generation.Client( values["inference_server_url"], timeout=values["timeout"], **values["server_kwargs"], ) values["async_client"] = text_generation.AsyncClient( values["inference_server_url"], timeout=values["timeout"], **values["server_kwargs"], ) except ImportError: raise ImportError( "Could not import text_generation python package. " "Please install it with `pip install text_generation`." ) return values @property def _llm_type(self) -> str: """llm的返回类型。""" return "huggingface_textgen_inference" @property def _default_params(self) -> Dict[str, Any]: """获取调用文本生成推理API的默认参数。""" return { "max_new_tokens": self.max_new_tokens, "top_k": self.top_k, "top_p": self.top_p, "typical_p": self.typical_p, "temperature": self.temperature, "repetition_penalty": self.repetition_penalty, "return_full_text": self.return_full_text, "truncate": self.truncate, "stop_sequences": self.stop_sequences, "seed": self.seed, "do_sample": self.do_sample, "watermark": self.watermark, **self.model_kwargs, } def _invocation_params( self, runtime_stop: Optional[List[str]], **kwargs: Any ) -> Dict[str, Any]: params = {**self._default_params, **kwargs} params["stop_sequences"] = params["stop_sequences"] + (runtime_stop or []) return params def _call( self, prompt: str, stop: Optional[List[str]] = None, run_manager: Optional[CallbackManagerForLLMRun] = None, **kwargs: Any, ) -> str: if self.streaming: completion = "" for chunk in self._stream(prompt, stop, run_manager, **kwargs): completion += chunk.text return completion invocation_params = self._invocation_params(stop, **kwargs) res = self.client.generate(prompt, **invocation_params) # remove stop sequences from the end of the generated text for stop_seq in invocation_params["stop_sequences"]: if stop_seq in res.generated_text: res.generated_text = res.generated_text[ : res.generated_text.index(stop_seq) ] return res.generated_text async def _acall( self, prompt: str, stop: Optional[List[str]] = None, run_manager: Optional[AsyncCallbackManagerForLLMRun] = None, **kwargs: Any, ) -> str: if self.streaming: completion = "" async for chunk in self._astream(prompt, stop, run_manager, **kwargs): completion += chunk.text return completion invocation_params = self._invocation_params(stop, **kwargs) res = await self.async_client.generate(prompt, **invocation_params) # remove stop sequences from the end of the generated text for stop_seq in invocation_params["stop_sequences"]: if stop_seq in res.generated_text: res.generated_text = res.generated_text[ : res.generated_text.index(stop_seq) ] return res.generated_text def _stream( self, prompt: str, stop: Optional[List[str]] = None, run_manager: Optional[CallbackManagerForLLMRun] = None, **kwargs: Any, ) -> Iterator[GenerationChunk]: invocation_params = self._invocation_params(stop, **kwargs) for res in self.client.generate_stream(prompt, **invocation_params): # identify stop sequence in generated text, if any stop_seq_found: Optional[str] = None for stop_seq in invocation_params["stop_sequences"]: if stop_seq in res.token.text: stop_seq_found = stop_seq # identify text to yield text: Optional[str] = None if res.token.special: text = None elif stop_seq_found: text = res.token.text[: res.token.text.index(stop_seq_found)] else: text = res.token.text # yield text, if any if text: chunk = GenerationChunk(text=text) if run_manager: run_manager.on_llm_new_token(chunk.text) yield chunk # break if stop sequence found if stop_seq_found: break async def _astream( self, prompt: str, stop: Optional[List[str]] = None, run_manager: Optional[AsyncCallbackManagerForLLMRun] = None, **kwargs: Any, ) -> AsyncIterator[GenerationChunk]: invocation_params = self._invocation_params(stop, **kwargs) async for res in self.async_client.generate_stream(prompt, **invocation_params): # identify stop sequence in generated text, if any stop_seq_found: Optional[str] = None for stop_seq in invocation_params["stop_sequences"]: if stop_seq in res.token.text: stop_seq_found = stop_seq # identify text to yield text: Optional[str] = None if res.token.special: text = None elif stop_seq_found: text = res.token.text[: res.token.text.index(stop_seq_found)] else: text = res.token.text # yield text, if any if text: chunk = GenerationChunk(text=text) if run_manager: await run_manager.on_llm_new_token(chunk.text) yield chunk # break if stop sequence found if stop_seq_found: break