class PostprocessorComponent(QueryComponent):
"""后处理器组件。"""
postprocessor: BaseNodePostprocessor = Field(..., description="Postprocessor")
class Config:
arbitrary_types_allowed = True
def set_callback_manager(self, callback_manager: CallbackManager) -> None:
"""设置回调管理器。"""
self.postprocessor.callback_manager = callback_manager
def _validate_component_inputs(self, input: Dict[str, Any]) -> Dict[str, Any]:
"""在运行组件期间验证组件输入。"""
# make sure `nodes` is a list of nodes
if "nodes" not in input:
raise ValueError("Input must have key 'nodes'")
nodes = input["nodes"]
if not isinstance(nodes, list):
raise ValueError("Input nodes must be a list")
for node in nodes:
if not isinstance(node, NodeWithScore):
raise ValueError("Input nodes must be a list of NodeWithScore")
# if query_str exists, make sure `query_str` is stringable
if "query_str" in input:
input["query_str"] = validate_and_convert_stringable(input["query_str"])
return input
def _run_component(self, **kwargs: Any) -> Any:
"""运行组件。"""
output = self.postprocessor.postprocess_nodes(
kwargs["nodes"], query_str=kwargs.get("query_str", None)
)
return {"nodes": output}
async def _arun_component(self, **kwargs: Any) -> Any:
"""运行组件(异步)。"""
# NOTE: no native async for postprocessor
return self._run_component(**kwargs)
@property
def input_keys(self) -> InputKeys:
"""输入键。"""
return InputKeys.from_keys({"nodes"}, optional_keys={"query_str"})
@property
def output_keys(self) -> OutputKeys:
"""输出键。"""
return OutputKeys.from_keys({"nodes"})