"""实现了RunnablePassthrough。"""
from __future__ import annotations
import asyncio
import inspect
import threading
from typing import (
TYPE_CHECKING,
Any,
AsyncIterator,
Awaitable,
Callable,
Dict,
Iterator,
List,
Mapping,
Optional,
Type,
Union,
cast,
)
from langchain_core.pydantic_v1 import BaseModel
from langchain_core.runnables.base import (
Other,
Runnable,
RunnableParallel,
RunnableSerializable,
)
from langchain_core.runnables.config import (
RunnableConfig,
acall_func_with_variable_args,
call_func_with_variable_args,
ensure_config,
get_executor_for_config,
patch_config,
)
from langchain_core.runnables.graph import Graph
from langchain_core.runnables.utils import (
AddableDict,
ConfigurableFieldSpec,
create_model,
)
from langchain_core.utils.aiter import atee, py_anext
from langchain_core.utils.iter import safetee
if TYPE_CHECKING:
from langchain_core.callbacks.manager import (
AsyncCallbackManagerForChainRun,
CallbackManagerForChainRun,
)
[docs]def identity(x: Other) -> Other:
"""身份函数"""
return x
[docs]async def aidentity(x: Other) -> Other:
"""异步身份函数"""
return x
[docs]class RunnablePassthrough(RunnableSerializable[Other, Other]):
"""将输入保持不变或添加额外键的可运行程序。
这个可运行程序的行为几乎像身份函数一样,只是可以配置为在输入为字典时向输出添加额外的键。
下面的示例演示了这个可运行程序如何使用一些简单的链。这些链依赖于简单的lambda函数,使示例易于执行和实验。
示例:
.. code-block:: python
from langchain_core.runnables import (
RunnableLambda,
RunnableParallel,
RunnablePassthrough,
)
runnable = RunnableParallel(
origin=RunnablePassthrough(),
modified=lambda x: x+1
)
runnable.invoke(1) # {'origin': 1, 'modified': 2}
def fake_llm(prompt: str) -> str: # 用于示例的虚拟LLM
return "completion"
chain = RunnableLambda(fake_llm) | {
'original': RunnablePassthrough(), # 原始LLM输出
'parsed': lambda text: text[::-1] # 解析逻辑
}
chain.invoke('hello') # {'original': 'completion', 'parsed': 'noitelpmoc'}
在某些情况下,将输入传递同时向输出添加一些键可能很有用。在这种情况下,可以使用`assign`方法:
.. code-block:: python
from langchain_core.runnables import RunnablePassthrough
def fake_llm(prompt: str) -> str: # 用于示例的虚拟LLM
return "completion"
runnable = {
'llm1': fake_llm,
'llm2': fake_llm,
} | RunnablePassthrough.assign(
total_chars=lambda inputs: len(inputs['llm1'] + inputs['llm2'])
)
runnable.invoke('hello')
# {'llm1': 'completion', 'llm2': 'completion', 'total_chars': 20}
"""
input_type: Optional[Type[Other]] = None
func: Optional[
Union[Callable[[Other], None], Callable[[Other, RunnableConfig], None]]
] = None
afunc: Optional[
Union[
Callable[[Other], Awaitable[None]],
Callable[[Other, RunnableConfig], Awaitable[None]],
]
] = None
def __repr_args__(self) -> Any:
# Without this repr(self) raises a RecursionError
# See https://github.com/pydantic/pydantic/issues/7327
return []
def __init__(
self,
func: Optional[
Union[
Union[Callable[[Other], None], Callable[[Other, RunnableConfig], None]],
Union[
Callable[[Other], Awaitable[None]],
Callable[[Other, RunnableConfig], Awaitable[None]],
],
]
] = None,
afunc: Optional[
Union[
Callable[[Other], Awaitable[None]],
Callable[[Other, RunnableConfig], Awaitable[None]],
]
] = None,
*,
input_type: Optional[Type[Other]] = None,
**kwargs: Any,
) -> None:
if inspect.iscoroutinefunction(func):
afunc = func
func = None
super().__init__(func=func, afunc=afunc, input_type=input_type, **kwargs) # type: ignore[call-arg]
[docs] @classmethod
def is_lc_serializable(cls) -> bool:
return True
[docs] @classmethod
def get_lc_namespace(cls) -> List[str]:
"""获取langchain对象的命名空间。"""
return ["langchain", "schema", "runnable"]
@property
def InputType(self) -> Any:
return self.input_type or Any
@property
def OutputType(self) -> Any:
return self.input_type or Any
[docs] @classmethod
def assign(
cls,
**kwargs: Union[
Runnable[Dict[str, Any], Any],
Callable[[Dict[str, Any]], Any],
Mapping[
str,
Union[Runnable[Dict[str, Any], Any], Callable[[Dict[str, Any]], Any]],
],
],
) -> "RunnableAssign":
"""将Dict输入与映射参数生成的输出合并。
参数:
mapping:从键到可运行对象或可调用对象的映射。
返回:
一个可运行对象,将Dict输入与映射参数生成的输出合并。
"""
return RunnableAssign(RunnableParallel(kwargs))
[docs] def invoke(
self, input: Other, config: Optional[RunnableConfig] = None, **kwargs: Any
) -> Other:
if self.func is not None:
call_func_with_variable_args(
self.func, input, ensure_config(config), **kwargs
)
return self._call_with_config(identity, input, config)
[docs] async def ainvoke(
self,
input: Other,
config: Optional[RunnableConfig] = None,
**kwargs: Optional[Any],
) -> Other:
if self.afunc is not None:
await acall_func_with_variable_args(
self.afunc, input, ensure_config(config), **kwargs
)
elif self.func is not None:
call_func_with_variable_args(
self.func, input, ensure_config(config), **kwargs
)
return await self._acall_with_config(aidentity, input, config)
[docs] def stream(
self,
input: Other,
config: Optional[RunnableConfig] = None,
**kwargs: Any,
) -> Iterator[Other]:
return self.transform(iter([input]), config, **kwargs)
[docs] async def astream(
self,
input: Other,
config: Optional[RunnableConfig] = None,
**kwargs: Any,
) -> AsyncIterator[Other]:
async def input_aiter() -> AsyncIterator[Other]:
yield input
async for chunk in self.atransform(input_aiter(), config, **kwargs):
yield chunk
_graph_passthrough: RunnablePassthrough = RunnablePassthrough()
[docs]class RunnableAssign(RunnableSerializable[Dict[str, Any], Dict[str, Any]]):
"""将键值对分配给Dict[str, Any]输入的可运行对象。
`RunnableAssign`类接受输入字典,并通过`RunnableParallel`实例应用转换,然后将这些与原始数据结合起来,根据映射器的逻辑引入新的键值对。
示例:
.. code-block:: python
# 这是一个RunnableAssign
from typing import Dict
from langchain_core.runnables.passthrough import (
RunnableAssign,
RunnableParallel,
)
from langchain_core.runnables.base import RunnableLambda
def add_ten(x: Dict[str, int]) -> Dict[str, int]:
return {"added": x["input"] + 10}
mapper = RunnableParallel(
{"add_step": RunnableLambda(add_ten),}
)
runnable_assign = RunnableAssign(mapper)
# 同步示例
runnable_assign.invoke({"input": 5})
# 返回 {'input': 5, 'add_step': {'added': 15}}
# 异步示例
await runnable_assign.ainvoke({"input": 5})
# 返回 {'input': 5, 'add_step': {'added': 15}}
"""
mapper: RunnableParallel[Dict[str, Any]]
def __init__(self, mapper: RunnableParallel[Dict[str, Any]], **kwargs: Any) -> None:
super().__init__(mapper=mapper, **kwargs) # type: ignore[call-arg]
[docs] @classmethod
def is_lc_serializable(cls) -> bool:
return True
[docs] @classmethod
def get_lc_namespace(cls) -> List[str]:
"""获取langchain对象的命名空间。"""
return ["langchain", "schema", "runnable"]
[docs] def get_name(
self, suffix: Optional[str] = None, *, name: Optional[str] = None
) -> str:
name = (
name
or self.name
or f"RunnableAssign<{','.join(self.mapper.steps__.keys())}>"
)
return super().get_name(suffix, name=name)
[docs] def get_output_schema(
self, config: Optional[RunnableConfig] = None
) -> Type[BaseModel]:
map_input_schema = self.mapper.get_input_schema(config)
map_output_schema = self.mapper.get_output_schema(config)
if (
not map_input_schema.__custom_root_type__
and not map_output_schema.__custom_root_type__
):
# ie. both are dicts
return create_model( # type: ignore[call-overload]
"RunnableAssignOutput",
**{
k: (v.type_, v.default)
for s in (map_input_schema, map_output_schema)
for k, v in s.__fields__.items()
},
)
elif not map_output_schema.__custom_root_type__:
# ie. only map output is a dict
# ie. input type is either unknown or inferred incorrectly
return map_output_schema
return super().get_output_schema(config)
@property
def config_specs(self) -> List[ConfigurableFieldSpec]:
return self.mapper.config_specs
[docs] def get_graph(self, config: RunnableConfig | None = None) -> Graph:
# get graph from mapper
graph = self.mapper.get_graph(config)
# add passthrough node and edges
input_node = graph.first_node()
output_node = graph.last_node()
if input_node is not None and output_node is not None:
passthrough_node = graph.add_node(_graph_passthrough)
graph.add_edge(input_node, passthrough_node)
graph.add_edge(passthrough_node, output_node)
return graph
def _invoke(
self,
input: Dict[str, Any],
run_manager: CallbackManagerForChainRun,
config: RunnableConfig,
**kwargs: Any,
) -> Dict[str, Any]:
assert isinstance(
input, dict
), "The input to RunnablePassthrough.assign() must be a dict."
return {
**input,
**self.mapper.invoke(
input,
patch_config(config, callbacks=run_manager.get_child()),
**kwargs,
),
}
[docs] def invoke(
self,
input: Dict[str, Any],
config: Optional[RunnableConfig] = None,
**kwargs: Any,
) -> Dict[str, Any]:
return self._call_with_config(self._invoke, input, config, **kwargs)
async def _ainvoke(
self,
input: Dict[str, Any],
run_manager: AsyncCallbackManagerForChainRun,
config: RunnableConfig,
**kwargs: Any,
) -> Dict[str, Any]:
assert isinstance(
input, dict
), "The input to RunnablePassthrough.assign() must be a dict."
return {
**input,
**await self.mapper.ainvoke(
input,
patch_config(config, callbacks=run_manager.get_child()),
**kwargs,
),
}
[docs] async def ainvoke(
self,
input: Dict[str, Any],
config: Optional[RunnableConfig] = None,
**kwargs: Any,
) -> Dict[str, Any]:
return await self._acall_with_config(self._ainvoke, input, config, **kwargs)
def _transform(
self,
input: Iterator[Dict[str, Any]],
run_manager: CallbackManagerForChainRun,
config: RunnableConfig,
**kwargs: Any,
) -> Iterator[Dict[str, Any]]:
# collect mapper keys
mapper_keys = set(self.mapper.steps__.keys())
# create two streams, one for the map and one for the passthrough
for_passthrough, for_map = safetee(input, 2, lock=threading.Lock())
# create map output stream
map_output = self.mapper.transform(
for_map,
patch_config(
config,
callbacks=run_manager.get_child(),
),
**kwargs,
)
# get executor to start map output stream in background
with get_executor_for_config(config) as executor:
# start map output stream
first_map_chunk_future = executor.submit(
next,
map_output, # type: ignore
None,
)
# consume passthrough stream
for chunk in for_passthrough:
assert isinstance(
chunk, dict
), "The input to RunnablePassthrough.assign() must be a dict."
# remove mapper keys from passthrough chunk, to be overwritten by map
filtered = AddableDict(
{k: v for k, v in chunk.items() if k not in mapper_keys}
)
if filtered:
yield filtered
# yield map output
yield cast(Dict[str, Any], first_map_chunk_future.result())
for chunk in map_output:
yield chunk
async def _atransform(
self,
input: AsyncIterator[Dict[str, Any]],
run_manager: AsyncCallbackManagerForChainRun,
config: RunnableConfig,
**kwargs: Any,
) -> AsyncIterator[Dict[str, Any]]:
# collect mapper keys
mapper_keys = set(self.mapper.steps__.keys())
# create two streams, one for the map and one for the passthrough
for_passthrough, for_map = atee(input, 2, lock=asyncio.Lock())
# create map output stream
map_output = self.mapper.atransform(
for_map,
patch_config(
config,
callbacks=run_manager.get_child(),
),
**kwargs,
)
# start map output stream
first_map_chunk_task: asyncio.Task = asyncio.create_task(
py_anext(map_output, None), # type: ignore[arg-type]
)
# consume passthrough stream
async for chunk in for_passthrough:
assert isinstance(
chunk, dict
), "The input to RunnablePassthrough.assign() must be a dict."
# remove mapper keys from passthrough chunk, to be overwritten by map output
filtered = AddableDict(
{k: v for k, v in chunk.items() if k not in mapper_keys}
)
if filtered:
yield filtered
# yield map output
yield await first_map_chunk_task
async for chunk in map_output:
yield chunk
[docs] def stream(
self,
input: Dict[str, Any],
config: Optional[RunnableConfig] = None,
**kwargs: Any,
) -> Iterator[Dict[str, Any]]:
return self.transform(iter([input]), config, **kwargs)
[docs] async def astream(
self,
input: Dict[str, Any],
config: Optional[RunnableConfig] = None,
**kwargs: Any,
) -> AsyncIterator[Dict[str, Any]]:
async def input_aiter() -> AsyncIterator[Dict[str, Any]]:
yield input
async for chunk in self.atransform(input_aiter(), config, **kwargs):
yield chunk
[docs]class RunnablePick(RunnableSerializable[Dict[str, Any], Dict[str, Any]]):
"""RunnablePick类表示一个可运行对象,它可以从字典输入中选择性地提取键。允许您指定一个或多个要从输入字典中提取的键。它返回一个仅包含所选键的新字典。
示例:
.. code-block:: python
from langchain_core.runnables.passthrough import RunnablePick
input_data = {
'name': 'John',
'age': 30,
'city': 'New York',
'country': 'USA'
}
runnable = RunnablePick(keys=['name', 'age'])
output_data = runnable.invoke(input_data)
print(output_data) # 输出: {'name': 'John', 'age': 30}
"""
keys: Union[str, List[str]]
def __init__(self, keys: Union[str, List[str]], **kwargs: Any) -> None:
super().__init__(keys=keys, **kwargs) # type: ignore[call-arg]
[docs] @classmethod
def is_lc_serializable(cls) -> bool:
return True
[docs] @classmethod
def get_lc_namespace(cls) -> List[str]:
"""获取langchain对象的命名空间。"""
return ["langchain", "schema", "runnable"]
[docs] def get_name(
self, suffix: Optional[str] = None, *, name: Optional[str] = None
) -> str:
name = (
name
or self.name
or f"RunnablePick<{','.join([self.keys] if isinstance(self.keys, str) else self.keys)}>" # noqa: E501
)
return super().get_name(suffix, name=name)
def _pick(self, input: Dict[str, Any]) -> Any:
assert isinstance(
input, dict
), "The input to RunnablePassthrough.assign() must be a dict."
if isinstance(self.keys, str):
return input.get(self.keys)
else:
picked = {k: input.get(k) for k in self.keys if k in input}
if picked:
return AddableDict(picked)
else:
return None
def _invoke(
self,
input: Dict[str, Any],
) -> Dict[str, Any]:
return self._pick(input)
[docs] def invoke(
self,
input: Dict[str, Any],
config: Optional[RunnableConfig] = None,
**kwargs: Any,
) -> Dict[str, Any]:
return self._call_with_config(self._invoke, input, config, **kwargs)
async def _ainvoke(
self,
input: Dict[str, Any],
) -> Dict[str, Any]:
return self._pick(input)
[docs] async def ainvoke(
self,
input: Dict[str, Any],
config: Optional[RunnableConfig] = None,
**kwargs: Any,
) -> Dict[str, Any]:
return await self._acall_with_config(self._ainvoke, input, config, **kwargs)
def _transform(
self,
input: Iterator[Dict[str, Any]],
) -> Iterator[Dict[str, Any]]:
for chunk in input:
picked = self._pick(chunk)
if picked is not None:
yield picked
async def _atransform(
self,
input: AsyncIterator[Dict[str, Any]],
) -> AsyncIterator[Dict[str, Any]]:
async for chunk in input:
picked = self._pick(chunk)
if picked is not None:
yield picked
[docs] def stream(
self,
input: Dict[str, Any],
config: Optional[RunnableConfig] = None,
**kwargs: Any,
) -> Iterator[Dict[str, Any]]:
return self.transform(iter([input]), config, **kwargs)
[docs] async def astream(
self,
input: Dict[str, Any],
config: Optional[RunnableConfig] = None,
**kwargs: Any,
) -> AsyncIterator[Dict[str, Any]]:
async def input_aiter() -> AsyncIterator[Dict[str, Any]]:
yield input
async for chunk in self.atransform(input_aiter(), config, **kwargs):
yield chunk