Skip to content

Retrieval

评估模块。

BaseRetrievalEvaluator #

Bases: BaseModel

基本的检索评估器类。

Source code in llama_index/core/evaluation/retrieval/base.py
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
class BaseRetrievalEvaluator(BaseModel):
    """基本的检索评估器类。"""

    metrics: List[BaseRetrievalMetric] = Field(
        ..., description="List of metrics to evaluate"
    )

    class Config:
        arbitrary_types_allowed = True

    @classmethod
    def from_metric_names(
        cls, metric_names: List[str], **kwargs: Any
    ) -> "BaseRetrievalEvaluator":
        """从指标名称创建评估器。

Args:
    metric_names(List[str]):指标名称列表
    **kwargs:评估器的其他参数
"""
        metric_types = resolve_metrics(metric_names)
        return cls(metrics=[metric() for metric in metric_types], **kwargs)

    @abstractmethod
    async def _aget_retrieved_ids_and_texts(
        self, query: str, mode: RetrievalEvalMode = RetrievalEvalMode.TEXT
    ) -> Tuple[List[str], List[str]]:
        """获取检索到的id和文本。"""
        raise NotImplementedError

    def evaluate(
        self,
        query: str,
        expected_ids: List[str],
        expected_texts: Optional[List[str]] = None,
        mode: RetrievalEvalMode = RetrievalEvalMode.TEXT,
        **kwargs: Any,
    ) -> RetrievalEvalResult:
        """运行带有查询字符串和预期id的评估结果。

Args:
    query (str): 查询字符串
    expected_ids (List[str]): 预期的id列表

Returns:
    RetrievalEvalResult: 评估结果
"""
        return asyncio_run(
            self.aevaluate(
                query=query,
                expected_ids=expected_ids,
                expected_texts=expected_texts,
                mode=mode,
                **kwargs,
            )
        )

    # @abstractmethod
    async def aevaluate(
        self,
        query: str,
        expected_ids: List[str],
        expected_texts: Optional[List[str]] = None,
        mode: RetrievalEvalMode = RetrievalEvalMode.TEXT,
        **kwargs: Any,
    ) -> RetrievalEvalResult:
        """使用查询字符串、检索到的上下文和生成的响应字符串运行评估。

子类可以重写此方法以提供自定义评估逻辑,并接受额外的参数。
"""
        retrieved_ids, retrieved_texts = await self._aget_retrieved_ids_and_texts(
            query, mode
        )
        metric_dict = {}
        for metric in self.metrics:
            eval_result = metric.compute(
                query, expected_ids, retrieved_ids, expected_texts, retrieved_texts
            )
            metric_dict[metric.metric_name] = eval_result

        return RetrievalEvalResult(
            query=query,
            expected_ids=expected_ids,
            expected_texts=expected_texts,
            retrieved_ids=retrieved_ids,
            retrieved_texts=retrieved_texts,
            mode=mode,
            metric_dict=metric_dict,
        )

    async def aevaluate_dataset(
        self,
        dataset: EmbeddingQAFinetuneDataset,
        workers: int = 2,
        show_progress: bool = False,
        **kwargs: Any,
    ) -> List[RetrievalEvalResult]:
        """使用数据集进行评估。"""
        semaphore = asyncio.Semaphore(workers)

        async def eval_worker(
            query: str, expected_ids: List[str], mode: RetrievalEvalMode
        ) -> RetrievalEvalResult:
            async with semaphore:
                return await self.aevaluate(query, expected_ids=expected_ids, mode=mode)

        response_jobs = []
        mode = RetrievalEvalMode.from_str(dataset.mode)
        for query_id, query in dataset.queries.items():
            expected_ids = dataset.relevant_docs[query_id]
            response_jobs.append(eval_worker(query, expected_ids, mode))
        if show_progress:
            from tqdm.asyncio import tqdm_asyncio

            eval_results = await tqdm_asyncio.gather(*response_jobs)
        else:
            eval_results = await asyncio.gather(*response_jobs)

        return eval_results

from_metric_names classmethod #

from_metric_names(
    metric_names: List[str], **kwargs: Any
) -> BaseRetrievalEvaluator

从指标名称创建评估器。

