Source code for langchain.chains.transform
"""运行任意Python函数的链。"""
import functools
import logging
from typing import Any, Awaitable, Callable, Dict, List, Optional
from langchain_core.callbacks import (
AsyncCallbackManagerForChainRun,
CallbackManagerForChainRun,
)
from langchain_core.pydantic_v1 import Field
from langchain.chains.base import Chain
logger = logging.getLogger(__name__)
[docs]class TransformChain(Chain):
"""链条,用于转换链条输出。
示例:
.. code-block:: python
from langchain.chains import TransformChain
transform_chain = TransformChain(input_variables=["text"],
output_variables["entities"], transform=func())"""
input_variables: List[str]
"""transform输入字典中期望的键。"""
output_variables: List[str]
"""转换输出字典返回的键。"""
transform_cb: Callable[[Dict[str, str]], Dict[str, str]] = Field(alias="transform")
"""转换函数。"""
atransform_cb: Optional[
Callable[[Dict[str, Any]], Awaitable[Dict[str, Any]]]
] = Field(None, alias="atransform")
"""异步协程转换函数。"""
@staticmethod
@functools.lru_cache
def _log_once(msg: str) -> None:
"""记录一条消息一次。
:元数据 私有:
"""
logger.warning(msg)
@property
def input_keys(self) -> List[str]:
"""期望输入键。
:元数据 私有:
"""
return self.input_variables
@property
def output_keys(self) -> List[str]:
"""返回输出键。
:元数据 私有:
"""
return self.output_variables
def _call(
self,
inputs: Dict[str, str],
run_manager: Optional[CallbackManagerForChainRun] = None,
) -> Dict[str, str]:
return self.transform_cb(inputs)
async def _acall(
self,
inputs: Dict[str, Any],
run_manager: Optional[AsyncCallbackManagerForChainRun] = None,
) -> Dict[str, Any]:
if self.atransform_cb is not None:
return await self.atransform_cb(inputs)
else:
self._log_once(
"TransformChain's atransform is not provided, falling"
" back to synchronous transform"
)
return self.transform_cb(inputs)