from __future__ import annotations
from abc import ABC, abstractmethod
from functools import lru_cache
from typing import (
TYPE_CHECKING,
Any,
Callable,
Dict,
List,
Mapping,
Optional,
Sequence,
Set,
Type,
TypeVar,
Union,
)
from typing_extensions import TypeAlias
from langchain_core._api import deprecated
from langchain_core.messages import (
AnyMessage,
BaseMessage,
MessageLikeRepresentation,
get_buffer_string,
)
from langchain_core.prompt_values import PromptValue
from langchain_core.pydantic_v1 import BaseModel, Field, validator
from langchain_core.runnables import Runnable, RunnableSerializable
from langchain_core.utils import get_pydantic_field_names
if TYPE_CHECKING:
from langchain_core.caches import BaseCache
from langchain_core.callbacks import Callbacks
from langchain_core.outputs import LLMResult
@lru_cache(maxsize=None) # Cache the tokenizer
def get_tokenizer() -> Any:
try:
from transformers import GPT2TokenizerFast # type: ignore[import]
except ImportError:
raise ImportError(
"Could not import transformers python package. "
"This is needed in order to calculate get_token_ids. "
"Please install it with `pip install transformers`."
)
# create a GPT-2 tokenizer instance
return GPT2TokenizerFast.from_pretrained("gpt2")
def _get_token_ids_default_method(text: str) -> List[int]:
"""将文本编码为标记ID。"""
# get the cached tokenizer
tokenizer = get_tokenizer()
# tokenize the text using the GPT-2 tokenizer
return tokenizer.encode(text)
LanguageModelInput = Union[PromptValue, str, Sequence[MessageLikeRepresentation]]
LanguageModelOutput = Union[BaseMessage, str]
LanguageModelLike = Runnable[LanguageModelInput, LanguageModelOutput]
LanguageModelOutputVar = TypeVar("LanguageModelOutputVar", BaseMessage, str)
def _get_verbosity() -> bool:
from langchain_core.globals import get_verbose
return get_verbose()
[docs]class BaseLanguageModel(
RunnableSerializable[LanguageModelInput, LanguageModelOutputVar], ABC
):
"""用于与语言模型进行交互的抽象基类。
所有语言模型包装器都继承自BaseLanguageModel。
"""
cache: Union[BaseCache, bool, None] = None
"""是否缓存响应。
* 如果为True,则将使用全局缓存。
* 如果为False,则不使用缓存。
* 如果为None,则如果设置了全局缓存,则使用全局缓存,否则不使用缓存。
* 如果为BaseCache的实例,则将使用提供的缓存。
目前不支持对模型的流式方法进行缓存。"""
verbose: bool = Field(default_factory=_get_verbosity)
"""是否打印响应文本。"""
callbacks: Callbacks = Field(default=None, exclude=True)
"""运行跟踪中需要添加的回调函数。"""
tags: Optional[List[str]] = Field(default=None, exclude=True)
"""要添加到运行跟踪中的标签。"""
metadata: Optional[Dict[str, Any]] = Field(default=None, exclude=True)
"""要添加到运行跟踪中的元数据。"""
custom_get_token_ids: Optional[Callable[[str], List[int]]] = Field(
default=None, exclude=True
)
"""用于计算标记的可选编码器。"""
@validator("verbose", pre=True, always=True)
def set_verbose(cls, verbose: Optional[bool]) -> bool:
"""如果verbose为None,则设置它。
这允许用户传入None作为verbose来访问全局设置。
"""
if verbose is None:
return _get_verbosity()
else:
return verbose
@property
def InputType(self) -> TypeAlias:
"""获取此可运行对象的输入类型。"""
from langchain_core.prompt_values import (
ChatPromptValueConcrete,
StringPromptValue,
)
# This is a version of LanguageModelInput which replaces the abstract
# base class BaseMessage with a union of its subclasses, which makes
# for a much better schema.
return Union[
str,
Union[StringPromptValue, ChatPromptValueConcrete],
List[AnyMessage],
]
[docs] @abstractmethod
def generate_prompt(
self,
prompts: List[PromptValue],
stop: Optional[List[str]] = None,
callbacks: Callbacks = None,
**kwargs: Any,
) -> LLMResult:
"""将一系列提示传递给模型,并返回模型生成的结果。
对于暴露批处理API的模型,此方法应该使用批处理调用。
在以下情况下使用此方法:
1. 利用批处理调用,
2. 需要从模型获得更多输出而不仅仅是顶部生成的值,
3. 正在构建对基础语言模型类型不可知的链(例如,纯文本完成模型与聊天模型)。
参数:
prompts: PromptValues列表。PromptValue是一个对象,可以转换为与任何语言模型匹配的格式(纯文本生成模型的字符串和聊天模型的BaseMessages)。
stop: 生成时要使用的停止词。模型输出在出现这些子字符串的第一次截断。
callbacks: 要传递的回调。用于在生成过程中执行额外功能,如日志记录或流式处理。
**kwargs: 任意额外的关键字参数。这些通常传递给模型提供者API调用。
返回:
一个LLMResult,其中包含每个输入提示的候选生成列表和额外的模型提供者特定输出。
"""
[docs] @abstractmethod
async def agenerate_prompt(
self,
prompts: List[PromptValue],
stop: Optional[List[str]] = None,
callbacks: Callbacks = None,
**kwargs: Any,
) -> LLMResult:
"""异步传递一系列提示并返回模型生成。
此方法应该利用暴露批处理API的模型进行批处理调用。
在以下情况下使用此方法:
1. 利用批处理调用,
2. 需要模型生成的更多输出而不仅仅是顶部生成的值,
3. 构建对基础语言模型类型不可知的链式结构(例如,纯文本完成模型与聊天模型)。
参数:
prompts: PromptValues的列表。PromptValue是一个对象,可以转换为与任何语言模型的格式匹配(纯文本生成模型的字符串和聊天模型的BaseMessages)。
stop: 生成时要使用的停止词。模型输出在出现这些子字符串的第一次发生时被截断。
callbacks: 要传递的回调。用于执行生成过程中的额外功能,例如日志记录或流式传输。
**kwargs: 任意的额外关键字参数。这些通常传递给模型提供者的API调用。
返回:
一个LLMResult,其中包含每个输入提示的候选生成列表和额外的模型提供者特定输出。
"""
[docs] def with_structured_output(
self, schema: Union[Dict, Type[BaseModel]], **kwargs: Any
) -> Runnable[LanguageModelInput, Union[Dict, BaseModel]]:
"""在这个类中未实现。"""
# Implement this on child class if there is a way of steering the model to
# generate responses that match a given schema.
raise NotImplementedError()
[docs] @deprecated("0.1.7", alternative="invoke", removal="0.3.0")
@abstractmethod
def predict(
self, text: str, *, stop: Optional[Sequence[str]] = None, **kwargs: Any
) -> str:
"""将单个字符串输入传递给模型并返回一个字符串。
在传递原始文本时使用此方法。如果要传递特定类型的聊天消息,请使用predict_messages。
参数:
text: 要传递给模型的字符串输入。
stop: 生成时要使用的停用词。模型输出在出现这些子字符串的第一次出现时被截断。
**kwargs: 任意额外的关键字参数。这些通常被传递给模型提供者的API调用。
返回:
作为字符串的顶级模型预测。
"""
[docs] @deprecated("0.1.7", alternative="invoke", removal="0.3.0")
@abstractmethod
def predict_messages(
self,
messages: List[BaseMessage],
*,
stop: Optional[Sequence[str]] = None,
**kwargs: Any,
) -> BaseMessage:
"""将消息序列传递给模型并返回一条消息。
在传递聊天消息时使用此方法。如果要传递原始文本,
使用predict。
参数:
messages:与单个模型输入对应的聊天消息序列。
stop:在生成时要使用的停用词。模型输出会在出现这些子字符串中的任何一个时被截断。
**kwargs:任意额外的关键字参数。这些通常会传递给模型提供者的API调用。
返回:
作为消息的顶级模型预测。
"""
[docs] @deprecated("0.1.7", alternative="ainvoke", removal="0.3.0")
@abstractmethod
async def apredict(
self, text: str, *, stop: Optional[Sequence[str]] = None, **kwargs: Any
) -> str:
"""将一个字符串异步传递给模型,并返回一个字符串。
在调用纯文本生成模型并且只需要最高候选生成时,请使用此方法。
参数:
text: 要传递给模型的字符串输入。
stop: 生成时要使用的停止词。模型输出在出现这些子字符串的第一次出现时被截断。
**kwargs: 任意额外的关键字参数。这些通常被传递给模型提供者的API调用。
返回:
作为字符串的顶级模型预测。
"""
[docs] @deprecated("0.1.7", alternative="ainvoke", removal="0.3.0")
@abstractmethod
async def apredict_messages(
self,
messages: List[BaseMessage],
*,
stop: Optional[Sequence[str]] = None,
**kwargs: Any,
) -> BaseMessage:
"""异步地向模型传递消息并返回一条消息。
在调用聊天模型并且只需要生成顶部候选项时使用此方法。
参数:
messages:与单个模型输入对应的一系列聊天消息。
stop:在生成时使用的停用词。模型输出在出现这些子字符串的第一次发生时被截断。
**kwargs:任意额外的关键字参数。这些通常被传递给模型提供者的API调用。
返回:
作为消息的顶部模型预测。
"""
@property
def _identifying_params(self) -> Mapping[str, Any]:
"""获取识别参数。"""
return self.lc_attributes
[docs] def get_token_ids(self, text: str) -> List[int]:
"""返回文本中标记的令牌的有序id。
参数:
text:要进行标记化的字符串输入。
返回:
返回一个id列表,对应于文本中的标记,按照它们在文本中出现的顺序。
"""
if self.custom_get_token_ids is not None:
return self.custom_get_token_ids(text)
else:
return _get_token_ids_default_method(text)
[docs] def get_num_tokens(self, text: str) -> int:
"""获取文本中存在的标记数量。
用于检查输入是否适合模型的上下文窗口。
参数:
text:要标记化的字符串输入。
返回:
文本中标记的整数数量。
"""
return len(self.get_token_ids(text))
[docs] def get_num_tokens_from_messages(self, messages: List[BaseMessage]) -> int:
"""获取消息中的令牌数量。
用于检查输入是否适合模型的上下文窗口。
参数:
messages:要进行标记化的消息输入。
返回:
跨消息中令牌数量的总和。
"""
return sum([self.get_num_tokens(get_buffer_string([m])) for m in messages])
@classmethod
def _all_required_field_names(cls) -> Set:
"""已弃用:保留以确保向后兼容性。
请使用get_pydantic_field_names。
"""
return get_pydantic_field_names(cls)