Skip to content

Retry

RetryGuidelineQueryEngine #

Bases: BaseQueryEngine

在查询引擎评估失败时,使用评估器反馈进行重试。

Parameters:

Name Type Description Default
query_engine BaseQueryEngine

查询引擎对象

required
guideline_evaluator GuidelineEvaluator

指南评估器对象

required
resynthesize_query bool

是否重新合成查询

False
max_retries int

最大重试次数

3
callback_manager Optional[CallbackManager]

回调管理器对象

None
Source code in llama_index/core/query_engine/retry_query_engine.py
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
class RetryGuidelineQueryEngine(BaseQueryEngine):
    """在查询引擎评估失败时,使用评估器反馈进行重试。

Args:
    query_engine (BaseQueryEngine): 查询引擎对象
    guideline_evaluator (GuidelineEvaluator): 指南评估器对象
    resynthesize_query (bool): 是否重新合成查询
    max_retries (int): 最大重试次数
    callback_manager (Optional[CallbackManager]): 回调管理器对象"""

    def __init__(
        self,
        query_engine: BaseQueryEngine,
        guideline_evaluator: GuidelineEvaluator,
        resynthesize_query: bool = False,
        max_retries: int = 3,
        callback_manager: Optional[CallbackManager] = None,
        query_transformer: Optional[FeedbackQueryTransformation] = None,
    ) -> None:
        self._query_engine = query_engine
        self._guideline_evaluator = guideline_evaluator
        self.max_retries = max_retries
        self.resynthesize_query = resynthesize_query
        self.query_transformer = query_transformer or FeedbackQueryTransformation(
            resynthesize_query=self.resynthesize_query
        )
        super().__init__(callback_manager)

    def _get_prompt_modules(self) -> PromptMixinType:
        """获取提示子模块。"""
        return {
            "query_engine": self._query_engine,
            "guideline_evalator": self._guideline_evaluator,
        }

    def _query(self, query_bundle: QueryBundle) -> RESPONSE_TYPE:
        """回答一个查询。"""
        response = self._query_engine._query(query_bundle)
        if self.max_retries <= 0:
            return response
        typed_response = (
            response if isinstance(response, Response) else response.get_response()
        )
        query_str = query_bundle.query_str
        eval = self._guideline_evaluator.evaluate_response(query_str, typed_response)
        if eval.passing:
            logger.debug("Evaluation returned True.")
            return response
        else:
            logger.debug("Evaluation returned False.")
            new_query_engine = RetryGuidelineQueryEngine(
                self._query_engine,
                self._guideline_evaluator,
                self.resynthesize_query,
                self.max_retries - 1,
                self.callback_manager,
            )
            new_query = self.query_transformer.run(query_bundle, {"evaluation": eval})
            logger.debug("New query: %s", new_query.query_str)
            return new_query_engine.query(new_query)

    async def _aquery(self, query_bundle: QueryBundle) -> RESPONSE_TYPE:
        """不支持。"""
        return self._query(query_bundle)

RetryQueryEngine #

Bases: BaseQueryEngine

如果评估失败,则在查询引擎上重试。

Parameters:

Name Type Description Default
query_engine BaseQueryEngine

查询引擎对象

required
evaluator BaseEvaluator

评估器对象

required
max_retries int

最大重试次数

3
callback_manager Optional[CallbackManager]

回调管理器对象

