Skip to content

Corrective rag

CorrectiveRAGPack #

Bases: BaseLlamaPack

Source code in llama_index/packs/corrective_rag/base.py
 47
 48
 49
 50
 51
 52
 53
 54
 55
 56
 57
 58
 59
 60
 61
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
class CorrectiveRAGPack(BaseLlamaPack):
    def __init__(self, documents: List[Document], tavily_ai_apikey: str) -> None:
        """初始化参数。"""
        llm = OpenAI(model="gpt-4")
        self.relevancy_pipeline = QueryPipeline(
            chain=[DEFAULT_RELEVANCY_PROMPT_TEMPLATE, llm]
        )
        self.transform_query_pipeline = QueryPipeline(
            chain=[DEFAULT_TRANSFORM_QUERY_TEMPLATE, llm]
        )

        self.llm = llm
        self.index = VectorStoreIndex.from_documents(documents)
        self.tavily_tool = TavilyToolSpec(api_key=tavily_ai_apikey)

    def get_modules(self) -> Dict[str, Any]:
        """获取模块。"""
        return {"llm": self.llm, "index": self.index}

    def retrieve_nodes(self, query_str: str, **kwargs: Any) -> List[NodeWithScore]:
        """检索与查询相关的节点。"""
        retriever = self.index.as_retriever(**kwargs)
        return retriever.retrieve(query_str)

    def evaluate_relevancy(
        self, retrieved_nodes: List[Document], query_str: str
    ) -> List[str]:
        """评估检索到的文档与查询的相关性。"""
        relevancy_results = []
        for node in retrieved_nodes:
            relevancy = self.relevancy_pipeline.run(
                context_str=node.text, query_str=query_str
            )
            relevancy_results.append(relevancy.message.content.lower().strip())
        return relevancy_results

    def extract_relevant_texts(
        self, retrieved_nodes: List[NodeWithScore], relevancy_results: List[str]
    ) -> str:
        """从检索到的文档中提取相关文本。"""
        relevant_texts = [
            retrieved_nodes[i].text
            for i, result in enumerate(relevancy_results)
            if result == "yes"
        ]
        return "\n".join(relevant_texts)

    def search_with_transformed_query(self, query_str: str) -> str:
        """使用Tavily API 搜索转换后的查询。"""
        search_results = self.tavily_tool.search(query_str, max_results=5)
        return "\n".join([result.text for result in search_results])

    def get_result(self, relevant_text: str, search_text: str, query_str: str) -> Any:
        """获取与相关文本的结果。"""
        documents = [Document(text=relevant_text + "\n" + search_text)]
        index = SummaryIndex.from_documents(documents)
        query_engine = index.as_query_engine()
        return query_engine.query(query_str)

    def run(self, query_str: str, **kwargs: Any) -> Any:
        """运行流水线。"""
        # Retrieve nodes based on the input query string.
        retrieved_nodes = self.retrieve_nodes(query_str, **kwargs)

        # Evaluate the relevancy of each retrieved document in relation to the query string.
        relevancy_results = self.evaluate_relevancy(retrieved_nodes, query_str)
        # Extract texts from documents that are deemed relevant based on the evaluation.
        relevant_text = self.extract_relevant_texts(retrieved_nodes, relevancy_results)

        # Initialize search_text variable to handle cases where it might not get defined.
        search_text = ""

        # If any document is found irrelevant, transform the query string for better search results.
        if "no" in relevancy_results:
            transformed_query_str = self.transform_query_pipeline.run(
                query_str=query_str
            ).message.content
            # Conduct a search with the transformed query string and collect the results.
            search_text = self.search_with_transformed_query(transformed_query_str)

        # Compile the final result. If there's additional search text from the transformed query,
        # it's included; otherwise, only the relevant text from the initial retrieval is returned.
        if search_text:
            return self.get_result(relevant_text, search_text, query_str)
        else:
            return self.get_result(relevant_text, "", query_str)

get_modules #

get_modules() -> Dict[str, Any]

获取模块。

Source code in llama_index/packs/corrective_rag/base.py
62
63
64
def get_modules(self) -> Dict[str, Any]:
    """获取模块。"""
    return {"llm": self.llm, "index": self.index}

