Source code for langchain_core.utils.utils
"""通用的实用函数。"""
import contextlib
import datetime
import functools
import importlib
import warnings
from importlib.metadata import version
from typing import Any, Callable, Dict, Optional, Set, Tuple, Union
from packaging.version import parse
from requests import HTTPError, Response
from langchain_core.pydantic_v1 import SecretStr
[docs]def xor_args(*arg_groups: Tuple[str, ...]) -> Callable:
"""验证指定的关键字参数是否彼此互斥。"""
def decorator(func: Callable) -> Callable:
@functools.wraps(func)
def wrapper(*args: Any, **kwargs: Any) -> Any:
"""验证每个组中确切有一个参数不是None。"""
counts = [
sum(1 for arg in arg_group if kwargs.get(arg) is not None)
for arg_group in arg_groups
]
invalid_groups = [i for i, count in enumerate(counts) if count != 1]
if invalid_groups:
invalid_group_names = [", ".join(arg_groups[i]) for i in invalid_groups]
raise ValueError(
"Exactly one argument in each of the following"
" groups must be defined:"
f" {', '.join(invalid_group_names)}"
)
return func(*args, **kwargs)
return wrapper
return decorator
[docs]def raise_for_status_with_text(response: Response) -> None:
"""使用响应文本引发错误。"""
try:
response.raise_for_status()
except HTTPError as e:
raise ValueError(response.text) from e
[docs]@contextlib.contextmanager
def mock_now(dt_value): # type: ignore
"""用于在单元测试中模拟datetime.now()的上下文管理器。
示例:
with mock_now(datetime.datetime(2011, 2, 3, 10, 11)):
assert datetime.datetime.now() == datetime.datetime(2011, 2, 3, 10, 11)
"""
class MockDateTime(datetime.datetime):
"""使用固定的日期时间模拟 datetime.datetime.now()。"""
@classmethod
def now(cls): # type: ignore
# Create a copy of dt_value.
return datetime.datetime(
dt_value.year,
dt_value.month,
dt_value.day,
dt_value.hour,
dt_value.minute,
dt_value.second,
dt_value.microsecond,
dt_value.tzinfo,
)
real_datetime = datetime.datetime
datetime.datetime = MockDateTime
try:
yield datetime.datetime
finally:
datetime.datetime = real_datetime
[docs]def guard_import(
module_name: str, *, pip_name: Optional[str] = None, package: Optional[str] = None
) -> Any:
"""动态导入一个模块,并在模块未安装时引发异常。
"""
try:
module = importlib.import_module(module_name, package)
except (ImportError, ModuleNotFoundError):
pip_name = pip_name or module_name.split(".")[0].replace("_", "-")
raise ImportError(
f"Could not import {module_name} python package. "
f"Please install it with `pip install {pip_name}`."
)
return module
[docs]def check_package_version(
package: str,
lt_version: Optional[str] = None,
lte_version: Optional[str] = None,
gt_version: Optional[str] = None,
gte_version: Optional[str] = None,
) -> None:
"""检查一个包的版本。"""
imported_version = parse(version(package))
if lt_version is not None and imported_version >= parse(lt_version):
raise ValueError(
f"Expected {package} version to be < {lt_version}. Received "
f"{imported_version}."
)
if lte_version is not None and imported_version > parse(lte_version):
raise ValueError(
f"Expected {package} version to be <= {lte_version}. Received "
f"{imported_version}."
)
if gt_version is not None and imported_version <= parse(gt_version):
raise ValueError(
f"Expected {package} version to be > {gt_version}. Received "
f"{imported_version}."
)
if gte_version is not None and imported_version < parse(gte_version):
raise ValueError(
f"Expected {package} version to be >= {gte_version}. Received "
f"{imported_version}."
)
[docs]def get_pydantic_field_names(pydantic_cls: Any) -> Set[str]:
"""获取Pydantic类的字段名称,包括别名。
参数:
pydantic_cls:Pydantic类。
"""
all_required_field_names = set()
for field in pydantic_cls.__fields__.values():
all_required_field_names.add(field.name)
if field.has_alias:
all_required_field_names.add(field.alias)
return all_required_field_names
[docs]def convert_to_secret_str(value: Union[SecretStr, str]) -> SecretStr:
"""如果需要,将字符串转换为SecretStr。"""
if isinstance(value, SecretStr):
return value
return SecretStr(value)