# flake8: noqa
from typing import Any, AsyncIterator, Dict, Iterator, List, Optional, Union
from langchain_core.pydantic_v1 import root_validator
from langchain_core.callbacks import (
AsyncCallbackManagerForLLMRun,
CallbackManagerForLLMRun,
)
from langchain_core.language_models.llms import LLM
from langchain_community.llms.utils import enforce_stop_tokens
from langchain_core.outputs import GenerationChunk
[docs]class DeepSparse(LLM):
"""神经魔术DeepSparse LLM接口。
要使用,您应该安装``deepsparse``或``deepsparse-nightly`` python包。请参阅https://github.com/neuralmagic/deepsparse
该接口允许您直接从[SparseZoo](https://sparsezoo.neuralmagic.com/?useCase=text_generation)部署优化的LLMs
示例:
.. code-block:: python
from langchain_community.llms import DeepSparse
llm = DeepSparse(model="zoo:nlg/text_generation/codegen_mono-350m/pytorch/huggingface/bigpython_bigquery_thepile/base_quant-none")
""" # noqa: E501
pipeline: Any #: :meta private:
model: str
"""模型文件或目录的路径,或SparseZoo模型存根的名称。"""
model_config: Optional[Dict[str, Any]] = None
"""传递给管道构造函数的关键字参数。
常见参数包括sequence_length,prompt_sequence_length"""
generation_config: Union[None, str, Dict] = None
"""GenerationConfig字典包含用于控制为每个提示生成的序列的参数。常见参数包括:
max_length,max_new_tokens,num_return_sequences,output_scores,
top_p,top_k,repetition_penalty。"""
streaming: bool = False
"""是否逐个标记流式传输结果。"""
@property
def _identifying_params(self) -> Dict[str, Any]:
"""获取识别参数。"""
return {
"model": self.model,
"model_config": self.model_config,
"generation_config": self.generation_config,
"streaming": self.streaming,
}
@property
def _llm_type(self) -> str:
"""llm的返回类型。"""
return "deepsparse"
@root_validator()
def validate_environment(cls, values: Dict) -> Dict:
"""验证``deepsparse``包是否已安装。"""
try:
from deepsparse import Pipeline
except ImportError:
raise ImportError(
"Could not import `deepsparse` package. "
"Please install it with `pip install deepsparse[llm]`"
)
model_config = values["model_config"] or {}
values["pipeline"] = Pipeline.create(
task="text_generation",
model_path=values["model"],
**model_config,
)
return values
def _call(
self,
prompt: str,
stop: Optional[List[str]] = None,
run_manager: Optional[CallbackManagerForLLMRun] = None,
**kwargs: Any,
) -> str:
"""根据提示生成文本。
参数:
prompt: 生成文本的提示。
stop: 遇到时停止生成的字符串列表。
返回:
生成的文本。
示例:
.. code-block:: python
from langchain_community.llms import DeepSparse
llm = DeepSparse(model="zoo:nlg/text_generation/codegen_mono-350m/pytorch/huggingface/bigpython_bigquery_thepile/base_quant-none")
llm.invoke("Tell me a joke.")
"""
if self.streaming:
combined_output = ""
for chunk in self._stream(
prompt=prompt, stop=stop, run_manager=run_manager, **kwargs
):
combined_output += chunk.text
text = combined_output
else:
text = (
self.pipeline(sequences=prompt, **self.generation_config)
.generations[0]
.text
)
if stop is not None:
text = enforce_stop_tokens(text, stop)
return text
async def _acall(
self,
prompt: str,
stop: Optional[List[str]] = None,
run_manager: Optional[AsyncCallbackManagerForLLMRun] = None,
**kwargs: Any,
) -> str:
"""根据提示生成文本。
参数:
prompt: 生成文本的提示。
stop: 遇到时停止生成的字符串列表。
返回:
生成的文本。
示例:
.. code-block:: python
from langchain_community.llms import DeepSparse
llm = DeepSparse(model="zoo:nlg/text_generation/codegen_mono-350m/pytorch/huggingface/bigpython_bigquery_thepile/base_quant-none")
llm.invoke("Tell me a joke.")
"""
if self.streaming:
combined_output = ""
async for chunk in self._astream(
prompt=prompt, stop=stop, run_manager=run_manager, **kwargs
):
combined_output += chunk.text
text = combined_output
else:
text = (
self.pipeline(sequences=prompt, **self.generation_config)
.generations[0]
.text
)
if stop is not None:
text = enforce_stop_tokens(text, stop)
return text
def _stream(
self,
prompt: str,
stop: Optional[List[str]] = None,
run_manager: Optional[CallbackManagerForLLMRun] = None,
**kwargs: Any,
) -> Iterator[GenerationChunk]:
"""以实时生成的方式产生结果对象。
它还使用类似于OpenAI LLM类同名方法的参数调用回调管理器的on_llm_new_token事件。
Args:
prompt: 传递给模型的提示。
stop: 生成时使用的可选停止词列表。
Returns:
代表正在生成的标记流的生成器。
Yields:
包含字符串标记的类似字典的对象。
Example:
.. code-block:: python
from langchain_community.llms import DeepSparse
llm = DeepSparse(
model="zoo:nlg/text_generation/codegen_mono-350m/pytorch/huggingface/bigpython_bigquery_thepile/base_quant-none",
streaming=True
)
for chunk in llm.stream("Tell me a joke",
stop=["'","
"]):
print(chunk, end='', flush=True) # noqa: T201
"""
inference = self.pipeline(
sequences=prompt, streaming=True, **self.generation_config
)
for token in inference:
chunk = GenerationChunk(text=token.generations[0].text)
yield chunk
if run_manager:
run_manager.on_llm_new_token(token=chunk.text)
async def _astream(
self,
prompt: str,
stop: Optional[List[str]] = None,
run_manager: Optional[AsyncCallbackManagerForLLMRun] = None,
**kwargs: Any,
) -> AsyncIterator[GenerationChunk]:
"""以实时生成的方式产生结果对象。
它还使用类似于OpenAI LLM类同名方法的参数调用回调管理器的on_llm_new_token事件。
Args:
prompt: 传递给模型的提示。
stop: 生成时使用的可选停止词列表。
Returns:
代表正在生成的标记流的生成器。
Yields:
包含字符串标记的类似字典的对象。
Example:
.. code-block:: python
from langchain_community.llms import DeepSparse
llm = DeepSparse(
model="zoo:nlg/text_generation/codegen_mono-350m/pytorch/huggingface/bigpython_bigquery_thepile/base_quant-none",
streaming=True
)
for chunk in llm.stream("Tell me a joke",
stop=["'","
"]):
print(chunk, end='', flush=True) # noqa: T201
"""
inference = self.pipeline(
sequences=prompt, streaming=True, **self.generation_config
)
for token in inference:
chunk = GenerationChunk(text=token.generations[0].text)
yield chunk
if run_manager:
await run_manager.on_llm_new_token(token=chunk.text)