Skip to content

Multi step

MultiStepQueryEngine #

Bases: BaseQueryEngine

多步查询引擎。

该查询引擎可以在现有的基本查询引擎上操作,以及多步查询转换。

Parameters:

Name Type Description Default
query_engine BaseQueryEngine

一个BaseQueryEngine对象。

required
query_transform StepDecomposeQueryTransform

一个StepDecomposeQueryTransform对象。

required
response_synthesizer Optional[BaseSynthesizer]

一个BaseSynthesizer对象。

None
num_steps Optional[int]

运行多步查询的步骤数。

3
early_stopping bool

如果停止函数返回True,则是否提前停止。

True
index_summary str

索引的字符串摘要。

'None'
stop_fn Optional[Callable[[Dict], bool]]

接受信息字典并返回布尔值的停止函数。

None
Source code in llama_index/core/query_engine/multistep_query_engine.py
 26
 27
 28
 29
 30
 31
 32
 33
 34
 35
 36
 37
 38
 39
 40
 41
 42
 43
 44
 45
 46
 47
 48
 49
 50
 51
 52
 53
 54
 55
 56
 57
 58
 59
 60
 61
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
class MultiStepQueryEngine(BaseQueryEngine):
    """多步查询引擎。

    该查询引擎可以在现有的基本查询引擎上操作,以及多步查询转换。

    Args:
        query_engine (BaseQueryEngine): 一个BaseQueryEngine对象。
        query_transform (StepDecomposeQueryTransform): 一个StepDecomposeQueryTransform对象。
        response_synthesizer (Optional[BaseSynthesizer]): 一个BaseSynthesizer对象。
        num_steps (Optional[int]): 运行多步查询的步骤数。
        early_stopping (bool): 如果停止函数返回True,则是否提前停止。
        index_summary (str): 索引的字符串摘要。
        stop_fn (Optional[Callable[[Dict], bool]]): 接受信息字典并返回布尔值的停止函数。"""

    def __init__(
        self,
        query_engine: BaseQueryEngine,
        query_transform: StepDecomposeQueryTransform,
        response_synthesizer: Optional[BaseSynthesizer] = None,
        num_steps: Optional[int] = 3,
        early_stopping: bool = True,
        index_summary: str = "None",
        stop_fn: Optional[Callable[[Dict], bool]] = None,
    ) -> None:
        self._query_engine = query_engine
        self._query_transform = query_transform
        self._response_synthesizer = response_synthesizer or get_response_synthesizer(
            callback_manager=self._query_engine.callback_manager
        )

        self._index_summary = index_summary
        self._num_steps = num_steps
        self._early_stopping = early_stopping
        # TODO: make interface to stop function better
        self._stop_fn = stop_fn or default_stop_fn
        # num_steps must be provided if early_stopping is False
        if not self._early_stopping and self._num_steps is None:
            raise ValueError("Must specify num_steps if early_stopping is False.")

        callback_manager = self._query_engine.callback_manager
        super().__init__(callback_manager)

    def _get_prompt_modules(self) -> PromptMixinType:
        """获取提示子模块。"""
        return {
            "response_synthesizer": self._response_synthesizer,
            "query_transform": self._query_transform,
        }

    def _query(self, query_bundle: QueryBundle) -> RESPONSE_TYPE:
        with self.callback_manager.event(
            CBEventType.QUERY, payload={EventPayload.QUERY_STR: query_bundle.query_str}
        ) as query_event:
            nodes, source_nodes, metadata = self._query_multistep(query_bundle)

            final_response = self._response_synthesizer.synthesize(
                query=query_bundle,
                nodes=nodes,
                additional_source_nodes=source_nodes,
            )
            final_response.metadata = metadata

            query_event.on_end(payload={EventPayload.RESPONSE: final_response})

        return final_response

    async def _aquery(self, query_bundle: QueryBundle) -> RESPONSE_TYPE:
        with self.callback_manager.event(
            CBEventType.QUERY, payload={EventPayload.QUERY_STR: query_bundle.query_str}
        ) as query_event:
            nodes, source_nodes, metadata = self._query_multistep(query_bundle)

            final_response = await self._response_synthesizer.asynthesize(
                query=query_bundle,
                nodes=nodes,
                additional_source_nodes=source_nodes,
            )
            final_response.metadata = metadata

            query_event.on_end(payload={EventPayload.RESPONSE: final_response})

        return final_response

    def _combine_queries(
        self, query_bundle: QueryBundle, prev_reasoning: str
    ) -> QueryBundle:
        """合并查询。"""
        transform_metadata = {
            "prev_reasoning": prev_reasoning,
            "index_summary": self._index_summary,
        }
        return self._query_transform(query_bundle, metadata=transform_metadata)

    def _query_multistep(
        self, query_bundle: QueryBundle
    ) -> Tuple[List[NodeWithScore], List[NodeWithScore], Dict[str, Any]]:
        """运行查询组合器。"""
        prev_reasoning = ""
        cur_response = None
        should_stop = False
        cur_steps = 0

        # use response
        final_response_metadata: Dict[str, Any] = {"sub_qa": []}

        text_chunks = []
        source_nodes = []
        while not should_stop:
            if self._num_steps is not None and cur_steps >= self._num_steps:
                should_stop = True
                break
            elif should_stop:
                break

            updated_query_bundle = self._combine_queries(query_bundle, prev_reasoning)

            # TODO: make stop logic better
            stop_dict = {"query_bundle": updated_query_bundle}
            if self._stop_fn(stop_dict):
                should_stop = True
                break

            cur_response = self._query_engine.query(updated_query_bundle)

            # append to response builder
            cur_qa_text = (
                f"\nQuestion: {updated_query_bundle.query_str}\n"
                f"Answer: {cur_response!s}"
            )
            text_chunks.append(cur_qa_text)
            for source_node in cur_response.source_nodes:
                source_nodes.append(source_node)
            # update metadata
            final_response_metadata["sub_qa"].append(
                (updated_query_bundle.query_str, cur_response)
            )

            prev_reasoning += (
                f"- {updated_query_bundle.query_str}\n" f"- {cur_response!s}\n"
            )
            cur_steps += 1

        nodes = [
            NodeWithScore(node=TextNode(text=text_chunk)) for text_chunk in text_chunks
        ]
        return nodes, source_nodes, final_response_metadata