retrieve_nodes #

retrieve_nodes(
    query_str: str, **kwargs: Any
) -> List[NodeWithScore]

检索与查询相关的节点。

Source code in llama_index/packs/corrective_rag/base.py
66
67
68
69
def retrieve_nodes(self, query_str: str, **kwargs: Any) -> List[NodeWithScore]:
    """检索与查询相关的节点。"""
    retriever = self.index.as_retriever(**kwargs)
    return retriever.retrieve(query_str)

evaluate_relevancy #

evaluate_relevancy(
    retrieved_nodes: List[Document], query_str: str
) -> List[str]

评估检索到的文档与查询的相关性。

Source code in llama_index/packs/corrective_rag/base.py
71
72
73
74
75
76
77
78
79
80
81
def evaluate_relevancy(
    self, retrieved_nodes: List[Document], query_str: str
) -> List[str]:
    """评估检索到的文档与查询的相关性。"""
    relevancy_results = []
    for node in retrieved_nodes:
        relevancy = self.relevancy_pipeline.run(
            context_str=node.text, query_str=query_str
        )
        relevancy_results.append(relevancy.message.content.lower().strip())
    return relevancy_results

extract_relevant_texts #

extract_relevant_texts(
    retrieved_nodes: List[NodeWithScore],
    relevancy_results: List[str],
) -> str

从检索到的文档中提取相关文本。

Source code in llama_index/packs/corrective_rag/base.py
83
84
85
86
87
88
89
90
91
92
def extract_relevant_texts(
    self, retrieved_nodes: List[NodeWithScore], relevancy_results: List[str]
) -> str:
    """从检索到的文档中提取相关文本。"""
    relevant_texts = [
        retrieved_nodes[i].text
        for i, result in enumerate(relevancy_results)
        if result == "yes"
    ]
    return "\n".join(relevant_texts)

search_with_transformed_query #

search_with_transformed_query(query_str: str) -> str

使用Tavily API 搜索转换后的查询。

Source code in llama_index/packs/corrective_rag/base.py
94
95
96
97
def search_with_transformed_query(self, query_str: str) -> str:
    """使用Tavily API 搜索转换后的查询。"""
    search_results = self.tavily_tool.search(query_str, max_results=5)
    return "\n".join([result.text for result in search_results])

get_result #

get_result(
    relevant_text: str, search_text: str, query_str: str
) -> Any

获取与相关文本的结果。

Source code in llama_index/packs/corrective_rag/base.py
 99
100
101
102
103
104
def get_result(self, relevant_text: str, search_text: str, query_str: str) -> Any:
    """获取与相关文本的结果。"""
    documents = [Document(text=relevant_text + "\n" + search_text)]
    index = SummaryIndex.from_documents(documents)
    query_engine = index.as_query_engine()
    return query_engine.query(query_str)

run #

run(query_str: str, **kwargs: Any) -> Any

运行流水线。

Source code in llama_index/packs/corrective_rag/base.py
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
def run(self, query_str: str, **kwargs: Any) -> Any:
    """运行流水线。"""
    # Retrieve nodes based on the input query string.
    retrieved_nodes = self.retrieve_nodes(query_str, **kwargs)

    # Evaluate the relevancy of each retrieved document in relation to the query string.
    relevancy_results = self.evaluate_relevancy(retrieved_nodes, query_str)
    # Extract texts from documents that are deemed relevant based on the evaluation.
    relevant_text = self.extract_relevant_texts(retrieved_nodes, relevancy_results)

    # Initialize search_text variable to handle cases where it might not get defined.
    search_text = ""

    # If any document is found irrelevant, transform the query string for better search results.
    if "no" in relevancy_results:
        transformed_query_str = self.transform_query_pipeline.run(
            query_str=query_str
        ).message.content
        # Conduct a search with the transformed query string and collect the results.
        search_text = self.search_with_transformed_query(transformed_query_str)

    # Compile the final result. If there's additional search text from the transformed query,
    # it's included; otherwise, only the relevant text from the initial retrieval is returned.
    if search_text:
        return self.get_result(relevant_text, search_text, query_str)
    else:
        return self.get_result(relevant_text, "", query_str)