Source code in llama_index/core/evaluation/retrieval/base.py
86
87
88
89
90
91
92
93
94
95
96
97
    @classmethod
    def from_metric_names(
        cls, metric_names: List[str], **kwargs: Any
    ) -> "BaseRetrievalEvaluator":
        """从指标名称创建评估器。

Args:
    metric_names(List[str]):指标名称列表
    **kwargs:评估器的其他参数
"""
        metric_types = resolve_metrics(metric_names)
        return cls(metrics=[metric() for metric in metric_types], **kwargs)

evaluate #

evaluate(
    query: str,
    expected_ids: List[str],
    expected_texts: Optional[List[str]] = None,
    mode: RetrievalEvalMode = RetrievalEvalMode.TEXT,
    **kwargs: Any
) -> RetrievalEvalResult

运行带有查询字符串和预期id的评估结果。

Parameters:

Name Type Description Default
query str

查询字符串

required
expected_ids List[str]

预期的id列表

required

Returns:

Name Type Description
RetrievalEvalResult RetrievalEvalResult

评估结果

Source code in llama_index/core/evaluation/retrieval/base.py
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
    def evaluate(
        self,
        query: str,
        expected_ids: List[str],
        expected_texts: Optional[List[str]] = None,
        mode: RetrievalEvalMode = RetrievalEvalMode.TEXT,
        **kwargs: Any,
    ) -> RetrievalEvalResult:
        """运行带有查询字符串和预期id的评估结果。

Args:
    query (str): 查询字符串
    expected_ids (List[str]): 预期的id列表

Returns:
    RetrievalEvalResult: 评估结果
"""
        return asyncio_run(
            self.aevaluate(
                query=query,
                expected_ids=expected_ids,
                expected_texts=expected_texts,
                mode=mode,
                **kwargs,
            )
        )

aevaluate async #

aevaluate(
    query: str,
    expected_ids: List[str],
    expected_texts: Optional[List[str]] = None,
    mode: RetrievalEvalMode = RetrievalEvalMode.TEXT,
    **kwargs: Any
) -> RetrievalEvalResult

使用查询字符串、检索到的上下文和生成的响应字符串运行评估。

子类可以重写此方法以提供自定义评估逻辑,并接受额外的参数。

Source code in llama_index/core/evaluation/retrieval/base.py
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
    async def aevaluate(
        self,
        query: str,
        expected_ids: List[str],
        expected_texts: Optional[List[str]] = None,
        mode: RetrievalEvalMode = RetrievalEvalMode.TEXT,
        **kwargs: Any,
    ) -> RetrievalEvalResult:
        """使用查询字符串、检索到的上下文和生成的响应字符串运行评估。

子类可以重写此方法以提供自定义评估逻辑,并接受额外的参数。
"""
        retrieved_ids, retrieved_texts = await self._aget_retrieved_ids_and_texts(
            query, mode
        )
        metric_dict = {}
        for metric in self.metrics:
            eval_result = metric.compute(
                query, expected_ids, retrieved_ids, expected_texts, retrieved_texts
            )
            metric_dict[metric.metric_name] = eval_result

        return RetrievalEvalResult(
            query=query,
            expected_ids=expected_ids,
            expected_texts=expected_texts,
            retrieved_ids=retrieved_ids,
            retrieved_texts=retrieved_texts,
            mode=mode,
            metric_dict=metric_dict,
        )

aevaluate_dataset async #

aevaluate_dataset(
    dataset: EmbeddingQAFinetuneDataset,
    workers: int = 2,
    show_progress: bool = False,
    **kwargs: Any
) -> List[RetrievalEvalResult]

使用数据集进行评估。

Source code in llama_index/core/evaluation/retrieval/base.py
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
async def aevaluate_dataset(
    self,
    dataset: EmbeddingQAFinetuneDataset,
    workers: int = 2,
    show_progress: bool = False,
    **kwargs: Any,
) -> List[RetrievalEvalResult]:
    """使用数据集进行评估。"""
    semaphore = asyncio.Semaphore(workers)

    async def eval_worker(
        query: str, expected_ids: List[str], mode: RetrievalEvalMode
    ) -> RetrievalEvalResult:
        async with semaphore:
            return await self.aevaluate(query, expected_ids=expected_ids, mode=mode)

    response_jobs = []
    mode = RetrievalEvalMode.from_str(dataset.mode)
    for query_id, query in dataset.queries.items():
        expected_ids = dataset.relevant_docs[query_id]
        response_jobs.append(eval_worker(query, expected_ids, mode))
    if show_progress:
        from tqdm.asyncio import tqdm_asyncio

        eval_results = await tqdm_asyncio.gather(*response_jobs)
    else:
        eval_results = await asyncio.gather(*response_jobs)

    return eval_results

