"""ChatYuan2包装器。"""
from __future__ import annotations
import logging
from typing import (
Any,
AsyncIterator,
Callable,
Dict,
Iterator,
List,
Mapping,
Optional,
Tuple,
Type,
Union,
)
from langchain_core.callbacks import (
AsyncCallbackManagerForLLMRun,
CallbackManagerForLLMRun,
)
from langchain_core.language_models.chat_models import (
BaseChatModel,
agenerate_from_stream,
generate_from_stream,
)
from langchain_core.messages import (
AIMessage,
AIMessageChunk,
BaseMessage,
BaseMessageChunk,
ChatMessage,
ChatMessageChunk,
FunctionMessage,
HumanMessage,
HumanMessageChunk,
SystemMessage,
SystemMessageChunk,
)
from langchain_core.outputs import ChatGeneration, ChatGenerationChunk, ChatResult
from langchain_core.pydantic_v1 import BaseModel, Field, root_validator
from langchain_core.utils import (
get_from_dict_or_env,
get_pydantic_field_names,
)
from tenacity import (
before_sleep_log,
retry,
retry_if_exception_type,
stop_after_attempt,
wait_exponential,
)
logger = logging.getLogger(__name__)
[docs]class ChatYuan2(BaseChatModel):
"""`Yuan2.0` 聊天模型 API。
要使用,应安装``openai-python``包,如果未安装该包,可以使用```pip install openai```进行安装。环境变量``YUAN2_API_KEY``设置为您的 API 密钥,如果未设置,任何人都可以访问 API。
可以传递给 openai.create 调用的任何有效参数都可以传递,即使在此类中没有明确保存。
示例:
.. code-block:: python
from langchain_community.chat_models import ChatYuan2
chat = ChatYuan2()
"""
client: Any #: :meta private:
async_client: Any = Field(default=None, exclude=True) #: :meta private:
model_name: str = Field(default="yuan2", alias="model")
"""要使用的模型名称。"""
model_kwargs: Dict[str, Any] = Field(default_factory=dict)
"""保存任何在`create`调用中有效但未明确指定的模型参数。"""
yuan2_api_key: Optional[str] = Field(default="EMPTY", alias="api_key")
"""如果未提供,将自动从环境变量`YUAN2_API_KEY`中推断。"""
yuan2_api_base: Optional[str] = Field(
default="http://127.0.0.1:8000/v1", alias="base_url"
)
"""API请求的基本URL路径,一个兼容OpenAI的API服务器。"""
request_timeout: Optional[Union[float, Tuple[float, float]]] = None
"""请求到yuan2完成API的超时时间。默认为600秒。"""
max_retries: int = 6
"""生成时最大的重试次数。"""
streaming: bool = False
"""是否要流式传输结果。"""
max_tokens: Optional[int] = None
"""生成的令牌的最大数量。"""
temperature: float = 1.0
"""使用哪种采样温度。"""
top_p: Optional[float] = 0.9
"""用于抽样的顶部p值。"""
stop: Optional[List[str]] = ["<eod>"]
"""遇到时停止生成的字符串列表。"""
repeat_last_n: Optional[int] = 64
"最后n个标记以进行惩罚"
repeat_penalty: Optional[float] = 1.18
"""重复标记的惩罚。"""
class Config:
"""此pydantic对象的配置。"""
allow_population_by_field_name = True
@property
def lc_secrets(self) -> Dict[str, str]:
return {"yuan2_api_key": "YUAN2_API_KEY"}
@property
def lc_attributes(self) -> Dict[str, Any]:
attributes: Dict[str, Any] = {}
if self.yuan2_api_base:
attributes["yuan2_api_base"] = self.yuan2_api_base
if self.yuan2_api_key:
attributes["yuan2_api_key"] = self.yuan2_api_key
return attributes
@root_validator(pre=True)
def build_extra(cls, values: Dict[str, Any]) -> Dict[str, Any]:
"""从传入的额外参数构建额外的kwargs。"""
all_required_field_names = get_pydantic_field_names(cls)
extra = values.get("model_kwargs", {})
for field_name in list(values):
if field_name in extra:
raise ValueError(f"Found {field_name} supplied twice.")
if field_name not in all_required_field_names:
logger.warning(
f"""WARNING! {field_name} is not default parameter.
{field_name} was transferred to model_kwargs.
Please confirm that {field_name} is what you intended."""
)
extra[field_name] = values.pop(field_name)
invalid_model_kwargs = all_required_field_names.intersection(extra.keys())
if invalid_model_kwargs:
raise ValueError(
f"Parameters {invalid_model_kwargs} should be specified explicitly. "
f"Instead they were passed in as part of `model_kwargs` parameter."
)
values["model_kwargs"] = extra
return values
@root_validator()
def validate_environment(cls, values: Dict) -> Dict:
"""验证环境中是否存在API密钥和Python包。"""
values["yuan2_api_key"] = get_from_dict_or_env(
values, "yuan2_api_key", "YUAN2_API_KEY"
)
try:
import openai
except ImportError:
raise ImportError(
"Could not import openai python package. "
"Please install it with `pip install openai`."
)
client_params = {
"api_key": values["yuan2_api_key"],
"base_url": values["yuan2_api_base"],
"timeout": values["request_timeout"],
"max_retries": values["max_retries"],
}
# generate client and async_client
if not values.get("client"):
values["client"] = openai.OpenAI(**client_params).chat.completions
if not values.get("async_client"):
values["async_client"] = openai.AsyncOpenAI(
**client_params
).chat.completions
return values
@property
def _default_params(self) -> Dict[str, Any]:
"""获取调用yuan2 API 的默认参数。"""
params = {
"model": self.model_name,
"stream": self.streaming,
"temperature": self.temperature,
"top_p": self.top_p,
**self.model_kwargs,
}
if self.max_tokens is not None:
params["max_tokens"] = self.max_tokens
if self.request_timeout is not None:
params["request_timeout"] = self.request_timeout
return params
[docs] def completion_with_retry(self, **kwargs: Any) -> Any:
"""使用tenacity来重试完成调用。"""
retry_decorator = _create_retry_decorator(self)
@retry_decorator
def _completion_with_retry(**kwargs: Any) -> Any:
return self.client.create(**kwargs)
return _completion_with_retry(**kwargs)
def _combine_llm_outputs(self, llm_outputs: List[Optional[dict]]) -> dict:
overall_token_usage: dict = {}
logger.debug(
f"type(llm_outputs): {type(llm_outputs)}; llm_outputs: {llm_outputs}"
)
for output in llm_outputs:
if output is None:
# Happens in streaming
continue
token_usage = output["token_usage"]
for k, v in token_usage.items():
if k in overall_token_usage:
overall_token_usage[k] += v
else:
overall_token_usage[k] = v
return {"token_usage": overall_token_usage, "model_name": self.model_name}
def _stream(
self,
messages: List[BaseMessage],
stop: Optional[List[str]] = None,
run_manager: Optional[CallbackManagerForLLMRun] = None,
**kwargs: Any,
) -> Iterator[ChatGenerationChunk]:
message_dicts, params = self._create_message_dicts(messages, stop)
params = {**params, **kwargs, "stream": True}
default_chunk_class = AIMessageChunk
for chunk in self.completion_with_retry(messages=message_dicts, **params):
if not isinstance(chunk, dict):
chunk = chunk.model_dump()
if len(chunk["choices"]) == 0:
continue
choice = chunk["choices"][0]
chunk = _convert_delta_to_message_chunk(
choice["delta"], default_chunk_class
)
finish_reason = choice.get("finish_reason")
generation_info = (
dict(finish_reason=finish_reason) if finish_reason is not None else None
)
default_chunk_class = chunk.__class__
cg_chunk = ChatGenerationChunk(
message=chunk,
generation_info=generation_info,
)
if run_manager:
run_manager.on_llm_new_token(chunk.content, chunk=cg_chunk)
yield cg_chunk
def _generate(
self,
messages: List[BaseMessage],
stop: Optional[List[str]] = None,
run_manager: Optional[CallbackManagerForLLMRun] = None,
**kwargs: Any,
) -> ChatResult:
if self.streaming:
stream_iter = self._stream(
messages=messages, stop=stop, run_manager=run_manager, **kwargs
)
return generate_from_stream(stream_iter)
message_dicts, params = self._create_message_dicts(messages, stop)
params = {**params, **kwargs}
response = self.completion_with_retry(messages=message_dicts, **params)
return self._create_chat_result(response)
def _create_message_dicts(
self, messages: List[BaseMessage], stop: Optional[List[str]]
) -> Tuple[List[Dict[str, Any]], Dict[str, Any]]:
params = dict(self._invocation_params)
if stop is not None:
if "stop" in params:
raise ValueError("`stop` found in both the input and default params.")
params["stop"] = stop
message_dicts = [_convert_message_to_dict(m) for m in messages]
return message_dicts, params
def _create_chat_result(self, response: Union[dict, BaseModel]) -> ChatResult:
generations = []
logger.debug(f"type(response): {type(response)}; response: {response}")
if not isinstance(response, dict):
response = response.dict()
for res in response["choices"]:
message = _convert_dict_to_message(res["message"])
generation_info = dict(finish_reason=res["finish_reason"])
if "logprobs" in res:
generation_info["logprobs"] = res["logprobs"]
gen = ChatGeneration(
message=message,
generation_info=generation_info,
)
generations.append(gen)
llm_output = {
"token_usage": response.get("usage", {}),
"model_name": self.model_name,
}
return ChatResult(generations=generations, llm_output=llm_output)
async def _astream(
self,
messages: List[BaseMessage],
stop: Optional[List[str]] = None,
run_manager: Optional[AsyncCallbackManagerForLLMRun] = None,
**kwargs: Any,
) -> AsyncIterator[ChatGenerationChunk]:
message_dicts, params = self._create_message_dicts(messages, stop)
params = {**params, **kwargs, "stream": True}
default_chunk_class = AIMessageChunk
async for chunk in await acompletion_with_retry(
self, messages=message_dicts, **params
):
if not isinstance(chunk, dict):
chunk = chunk.model_dump()
if len(chunk["choices"]) == 0:
continue
choice = chunk["choices"][0]
chunk = _convert_delta_to_message_chunk(
choice["delta"], default_chunk_class
)
finish_reason = choice.get("finish_reason")
generation_info = (
dict(finish_reason=finish_reason) if finish_reason is not None else None
)
default_chunk_class = chunk.__class__
cg_chunk = ChatGenerationChunk(
message=chunk,
generation_info=generation_info,
)
if run_manager:
await run_manager.on_llm_new_token(chunk.content, chunk=cg_chunk)
yield cg_chunk
async def _agenerate(
self,
messages: List[BaseMessage],
stop: Optional[List[str]] = None,
run_manager: Optional[AsyncCallbackManagerForLLMRun] = None,
**kwargs: Any,
) -> ChatResult:
if self.streaming:
stream_iter = self._astream(
messages=messages, stop=stop, run_manager=run_manager, **kwargs
)
return await agenerate_from_stream(stream_iter)
message_dicts, params = self._create_message_dicts(messages, stop)
params = {**params, **kwargs}
response = await acompletion_with_retry(self, messages=message_dicts, **params)
return self._create_chat_result(response)
@property
def _invocation_params(self) -> Mapping[str, Any]:
"""获取用于调用模型的参数。"""
yuan2_creds: Dict[str, Any] = {
"model": self.model_name,
}
return {**yuan2_creds, **self._default_params}
@property
def _llm_type(self) -> str:
"""聊天模型的返回类型。"""
return "chat-yuan2"
def _create_retry_decorator(llm: ChatYuan2) -> Callable[[Any], Any]:
import openai
min_seconds = 1
max_seconds = 60
# Wait 2^x * 1 second between each retry starting with
# 4 seconds, then up to 10 seconds, then 10 seconds afterwards
return retry(
reraise=True,
stop=stop_after_attempt(llm.max_retries),
wait=wait_exponential(multiplier=1, min=min_seconds, max=max_seconds),
retry=(
retry_if_exception_type(openai.APITimeoutError)
| retry_if_exception_type(openai.APIError)
| retry_if_exception_type(openai.APIConnectionError)
| retry_if_exception_type(openai.RateLimitError)
| retry_if_exception_type(openai.InternalServerError)
),
before_sleep=before_sleep_log(logger, logging.WARNING),
)
[docs]async def acompletion_with_retry(llm: ChatYuan2, **kwargs: Any) -> Any:
"""使用tenacity来重试异步完成调用。"""
retry_decorator = _create_retry_decorator(llm)
@retry_decorator
async def _completion_with_retry(**kwargs: Any) -> Any:
# Use OpenAI's async api https://github.com/openai/openai-python#async-api
return await llm.async_client.create(**kwargs)
return await _completion_with_retry(**kwargs)
def _convert_delta_to_message_chunk(
_dict: Mapping[str, Any], default_class: Type[BaseMessageChunk]
) -> BaseMessageChunk:
role = _dict.get("role")
content = _dict.get("content") or ""
if role == "user" or default_class == HumanMessageChunk:
return HumanMessageChunk(content=content)
elif role == "assistant" or default_class == AIMessageChunk:
return AIMessageChunk(content=content)
elif role == "system" or default_class == SystemMessageChunk:
return SystemMessageChunk(content=content)
elif role or default_class == ChatMessageChunk:
return ChatMessageChunk(content=content, role=role) # type: ignore[arg-type]
else:
return default_class(content=content) # type: ignore[call-arg]
def _convert_dict_to_message(_dict: Mapping[str, Any]) -> BaseMessage:
role = _dict.get("role")
if role == "user":
return HumanMessage(content=_dict.get("content", ""))
elif role == "assistant":
return AIMessage(content=_dict.get("content", ""))
elif role == "system":
return SystemMessage(content=_dict.get("content", ""))
else:
return ChatMessage(content=_dict.get("content", ""), role=role) # type: ignore[arg-type]
def _convert_message_to_dict(message: BaseMessage) -> dict:
"""将LangChain消息转换为字典。
参数:
message:LangChain消息。
返回:
字典。
"""
message_dict: Dict[str, Any]
if isinstance(message, ChatMessage):
message_dict = {"role": message.role, "content": message.content}
elif isinstance(message, HumanMessage):
message_dict = {"role": "user", "content": message.content}
elif isinstance(message, AIMessage):
message_dict = {"role": "assistant", "content": message.content}
elif isinstance(message, SystemMessage):
message_dict = {"role": "system", "content": message.content}
elif isinstance(message, FunctionMessage):
message_dict = {
"role": "function",
"name": message.name,
"content": message.content,
}
else:
raise ValueError(f"Got unknown type {message}")
if "name" in message.additional_kwargs:
message_dict["name"] = message.additional_kwargs["name"]
return message_dict