"""用于大型语言模型的基本接口。"""
from __future__ import annotations
import asyncio
import functools
import inspect
import json
import logging
import uuid
import warnings
from abc import ABC, abstractmethod
from pathlib import Path
from typing import (
Any,
AsyncIterator,
Callable,
Dict,
Iterator,
List,
Optional,
Sequence,
Tuple,
Type,
Union,
cast,
)
import yaml
from tenacity import (
RetryCallState,
before_sleep_log,
retry,
retry_base,
retry_if_exception_type,
stop_after_attempt,
wait_exponential,
)
from langchain_core._api import deprecated
from langchain_core.caches import BaseCache
from langchain_core.callbacks import (
AsyncCallbackManager,
AsyncCallbackManagerForLLMRun,
BaseCallbackManager,
CallbackManager,
CallbackManagerForLLMRun,
Callbacks,
)
from langchain_core.globals import get_llm_cache
from langchain_core.language_models.base import BaseLanguageModel, LanguageModelInput
from langchain_core.load import dumpd
from langchain_core.messages import (
AIMessage,
BaseMessage,
convert_to_messages,
get_buffer_string,
)
from langchain_core.outputs import Generation, GenerationChunk, LLMResult, RunInfo
from langchain_core.prompt_values import ChatPromptValue, PromptValue, StringPromptValue
from langchain_core.pydantic_v1 import Field, root_validator
from langchain_core.runnables import RunnableConfig, ensure_config, get_config_list
from langchain_core.runnables.config import run_in_executor
logger = logging.getLogger(__name__)
@functools.lru_cache
def _log_error_once(msg: str) -> None:
"""记录一个错误。"""
logger.error(msg)
[docs]def create_base_retry_decorator(
error_types: List[Type[BaseException]],
max_retries: int = 1,
run_manager: Optional[
Union[AsyncCallbackManagerForLLMRun, CallbackManagerForLLMRun]
] = None,
) -> Callable[[Any], Any]:
"""为给定的LLM创建一个重试装饰器,并提供错误类型列表。"""
_logging = before_sleep_log(logger, logging.WARNING)
def _before_sleep(retry_state: RetryCallState) -> None:
_logging(retry_state)
if run_manager:
if isinstance(run_manager, AsyncCallbackManagerForLLMRun):
coro = run_manager.on_retry(retry_state)
try:
loop = asyncio.get_event_loop()
if loop.is_running():
loop.create_task(coro)
else:
asyncio.run(coro)
except Exception as e:
_log_error_once(f"Error in on_retry: {e}")
else:
run_manager.on_retry(retry_state)
return None
min_seconds = 4
max_seconds = 10
# Wait 2^x * 1 second between each retry starting with
# 4 seconds, then up to 10 seconds, then 10 seconds afterwards
retry_instance: "retry_base" = retry_if_exception_type(error_types[0])
for error in error_types[1:]:
retry_instance = retry_instance | retry_if_exception_type(error)
return retry(
reraise=True,
stop=stop_after_attempt(max_retries),
wait=wait_exponential(multiplier=1, min=min_seconds, max=max_seconds),
retry=retry_instance,
before_sleep=_before_sleep,
)
def _resolve_cache(cache: Union[BaseCache, bool, None]) -> Optional[BaseCache]:
"""解决缓存问题。"""
if isinstance(cache, BaseCache):
llm_cache = cache
elif cache is None:
llm_cache = get_llm_cache()
elif cache is True:
llm_cache = get_llm_cache()
if llm_cache is None:
raise ValueError(
"No global cache was configured. Use `set_llm_cache`."
"to set a global cache if you want to use a global cache."
"Otherwise either pass a cache object or set cache to False/None"
)
elif cache is False:
llm_cache = None
else:
raise ValueError(f"Unsupported cache value {cache}")
return llm_cache
[docs]def get_prompts(
params: Dict[str, Any],
prompts: List[str],
cache: Optional[Union[BaseCache, bool, None]] = None,
) -> Tuple[Dict[int, List], str, List[int], List[str]]:
"""获取已缓存的提示。"""
llm_string = str(sorted([(k, v) for k, v in params.items()]))
missing_prompts = []
missing_prompt_idxs = []
existing_prompts = {}
llm_cache = _resolve_cache(cache)
for i, prompt in enumerate(prompts):
if llm_cache:
cache_val = llm_cache.lookup(prompt, llm_string)
if isinstance(cache_val, list):
existing_prompts[i] = cache_val
else:
missing_prompts.append(prompt)
missing_prompt_idxs.append(i)
return existing_prompts, llm_string, missing_prompt_idxs, missing_prompts
[docs]async def aget_prompts(
params: Dict[str, Any],
prompts: List[str],
cache: Optional[Union[BaseCache, bool, None]] = None,
) -> Tuple[Dict[int, List], str, List[int], List[str]]:
"""获取已缓存的提示。 Async version."""
llm_string = str(sorted([(k, v) for k, v in params.items()]))
missing_prompts = []
missing_prompt_idxs = []
existing_prompts = {}
llm_cache = _resolve_cache(cache)
for i, prompt in enumerate(prompts):
if llm_cache:
cache_val = await llm_cache.alookup(prompt, llm_string)
if isinstance(cache_val, list):
existing_prompts[i] = cache_val
else:
missing_prompts.append(prompt)
missing_prompt_idxs.append(i)
return existing_prompts, llm_string, missing_prompt_idxs, missing_prompts
[docs]def update_cache(
cache: Union[BaseCache, bool, None],
existing_prompts: Dict[int, List],
llm_string: str,
missing_prompt_idxs: List[int],
new_results: LLMResult,
prompts: List[str],
) -> Optional[dict]:
"""更新缓存并获取LLM输出。"""
llm_cache = _resolve_cache(cache)
for i, result in enumerate(new_results.generations):
existing_prompts[missing_prompt_idxs[i]] = result
prompt = prompts[missing_prompt_idxs[i]]
if llm_cache is not None:
llm_cache.update(prompt, llm_string, result)
llm_output = new_results.llm_output
return llm_output
[docs]async def aupdate_cache(
cache: Union[BaseCache, bool, None],
existing_prompts: Dict[int, List],
llm_string: str,
missing_prompt_idxs: List[int],
new_results: LLMResult,
prompts: List[str],
) -> Optional[dict]:
"""更新缓存并获取LLM输出。 Async version"""
llm_cache = _resolve_cache(cache)
for i, result in enumerate(new_results.generations):
existing_prompts[missing_prompt_idxs[i]] = result
prompt = prompts[missing_prompt_idxs[i]]
if llm_cache:
await llm_cache.aupdate(prompt, llm_string, result)
llm_output = new_results.llm_output
return llm_output
[docs]class BaseLLM(BaseLanguageModel[str], ABC):
"""基础LLM抽象接口。
应该接收一个提示并返回一个字符串。
"""
callback_manager: Optional[BaseCallbackManager] = Field(default=None, exclude=True)
"""[已弃用]"""
class Config:
"""这个pydantic对象的配置。"""
arbitrary_types_allowed = True
@root_validator()
def raise_deprecation(cls, values: Dict) -> Dict:
"""如果使用callback_manager,则发出弃用警告。"""
if values.get("callback_manager") is not None:
warnings.warn(
"callback_manager is deprecated. Please use callbacks instead.",
DeprecationWarning,
)
values["callbacks"] = values.pop("callback_manager", None)
return values
# --- Runnable methods ---
@property
def OutputType(self) -> Type[str]:
"""获取此可运行程序的输入类型。"""
return str
def _convert_input(self, input: LanguageModelInput) -> PromptValue:
if isinstance(input, PromptValue):
return input
elif isinstance(input, str):
return StringPromptValue(text=input)
elif isinstance(input, Sequence):
return ChatPromptValue(messages=convert_to_messages(input))
else:
raise ValueError(
f"Invalid input type {type(input)}. "
"Must be a PromptValue, str, or list of BaseMessages."
)
[docs] def invoke(
self,
input: LanguageModelInput,
config: Optional[RunnableConfig] = None,
*,
stop: Optional[List[str]] = None,
**kwargs: Any,
) -> str:
config = ensure_config(config)
return (
self.generate_prompt(
[self._convert_input(input)],
stop=stop,
callbacks=config.get("callbacks"),
tags=config.get("tags"),
metadata=config.get("metadata"),
run_name=config.get("run_name"),
run_id=config.pop("run_id", None),
**kwargs,
)
.generations[0][0]
.text
)
[docs] async def ainvoke(
self,
input: LanguageModelInput,
config: Optional[RunnableConfig] = None,
*,
stop: Optional[List[str]] = None,
**kwargs: Any,
) -> str:
config = ensure_config(config)
llm_result = await self.agenerate_prompt(
[self._convert_input(input)],
stop=stop,
callbacks=config.get("callbacks"),
tags=config.get("tags"),
metadata=config.get("metadata"),
run_name=config.get("run_name"),
run_id=config.pop("run_id", None),
**kwargs,
)
return llm_result.generations[0][0].text
[docs] def batch(
self,
inputs: List[LanguageModelInput],
config: Optional[Union[RunnableConfig, List[RunnableConfig]]] = None,
*,
return_exceptions: bool = False,
**kwargs: Any,
) -> List[str]:
if not inputs:
return []
config = get_config_list(config, len(inputs))
max_concurrency = config[0].get("max_concurrency")
if max_concurrency is None:
try:
llm_result = self.generate_prompt(
[self._convert_input(input) for input in inputs],
callbacks=[c.get("callbacks") for c in config],
tags=[c.get("tags") for c in config],
metadata=[c.get("metadata") for c in config],
run_name=[c.get("run_name") for c in config],
**kwargs,
)
return [g[0].text for g in llm_result.generations]
except Exception as e:
if return_exceptions:
return cast(List[str], [e for _ in inputs])
else:
raise e
else:
batches = [
inputs[i : i + max_concurrency]
for i in range(0, len(inputs), max_concurrency)
]
config = [{**c, "max_concurrency": None} for c in config] # type: ignore[misc]
return [
output
for i, batch in enumerate(batches)
for output in self.batch(
batch,
config=config[i * max_concurrency : (i + 1) * max_concurrency],
return_exceptions=return_exceptions,
**kwargs,
)
]
[docs] async def abatch(
self,
inputs: List[LanguageModelInput],
config: Optional[Union[RunnableConfig, List[RunnableConfig]]] = None,
*,
return_exceptions: bool = False,
**kwargs: Any,
) -> List[str]:
if not inputs:
return []
config = get_config_list(config, len(inputs))
max_concurrency = config[0].get("max_concurrency")
if max_concurrency is None:
try:
llm_result = await self.agenerate_prompt(
[self._convert_input(input) for input in inputs],
callbacks=[c.get("callbacks") for c in config],
tags=[c.get("tags") for c in config],
metadata=[c.get("metadata") for c in config],
run_name=[c.get("run_name") for c in config],
**kwargs,
)
return [g[0].text for g in llm_result.generations]
except Exception as e:
if return_exceptions:
return cast(List[str], [e for _ in inputs])
else:
raise e
else:
batches = [
inputs[i : i + max_concurrency]
for i in range(0, len(inputs), max_concurrency)
]
config = [{**c, "max_concurrency": None} for c in config] # type: ignore[misc]
return [
output
for i, batch in enumerate(batches)
for output in await self.abatch(
batch,
config=config[i * max_concurrency : (i + 1) * max_concurrency],
return_exceptions=return_exceptions,
**kwargs,
)
]
[docs] def stream(
self,
input: LanguageModelInput,
config: Optional[RunnableConfig] = None,
*,
stop: Optional[List[str]] = None,
**kwargs: Any,
) -> Iterator[str]:
if type(self)._stream == BaseLLM._stream:
# model doesn't implement streaming, so use default implementation
yield self.invoke(input, config=config, stop=stop, **kwargs)
else:
prompt = self._convert_input(input).to_string()
config = ensure_config(config)
params = self.dict()
params["stop"] = stop
params = {**params, **kwargs}
options = {"stop": stop}
callback_manager = CallbackManager.configure(
config.get("callbacks"),
self.callbacks,
self.verbose,
config.get("tags"),
self.tags,
config.get("metadata"),
self.metadata,
)
(run_manager,) = callback_manager.on_llm_start(
dumpd(self),
[prompt],
invocation_params=params,
options=options,
name=config.get("run_name"),
run_id=config.pop("run_id", None),
batch_size=1,
)
generation: Optional[GenerationChunk] = None
try:
for chunk in self._stream(
prompt, stop=stop, run_manager=run_manager, **kwargs
):
yield chunk.text
if generation is None:
generation = chunk
else:
generation += chunk
assert generation is not None
except BaseException as e:
run_manager.on_llm_error(
e,
response=LLMResult(
generations=[[generation]] if generation else []
),
)
raise e
else:
run_manager.on_llm_end(LLMResult(generations=[[generation]]))
[docs] async def astream(
self,
input: LanguageModelInput,
config: Optional[RunnableConfig] = None,
*,
stop: Optional[List[str]] = None,
**kwargs: Any,
) -> AsyncIterator[str]:
if (
type(self)._astream is BaseLLM._astream
and type(self)._stream is BaseLLM._stream
):
yield await self.ainvoke(input, config=config, stop=stop, **kwargs)
return
prompt = self._convert_input(input).to_string()
config = ensure_config(config)
params = self.dict()
params["stop"] = stop
params = {**params, **kwargs}
options = {"stop": stop}
callback_manager = AsyncCallbackManager.configure(
config.get("callbacks"),
self.callbacks,
self.verbose,
config.get("tags"),
self.tags,
config.get("metadata"),
self.metadata,
)
(run_manager,) = await callback_manager.on_llm_start(
dumpd(self),
[prompt],
invocation_params=params,
options=options,
name=config.get("run_name"),
run_id=config.pop("run_id", None),
batch_size=1,
)
generation: Optional[GenerationChunk] = None
try:
async for chunk in self._astream(
prompt,
stop=stop,
run_manager=run_manager,
**kwargs,
):
yield chunk.text
if generation is None:
generation = chunk
else:
generation += chunk
assert generation is not None
except BaseException as e:
await run_manager.on_llm_error(
e,
response=LLMResult(generations=[[generation]] if generation else []),
)
raise e
else:
await run_manager.on_llm_end(LLMResult(generations=[[generation]]))
# --- Custom methods ---
@abstractmethod
def _generate(
self,
prompts: List[str],
stop: Optional[List[str]] = None,
run_manager: Optional[CallbackManagerForLLMRun] = None,
**kwargs: Any,
) -> LLMResult:
"""在给定的提示上运行LLM。"""
async def _agenerate(
self,
prompts: List[str],
stop: Optional[List[str]] = None,
run_manager: Optional[AsyncCallbackManagerForLLMRun] = None,
**kwargs: Any,
) -> LLMResult:
"""在给定的提示上运行LLM。"""
return await run_in_executor(
None,
self._generate,
prompts,
stop,
run_manager.get_sync() if run_manager else None,
**kwargs,
)
def _stream(
self,
prompt: str,
stop: Optional[List[str]] = None,
run_manager: Optional[CallbackManagerForLLMRun] = None,
**kwargs: Any,
) -> Iterator[GenerationChunk]:
"""在给定提示上流式传输LLM。
这个方法应该被支持流式传输的子类重写。
如果没有实现,默认行为是调用stream将回退到模型的非流式版本,并将输出作为单个块返回。
参数:
prompt: 生成的提示。
stop: 生成时使用的停止词。模型输出在这些子字符串的第一次出现时被截断。
run_manager: 运行的回调管理器。
**kwargs: 任意额外的关键字参数。这些通常被传递给模型提供者API调用。
返回:
一个GenerationChunks的迭代器。
"""
raise NotImplementedError()
async def _astream(
self,
prompt: str,
stop: Optional[List[str]] = None,
run_manager: Optional[AsyncCallbackManagerForLLMRun] = None,
**kwargs: Any,
) -> AsyncIterator[GenerationChunk]:
"""一个异步版本的_stream方法。
默认实现使用同步的_stream方法,并将其包装在异步迭代器中。需要提供真正的异步实现的子类应该重写这个方法。
参数:
prompt: 用于生成的提示。
stop: 生成时使用的停止词。模型输出在这些子字符串的第一次出现时被截断。
run_manager: 运行的回调管理器。
**kwargs: 任意额外的关键字参数。通常这些参数会传递给模型提供者的API调用。
返回:
一个GenerationChunks的异步迭代器。
"""
iterator = await run_in_executor(
None,
self._stream,
prompt,
stop,
run_manager.get_sync() if run_manager else None,
**kwargs,
)
done = object()
while True:
item = await run_in_executor(
None,
next,
iterator,
done, # type: ignore[call-arg, arg-type]
)
if item is done:
break
yield item # type: ignore[misc]
[docs] def generate_prompt(
self,
prompts: List[PromptValue],
stop: Optional[List[str]] = None,
callbacks: Optional[Union[Callbacks, List[Callbacks]]] = None,
**kwargs: Any,
) -> LLMResult:
prompt_strings = [p.to_string() for p in prompts]
return self.generate(prompt_strings, stop=stop, callbacks=callbacks, **kwargs)
[docs] async def agenerate_prompt(
self,
prompts: List[PromptValue],
stop: Optional[List[str]] = None,
callbacks: Optional[Union[Callbacks, List[Callbacks]]] = None,
**kwargs: Any,
) -> LLMResult:
prompt_strings = [p.to_string() for p in prompts]
return await self.agenerate(
prompt_strings, stop=stop, callbacks=callbacks, **kwargs
)
def _generate_helper(
self,
prompts: List[str],
stop: Optional[List[str]],
run_managers: List[CallbackManagerForLLMRun],
new_arg_supported: bool,
**kwargs: Any,
) -> LLMResult:
try:
output = (
self._generate(
prompts,
stop=stop,
# TODO: support multiple run managers
run_manager=run_managers[0] if run_managers else None,
**kwargs,
)
if new_arg_supported
else self._generate(prompts, stop=stop)
)
except BaseException as e:
for run_manager in run_managers:
run_manager.on_llm_error(e, response=LLMResult(generations=[]))
raise e
flattened_outputs = output.flatten()
for manager, flattened_output in zip(run_managers, flattened_outputs):
manager.on_llm_end(flattened_output)
if run_managers:
output.run = [
RunInfo(run_id=run_manager.run_id) for run_manager in run_managers
]
return output
[docs] def generate(
self,
prompts: List[str],
stop: Optional[List[str]] = None,
callbacks: Optional[Union[Callbacks, List[Callbacks]]] = None,
*,
tags: Optional[Union[List[str], List[List[str]]]] = None,
metadata: Optional[Union[Dict[str, Any], List[Dict[str, Any]]]] = None,
run_name: Optional[Union[str, List[str]]] = None,
run_id: Optional[Union[uuid.UUID, List[Optional[uuid.UUID]]]] = None,
**kwargs: Any,
) -> LLMResult:
"""将一系列提示传递给模型并返回生成结果。
此方法应该利用那些暴露批处理API的模型进行批处理调用。
在以下情况下使用此方法:
1. 利用批处理调用,
2. 需要从模型获得更多输出而不仅仅是顶部生成的值,
3. 正在构建对基础语言模型类型不可知的链(例如,纯文本完成模型与聊天模型)。
参数:
prompts: 字符串提示的列表。
stop: 生成时要使用的停止词。模型输出在这些子字符串的第一次出现时被截断。
callbacks: 要传递的回调函数。用于在整个生成过程中执行额外功能,如日志记录或流式传输。
**kwargs: 任意额外的关键字参数。这些通常传递给模型提供者API调用。
返回:
一个LLMResult,其中包含每个输入提示的候选生成列表和额外的模型提供者特定输出。
"""
if not isinstance(prompts, list):
raise ValueError(
"Argument 'prompts' is expected to be of type List[str], received"
f" argument of type {type(prompts)}."
)
# Create callback managers
if (
isinstance(callbacks, list)
and callbacks
and (
isinstance(callbacks[0], (list, BaseCallbackManager))
or callbacks[0] is None
)
):
# We've received a list of callbacks args to apply to each input
assert len(callbacks) == len(prompts)
assert tags is None or (
isinstance(tags, list) and len(tags) == len(prompts)
)
assert metadata is None or (
isinstance(metadata, list) and len(metadata) == len(prompts)
)
assert run_name is None or (
isinstance(run_name, list) and len(run_name) == len(prompts)
)
callbacks = cast(List[Callbacks], callbacks)
tags_list = cast(List[Optional[List[str]]], tags or ([None] * len(prompts)))
metadata_list = cast(
List[Optional[Dict[str, Any]]], metadata or ([{}] * len(prompts))
)
run_name_list = run_name or cast(
List[Optional[str]], ([None] * len(prompts))
)
callback_managers = [
CallbackManager.configure(
callback,
self.callbacks,
self.verbose,
tag,
self.tags,
meta,
self.metadata,
)
for callback, tag, meta in zip(callbacks, tags_list, metadata_list)
]
else:
# We've received a single callbacks arg to apply to all inputs
callback_managers = [
CallbackManager.configure(
cast(Callbacks, callbacks),
self.callbacks,
self.verbose,
cast(List[str], tags),
self.tags,
cast(Dict[str, Any], metadata),
self.metadata,
)
] * len(prompts)
run_name_list = [cast(Optional[str], run_name)] * len(prompts)
run_ids_list = self._get_run_ids_list(run_id, prompts)
params = self.dict()
params["stop"] = stop
options = {"stop": stop}
(
existing_prompts,
llm_string,
missing_prompt_idxs,
missing_prompts,
) = get_prompts(params, prompts, self.cache)
new_arg_supported = inspect.signature(self._generate).parameters.get(
"run_manager"
)
if (self.cache is None and get_llm_cache() is None) or self.cache is False:
run_managers = [
callback_manager.on_llm_start(
dumpd(self),
[prompt],
invocation_params=params,
options=options,
name=run_name,
batch_size=len(prompts),
run_id=run_id_,
)[0]
for callback_manager, prompt, run_name, run_id_ in zip(
callback_managers, prompts, run_name_list, run_ids_list
)
]
output = self._generate_helper(
prompts, stop, run_managers, bool(new_arg_supported), **kwargs
)
return output
if len(missing_prompts) > 0:
run_managers = [
callback_managers[idx].on_llm_start(
dumpd(self),
[prompts[idx]],
invocation_params=params,
options=options,
name=run_name_list[idx],
batch_size=len(missing_prompts),
)[0]
for idx in missing_prompt_idxs
]
new_results = self._generate_helper(
missing_prompts, stop, run_managers, bool(new_arg_supported), **kwargs
)
llm_output = update_cache(
self.cache,
existing_prompts,
llm_string,
missing_prompt_idxs,
new_results,
prompts,
)
run_info = (
[RunInfo(run_id=run_manager.run_id) for run_manager in run_managers]
if run_managers
else None
)
else:
llm_output = {}
run_info = None
generations = [existing_prompts[i] for i in range(len(prompts))]
return LLMResult(generations=generations, llm_output=llm_output, run=run_info)
@staticmethod
def _get_run_ids_list(
run_id: Optional[Union[uuid.UUID, List[Optional[uuid.UUID]]]], prompts: list
) -> list:
if run_id is None:
return [None] * len(prompts)
if isinstance(run_id, list):
if len(run_id) != len(prompts):
raise ValueError(
"Number of manually provided run_id's does not match batch length."
f" {len(run_id)} != {len(prompts)}"
)
return run_id
return [run_id] + [None] * (len(prompts) - 1)
async def _agenerate_helper(
self,
prompts: List[str],
stop: Optional[List[str]],
run_managers: List[AsyncCallbackManagerForLLMRun],
new_arg_supported: bool,
**kwargs: Any,
) -> LLMResult:
try:
output = (
await self._agenerate(
prompts,
stop=stop,
run_manager=run_managers[0] if run_managers else None,
**kwargs,
)
if new_arg_supported
else await self._agenerate(prompts, stop=stop)
)
except BaseException as e:
await asyncio.gather(
*[
run_manager.on_llm_error(e, response=LLMResult(generations=[]))
for run_manager in run_managers
]
)
raise e
flattened_outputs = output.flatten()
await asyncio.gather(
*[
run_manager.on_llm_end(flattened_output)
for run_manager, flattened_output in zip(
run_managers, flattened_outputs
)
]
)
if run_managers:
output.run = [
RunInfo(run_id=run_manager.run_id) for run_manager in run_managers
]
return output
[docs] async def agenerate(
self,
prompts: List[str],
stop: Optional[List[str]] = None,
callbacks: Optional[Union[Callbacks, List[Callbacks]]] = None,
*,
tags: Optional[Union[List[str], List[List[str]]]] = None,
metadata: Optional[Union[Dict[str, Any], List[Dict[str, Any]]]] = None,
run_name: Optional[Union[str, List[str]]] = None,
run_id: Optional[Union[uuid.UUID, List[Optional[uuid.UUID]]]] = None,
**kwargs: Any,
) -> LLMResult:
"""将一系列提示异步传递给模型并返回生成结果。
该方法应该利用暴露批处理API的模型进行批处理调用。
在以下情况下使用此方法:
1. 利用批处理调用,
2. 需要模型生成的输出不仅仅是顶部生成的值,
3. 构建对底层语言模型类型不可知的链式结构(例如,纯文本完成模型与聊天模型)。
参数:
prompts: 字符串提示的列表。
stop: 生成时使用的停止词。模型输出在出现这些子字符串的第一次发生时被截断。
callbacks: 要传递的回调。用于在生成过程中执行额外功能,如日志记录或流式处理。
**kwargs: 任意额外的关键字参数。这些通常被传递给模型提供者的API调用。
返回:
一个LLMResult,其中包含每个输入提示的候选生成列表和额外的模型提供者特定输出。
"""
# Create callback managers
if isinstance(callbacks, list) and (
isinstance(callbacks[0], (list, BaseCallbackManager))
or callbacks[0] is None
):
# We've received a list of callbacks args to apply to each input
assert len(callbacks) == len(prompts)
assert tags is None or (
isinstance(tags, list) and len(tags) == len(prompts)
)
assert metadata is None or (
isinstance(metadata, list) and len(metadata) == len(prompts)
)
assert run_name is None or (
isinstance(run_name, list) and len(run_name) == len(prompts)
)
callbacks = cast(List[Callbacks], callbacks)
tags_list = cast(List[Optional[List[str]]], tags or ([None] * len(prompts)))
metadata_list = cast(
List[Optional[Dict[str, Any]]], metadata or ([{}] * len(prompts))
)
run_name_list = run_name or cast(
List[Optional[str]], ([None] * len(prompts))
)
callback_managers = [
AsyncCallbackManager.configure(
callback,
self.callbacks,
self.verbose,
tag,
self.tags,
meta,
self.metadata,
)
for callback, tag, meta in zip(callbacks, tags_list, metadata_list)
]
else:
# We've received a single callbacks arg to apply to all inputs
callback_managers = [
AsyncCallbackManager.configure(
cast(Callbacks, callbacks),
self.callbacks,
self.verbose,
cast(List[str], tags),
self.tags,
cast(Dict[str, Any], metadata),
self.metadata,
)
] * len(prompts)
run_name_list = [cast(Optional[str], run_name)] * len(prompts)
run_ids_list = self._get_run_ids_list(run_id, prompts)
params = self.dict()
params["stop"] = stop
options = {"stop": stop}
(
existing_prompts,
llm_string,
missing_prompt_idxs,
missing_prompts,
) = await aget_prompts(params, prompts, self.cache)
# Verify whether the cache is set, and if the cache is set,
# verify whether the cache is available.
new_arg_supported = inspect.signature(self._agenerate).parameters.get(
"run_manager"
)
if (self.cache is None and get_llm_cache() is None) or self.cache is False:
run_managers = await asyncio.gather(
*[
callback_manager.on_llm_start(
dumpd(self),
[prompt],
invocation_params=params,
options=options,
name=run_name,
batch_size=len(prompts),
run_id=run_id_,
)
for callback_manager, prompt, run_name, run_id_ in zip(
callback_managers, prompts, run_name_list, run_ids_list
)
]
)
run_managers = [r[0] for r in run_managers] # type: ignore[misc]
output = await self._agenerate_helper(
prompts,
stop,
run_managers, # type: ignore[arg-type]
bool(new_arg_supported),
**kwargs, # type: ignore[arg-type]
)
return output
if len(missing_prompts) > 0:
run_managers = await asyncio.gather(
*[
callback_managers[idx].on_llm_start(
dumpd(self),
[prompts[idx]],
invocation_params=params,
options=options,
name=run_name_list[idx],
batch_size=len(missing_prompts),
)
for idx in missing_prompt_idxs
]
)
run_managers = [r[0] for r in run_managers] # type: ignore[misc]
new_results = await self._agenerate_helper(
missing_prompts,
stop,
run_managers, # type: ignore[arg-type]
bool(new_arg_supported),
**kwargs, # type: ignore[arg-type]
)
llm_output = await aupdate_cache(
self.cache,
existing_prompts,
llm_string,
missing_prompt_idxs,
new_results,
prompts,
)
run_info = (
[RunInfo(run_id=run_manager.run_id) for run_manager in run_managers] # type: ignore[attr-defined]
if run_managers
else None
)
else:
llm_output = {}
run_info = None
generations = [existing_prompts[i] for i in range(len(prompts))]
return LLMResult(generations=generations, llm_output=llm_output, run=run_info)
[docs] @deprecated("0.1.7", alternative="invoke", removal="0.3.0")
def __call__(
self,
prompt: str,
stop: Optional[List[str]] = None,
callbacks: Callbacks = None,
*,
tags: Optional[List[str]] = None,
metadata: Optional[Dict[str, Any]] = None,
**kwargs: Any,
) -> str:
"""检查缓存并在给定的提示和输入上运行LLM。"""
if not isinstance(prompt, str):
raise ValueError(
"Argument `prompt` is expected to be a string. Instead found "
f"{type(prompt)}. If you want to run the LLM on multiple prompts, use "
"`generate` instead."
)
return (
self.generate(
[prompt],
stop=stop,
callbacks=callbacks,
tags=tags,
metadata=metadata,
**kwargs,
)
.generations[0][0]
.text
)
async def _call_async(
self,
prompt: str,
stop: Optional[List[str]] = None,
callbacks: Callbacks = None,
*,
tags: Optional[List[str]] = None,
metadata: Optional[Dict[str, Any]] = None,
**kwargs: Any,
) -> str:
"""检查缓存并在给定的提示和输入上运行LLM。"""
result = await self.agenerate(
[prompt],
stop=stop,
callbacks=callbacks,
tags=tags,
metadata=metadata,
**kwargs,
)
return result.generations[0][0].text
[docs] @deprecated("0.1.7", alternative="invoke", removal="0.3.0")
def predict(
self, text: str, *, stop: Optional[Sequence[str]] = None, **kwargs: Any
) -> str:
if stop is None:
_stop = None
else:
_stop = list(stop)
return self(text, stop=_stop, **kwargs)
[docs] @deprecated("0.1.7", alternative="invoke", removal="0.3.0")
def predict_messages(
self,
messages: List[BaseMessage],
*,
stop: Optional[Sequence[str]] = None,
**kwargs: Any,
) -> BaseMessage:
text = get_buffer_string(messages)
if stop is None:
_stop = None
else:
_stop = list(stop)
content = self(text, stop=_stop, **kwargs)
return AIMessage(content=content)
[docs] @deprecated("0.1.7", alternative="ainvoke", removal="0.3.0")
async def apredict(
self, text: str, *, stop: Optional[Sequence[str]] = None, **kwargs: Any
) -> str:
if stop is None:
_stop = None
else:
_stop = list(stop)
return await self._call_async(text, stop=_stop, **kwargs)
[docs] @deprecated("0.1.7", alternative="ainvoke", removal="0.3.0")
async def apredict_messages(
self,
messages: List[BaseMessage],
*,
stop: Optional[Sequence[str]] = None,
**kwargs: Any,
) -> BaseMessage:
text = get_buffer_string(messages)
if stop is None:
_stop = None
else:
_stop = list(stop)
content = await self._call_async(text, stop=_stop, **kwargs)
return AIMessage(content=content)
def __str__(self) -> str:
"""获取对象的字符串表示以便打印。"""
cls_name = f"\033[1m{self.__class__.__name__}\033[0m"
return f"{cls_name}\nParams: {self._identifying_params}"
@property
@abstractmethod
def _llm_type(self) -> str:
"""llm的返回类型。"""
[docs] def dict(self, **kwargs: Any) -> Dict:
"""返回LLM的字典。"""
starter_dict = dict(self._identifying_params)
starter_dict["_type"] = self._llm_type
return starter_dict
[docs] def save(self, file_path: Union[Path, str]) -> None:
"""保存LLM。
参数:
file_path: 保存LLM的文件路径。
示例:
.. code-block:: python
llm.save(file_path="path/llm.yaml")
"""
# Convert file to Path object.
if isinstance(file_path, str):
save_path = Path(file_path)
else:
save_path = file_path
directory_path = save_path.parent
directory_path.mkdir(parents=True, exist_ok=True)
# Fetch dictionary to save
prompt_dict = self.dict()
if save_path.suffix == ".json":
with open(file_path, "w") as f:
json.dump(prompt_dict, f, indent=4)
elif save_path.suffix.endswith((".yaml", ".yml")):
with open(file_path, "w") as f:
yaml.dump(prompt_dict, f, default_flow_style=False)
else:
raise ValueError(f"{save_path} must be json or yaml")
[docs]class LLM(BaseLLM):
"""实现自定义LLM的简单接口。
您应该继承这个类并实现以下内容:
- `_call` 方法:在给定的提示和输入上运行LLM(被`invoke`使用)。
- `_identifying_params` 属性:返回一个标识参数的字典
这对于缓存和跟踪至关重要。标识参数是一个标识LLM的字典。
它应该主要包括一个`model_name`。
可选:重写以下方法以提供更多优化:
- `_acall`:提供`_call`方法的本机异步版本。
如果未提供,将使用`run_in_executor`委托给同步版本(被`ainvoke`使用)。
- `_stream`:在给定的提示和输入上流式传输LLM。
如果提供了`_stream`,`stream`将使用它,否则它将使用`_call`,输出将一次性到达。
- `_astream`:重写以提供`_stream`方法的本机异步版本。
如果提供了`_astream`,`astream`将使用它,否则它将实现一个后备行为,如果`_stream`已实现,则将使用`_stream`,如果`_stream`未实现,则将使用`_acall`。
"""
@abstractmethod
def _call(
self,
prompt: str,
stop: Optional[List[str]] = None,
run_manager: Optional[CallbackManagerForLLMRun] = None,
**kwargs: Any,
) -> str:
"""运行给定输入上的LLM。
覆盖此方法以实现LLM逻辑。
参数:
prompt: 生成内容的提示。
stop: 生成时要使用的停止词。模型输出在遇到任何停止子字符串的第一次出现时被截断。
如果不支持停止标记,请考虑引发NotImplementedError。
run_manager: 运行的回调管理器。
**kwargs: 任意额外的关键字参数。这些通常会传递给模型提供者的API调用。
返回:
模型输出作为字符串。不应包括提示。
"""
async def _acall(
self,
prompt: str,
stop: Optional[List[str]] = None,
run_manager: Optional[AsyncCallbackManagerForLLMRun] = None,
**kwargs: Any,
) -> str:
"""异步版本的_call方法。
默认实现使用`run_in_executor`委托给同步的_call方法。需要提供真正的异步实现的子类应该重写这个方法,以减少使用`run_in_executor`的开销。
参数:
prompt: 生成提示的内容。
stop: 生成时使用的停止词。模型输出在第一次出现任何停止子字符串时被截断。
如果不支持停止标记,请考虑引发NotImplementedError。
run_manager: 运行的回调管理器。
**kwargs: 任意额外的关键字参数。这些通常传递给模型提供者的API调用。
返回:
模型输出作为字符串。不应包括提示。
"""
return await run_in_executor(
None,
self._call,
prompt,
stop,
run_manager.get_sync() if run_manager else None,
**kwargs,
)
def _generate(
self,
prompts: List[str],
stop: Optional[List[str]] = None,
run_manager: Optional[CallbackManagerForLLMRun] = None,
**kwargs: Any,
) -> LLMResult:
"""在给定的提示和输入上运行LLM。"""
# TODO: add caching here.
generations = []
new_arg_supported = inspect.signature(self._call).parameters.get("run_manager")
for prompt in prompts:
text = (
self._call(prompt, stop=stop, run_manager=run_manager, **kwargs)
if new_arg_supported
else self._call(prompt, stop=stop, **kwargs)
)
generations.append([Generation(text=text)])
return LLMResult(generations=generations)
async def _agenerate(
self,
prompts: List[str],
stop: Optional[List[str]] = None,
run_manager: Optional[AsyncCallbackManagerForLLMRun] = None,
**kwargs: Any,
) -> LLMResult:
"""在给定的提示和输入上运行LLM。"""
generations = []
new_arg_supported = inspect.signature(self._acall).parameters.get("run_manager")
for prompt in prompts:
text = (
await self._acall(prompt, stop=stop, run_manager=run_manager, **kwargs)
if new_arg_supported
else await self._acall(prompt, stop=stop, **kwargs)
)
generations.append([Generation(text=text)])
return LLMResult(generations=generations)