RetrieverEvaluator #

Bases: BaseRetrievalEvaluator

检索器评估器。

该模块将使用一组指标评估检索器。

Parameters:

Name Type Description Default
metrics List[BaseRetrievalMetric]

用于评估的指标序列

required
retriever BaseRetriever

要评估的检索器。

required
node_postprocessors Optional[List[BaseNodePostprocessor]]

检索后应用的后处理器。

None
Source code in llama_index/core/evaluation/retrieval/evaluator.py
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
class RetrieverEvaluator(BaseRetrievalEvaluator):
    """检索器评估器。

    该模块将使用一组指标评估检索器。

    Args:
        metrics (List[BaseRetrievalMetric]): 用于评估的指标序列
        retriever: 要评估的检索器。
        node_postprocessors (Optional[List[BaseNodePostprocessor]]): 检索后应用的后处理器。"""

    retriever: BaseRetriever = Field(..., description="Retriever to evaluate")
    node_postprocessors: Optional[List[BaseNodePostprocessor]] = Field(
        default=None, description="Optional post-processor"
    )

    def __init__(
        self,
        metrics: Sequence[BaseRetrievalMetric],
        retriever: BaseRetriever,
        node_postprocessors: Optional[List[BaseNodePostprocessor]] = None,
        **kwargs: Any,
    ) -> None:
        """初始化参数。"""
        super().__init__(
            metrics=metrics,
            retriever=retriever,
            node_postprocessors=node_postprocessors,
            **kwargs,
        )

    async def _aget_retrieved_ids_and_texts(
        self, query: str, mode: RetrievalEvalMode = RetrievalEvalMode.TEXT
    ) -> Tuple[List[str], List[str]]:
        """获取检索到的id和文本,可能会应用后处理器。"""
        retrieved_nodes = await self.retriever.aretrieve(query)

        if self.node_postprocessors:
            for node_postprocessor in self.node_postprocessors:
                retrieved_nodes = node_postprocessor.postprocess_nodes(
                    retrieved_nodes, query_str=query
                )

        return (
            [node.node.node_id for node in retrieved_nodes],
            [node.node.text for node in retrieved_nodes],
        )

RetrievalEvalResult #

Bases: BaseModel

检索评估结果。

注意:这个抽象可能会在将来发生变化。

属性

query (str): 查询字符串 expected_ids (List[str]): 期望的id retrieved_ids (List[str]): 检索到的id metric_dict (Dict[str, BaseRetrievalMetric]): 评估的指标字典

Source code in llama_index/core/evaluation/retrieval/base.py
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
class RetrievalEvalResult(BaseModel):
    """检索评估结果。

    注意:这个抽象可能会在将来发生变化。

    属性:
        query (str): 查询字符串
        expected_ids (List[str]): 期望的id
        retrieved_ids (List[str]): 检索到的id
        metric_dict (Dict[str, BaseRetrievalMetric]): 评估的指标字典
"""

    class Config:
        arbitrary_types_allowed = True

    query: str = Field(..., description="Query string")
    expected_ids: List[str] = Field(..., description="Expected ids")
    expected_texts: Optional[List[str]] = Field(
        default=None,
        description="Expected texts associated with nodes provided in `expected_ids`",
    )
    retrieved_ids: List[str] = Field(..., description="Retrieved ids")
    retrieved_texts: List[str] = Field(..., description="Retrieved texts")
    mode: "RetrievalEvalMode" = Field(
        default=RetrievalEvalMode.TEXT, description="text or image"
    )
    metric_dict: Dict[str, RetrievalMetricResult] = Field(
        ..., description="Metric dictionary for the evaluation"
    )

    @property
    def metric_vals_dict(self) -> Dict[str, float]:
        """度量值的字典。"""
        return {k: v.score for k, v in self.metric_dict.items()}

    def __str__(self) -> str:
        """字符串表示。"""
        return f"Query: {self.query}\n" f"Metrics: {self.metric_vals_dict!s}\n"

metric_vals_dict property #

metric_vals_dict: Dict[str, float]

度量值的字典。