None
Source code in llama_index/core/query_engine/retry_query_engine.py
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
class RetryQueryEngine(BaseQueryEngine):
    """如果评估失败,则在查询引擎上重试。

    Args:
        query_engine (BaseQueryEngine): 查询引擎对象
        evaluator (BaseEvaluator): 评估器对象
        max_retries (int): 最大重试次数
        callback_manager (Optional[CallbackManager]): 回调管理器对象"""

    def __init__(
        self,
        query_engine: BaseQueryEngine,
        evaluator: BaseEvaluator,
        max_retries: int = 3,
        callback_manager: Optional[CallbackManager] = None,
    ) -> None:
        self._query_engine = query_engine
        self._evaluator = evaluator
        self.max_retries = max_retries
        super().__init__(callback_manager)

    def _get_prompt_modules(self) -> PromptMixinType:
        """获取提示子模块。"""
        return {"query_engine": self._query_engine, "evaluator": self._evaluator}

    def _query(self, query_bundle: QueryBundle) -> RESPONSE_TYPE:
        """回答一个查询。"""
        response = self._query_engine._query(query_bundle)
        if self.max_retries <= 0:
            return response
        typed_response = (
            response if isinstance(response, Response) else response.get_response()
        )
        query_str = query_bundle.query_str
        eval = self._evaluator.evaluate_response(query_str, typed_response)
        if eval.passing:
            logger.debug("Evaluation returned True.")
            return response
        else:
            logger.debug("Evaluation returned False.")
            new_query_engine = RetryQueryEngine(
                self._query_engine, self._evaluator, self.max_retries - 1
            )
            query_transformer = FeedbackQueryTransformation()
            new_query = query_transformer.run(query_bundle, {"evaluation": eval})
            return new_query_engine.query(new_query)

    async def _aquery(self, query_bundle: QueryBundle) -> RESPONSE_TYPE:
        """不支持。"""
        return self._query(query_bundle)

RetrySourceQueryEngine #

Bases: BaseQueryEngine

使用不同的源节点进行重试。

Source code in llama_index/core/query_engine/retry_source_query_engine.py
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
class RetrySourceQueryEngine(BaseQueryEngine):
    """使用不同的源节点进行重试。"""

    def __init__(
        self,
        query_engine: RetrieverQueryEngine,
        evaluator: BaseEvaluator,
        llm: Optional[LLM] = None,
        max_retries: int = 3,
        callback_manager: Optional[CallbackManager] = None,
        # deprecated
        service_context: Optional[ServiceContext] = None,
    ) -> None:
        """运行一个带有重试机制的BaseQueryEngine。"""
        self._query_engine = query_engine
        self._evaluator = evaluator
        self._llm = llm or llm_from_settings_or_context(Settings, service_context)
        self.max_retries = max_retries
        super().__init__(
            callback_manager=callback_manager
            or callback_manager_from_settings_or_context(Settings, service_context)
        )

    def _get_prompt_modules(self) -> PromptMixinType:
        """获取提示子模块。"""
        return {"query_engine": self._query_engine, "evaluator": self._evaluator}

    def _query(self, query_bundle: QueryBundle) -> RESPONSE_TYPE:
        response = self._query_engine._query(query_bundle)
        if self.max_retries <= 0:
            return response
        typed_response = (
            response if isinstance(response, Response) else response.get_response()
        )
        query_str = query_bundle.query_str
        eval = self._evaluator.evaluate_response(query_str, typed_response)
        if eval.passing:
            logger.debug("Evaluation returned True.")
            return response
        else:
            logger.debug("Evaluation returned False.")
            # Test source nodes
            source_evals = [
                self._evaluator.evaluate(
                    query=query_str,
                    response=typed_response.response,
                    contexts=[source_node.get_content()],
                )
                for source_node in typed_response.source_nodes
            ]
            orig_nodes = typed_response.source_nodes
            assert len(source_evals) == len(orig_nodes)
            new_docs = []
            for node, eval_result in zip(orig_nodes, source_evals):
                if eval_result:
                    new_docs.append(Document(text=node.node.get_content()))
            if len(new_docs) == 0:
                raise ValueError("No source nodes passed evaluation.")
            new_index = SummaryIndex.from_documents(
                new_docs,
            )
            new_retriever_engine = RetrieverQueryEngine(new_index.as_retriever())
            new_query_engine = RetrySourceQueryEngine(
                new_retriever_engine,
                self._evaluator,
                self._llm,
                self.max_retries - 1,
            )
            return new_query_engine.query(query_bundle)

    async def _aquery(self, query_bundle: QueryBundle) -> RESPONSE_TYPE:
        """不支持。"""
        return self._query(query_bundle)