Skip to content

Searchain

SearChainPack #

Bases: BaseLlamaPack

简单的短形式SearChain包。

Source code in llama_index/packs/searchain/base.py
 48
 49
 50
 51
 52
 53
 54
 55
 56
 57
 58
 59
 60
 61
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
class SearChainPack(BaseLlamaPack):
    """简单的短形式SearChain包。"""

    def __init__(
        self,
        data_path: str,
        dprtokenizer_path: str = "/dpr_reader_multi",
        dprmodel_path: str = "/dpr_reader_multi",
        crossencoder_name_or_path: str = "/Quora_cross_encoder",
        device: str = "cuda",
        **kwargs: Any,
    ) -> None:
        """初始化参数。"""
        self.device = device
        self.crossencoder = CrossEncoder(crossencoder_name_or_path, device=self.device)
        self.documents = SimpleDirectoryReader(data_path).load_data()
        self.index = VectorStoreIndex.from_documents(self.documents)
        self.query_engine = self.index.as_query_engine()
        self.llm = OpenAI()

        self.dprtokenizer = DPRReaderTokenizer.from_pretrained(dprtokenizer_path)
        self.dprmodel = DPRReader.from_pretrained(dprmodel_path)
        self.dprmodel.eval()
        self.dprmodel.to(self.device)

    def _get_answer(self, query, texts, title):
        print("texts:" + texts)
        encoded_inputs = self.dprtokenizer(
            questions=[query],
            titles=[title],
            texts=[texts],
            return_tensors="pt",
            max_length=510,
        )
        outputs = self.dprmodel(**encoded_inputs.to(self.device))
        start_logits = outputs.start_logits
        end_logits = outputs.end_logits
        relevance_logits = outputs.relevance_logits

        answer_start_index = outputs.start_logits.argmax()
        answer_end_index = outputs.end_logits.argmax()
        predict_answer_tokens = encoded_inputs.input_ids[
            0, answer_start_index : answer_end_index + 1
        ]
        answer = self.dprtokenizer.decode(predict_answer_tokens)
        return answer, relevance_logits

    def _ir(self, query, query_seen_list):
        flag_contibue_label = False
        query_list = query.split("\n")
        message = ""
        for idx in range(len(query_list)):
            query_item = query_list[idx]
            if "Query" in query_item and "]:" in query_item:
                temp = query_item.split("]")
                if len(temp) < 2:
                    continue
                query_type = temp[0]
                query_item = temp[1]
                if ":" in query_item:
                    query_item = query_item[1:]
                if not _have_seen_or_not(
                    self.crossencoder, query_item, query_seen_list, query_type
                ):
                    now_reference = {}
                    query_seen_list.append(query_item)
                    response = str(self.query_engine.query(query_item))

                    answer, relevance_score = self._get_answer(
                        query=query_item, texts="", title=response
                    )
                    now_reference["query"] = query_item
                    now_reference["answer"] = answer
                    now_reference["reference"] = response
                    now_reference["ref_score"] = relevance_score
                    now_reference["idx"] = response

                    if "Unsolved" in query_type:
                        message = "[Unsolved Query]:{}<SEP>[Answer]:{}<SEP>[Reference]:{}<SEP>".format(
                            query_item, answer, response
                        )
                        flag_contibue_label = True
                        break
                    elif relevance_score > 1.5:
                        answer_start_idx = idx + 1
                        predict_answer = ""
                        while answer_start_idx < len(query_list):
                            if "Answer" in query_list[answer_start_idx]:
                                predict_answer = query_list[answer_start_idx]
                                break
                            answer_start_idx += 1
                        match_label = _match_or_not(
                            prediction=predict_answer, ground_truth=answer
                        )
                        if match_label:
                            continue
                        else:
                            message = "[Query]:{}<SEP>[Answer]:{}<SEP>[Reference]:{}<SEP>".format(
                                query_item, answer, response
                            )
                            flag_contibue_label = True
                            break
        return message, flag_contibue_label, query_seen_list

    def _extract(self, message_keys_list):
        text = message_keys_list
        idx = len(text)
        while idx > 0:
            idx = idx - 1
            item = text[idx]
            if item.role == "assistant" and "Final Content" in item.content:
                list_item = item.content.split("\n")
                for sp in list_item:
                    if "Final Content" in sp:
                        return item.content
        return "Sorry, I still cannot solve this question!"

    def execute(self, data_path, start_idx):
        data = open(data_path)
        for k, example in enumerate(data):
            if k < start_idx:
                continue
            example = json.loads(example)
            q = example["question"]
            round_count = 0
            message_keys_list = [
                ChatMessage(
                    role="user",
                    content="""Construct a global reasoning chain for this complex [Question] : " {} " You should generate a query to the search engine based on what you already know at each step of the reasoning chain, starting with [Query]. If you know the answer for [Query], generate it starting with [Answer]. You can try to generate the final answer for the [Question] by referring to the [Query]-[Answer] pairs, starting with [Final Content]. If you don't know the answer, generate a query to search engine based on what you already know and do not know, starting with [Unsolved Query].
                        For example:
                        [Question]: "Where do greyhound buses that are in the birthplace of Spirit If...'s performer leave from? "
                        [Query 1]: Who is the performer of Spirit If... ?
                        If you don't know the answer:
                        [Unsolved Query]: Who is the performer of Spirit If... ?
                        If you know the answer:
                        [Answer 1]: The performer of Spirit If... is Kevin Drew.
                        [Query 2]: Where was Kevin Drew born?
                        If you don't know the answer:
                        [Unsolved Query]: Where was Kevin Drew born?
                        If you know the answer:
                        [Answer 2]: Toronto.
                        [Query 3]: Where do greyhound buses in Toronto leave from?
                        If you don't know the answer:
                        [Unsolved Query]: Where do greyhound buses in Toronto leave from?
                        If you know the answer:
                        [Answer 3]: Toronto Coach Terminal.
                        [Final Content]: The performer of Spirit If... is Kevin Drew [1]. Kevin Drew was born in Toronto [2]. Greyhound buses in Toronto leave from Toronto Coach Terminal [3]. So the final answer is Toronto Coach Terminal.

                        [Question]:"Which magazine was started first Arthur’s Magazine or First for Women?"
                        [Query 1]: When was Arthur’s Magazine started?
                        [Answer 1]: 1844.
                        [Query 2]: When was First for Women started?
                        [Answer 2]: 1989
                        [Final Content]: Arthur’s Magazine started in 1844 [1]. First for Women started in 1989 [2]. So Arthur’s Magazine was started first. So the answer is Arthur’s Magazi
                        [Question]: {}""".format(
                        q, q
                    ),
                )
            ]
            feedback_answer = "continue"
            predict_answer = ""
            query_seen_list = []
            while round_count < 5 and feedback_answer != "end":
                time.sleep(0.5)
                rsp = self.llm.chat(message_keys_list)
                round_count += 1
                input_str = str(rsp.message.content)
                message_keys_list.append(
                    ChatMessage(role="assistant", content=input_str)
                )
                predict_answer += input_str

                message, flag_contibue_label, query_seen_list = self._ir(
                    input_str, query_seen_list
                )
                if flag_contibue_label:
                    feedback = message
                else:
                    feedback = "end"

                if feedback == "end":
                    break
                # [Query]:xxxx<SEP>[Answer]:xxxx<SEP>[Reference]:xxxx<SEP>
                feedback_list = feedback.split("<SEP>")
                if "Unsolved Query" not in feedback:
                    new_prompt = """Reference: {} According to this Reference, the answer for "{}" should be "{}", you can change your answer based on the Reference and continue constructing the reasoning chain to give the final answer for [Question]:{}""".format(
                        feedback_list[0], feedback_list[1], q, feedback_list[2]
                    )
                else:
                    new_prompt = """Reference: {} According to this Reference, the answer for "{}" should be "{}", you can give your answer based on the Reference and continue constructing the reasoning chain to give the final answer for [Question]:{} """.format(
                        feedback_list[0], feedback_list[1], q, feedback_list[2]
                    )
                message_keys_list.append(ChatMessage(role="user", content=new_prompt))
            result = self._extract(message_keys_list)
            print(result)

        return -1