classFLAREInstructQueryEngine(BaseQueryEngine):""" FLARE Instruct query engine. This is the version of FLARE that uses retrieval-encouraging instructions. NOTE: this is a beta feature. Interfaces might change, and it might not always give correct answers. Args: query_engine (BaseQueryEngine): query engine to use llm (Optional[LLM]): LLM model. Defaults to None. instruct_prompt (Optional[PromptTemplate]): instruct prompt. Defaults to None. lookahead_answer_inserter (Optional[BaseLookaheadAnswerInserter]): lookahead answer inserter. Defaults to None. done_output_parser (Optional[IsDoneOutputParser]): done output parser. Defaults to None. query_task_output_parser (Optional[QueryTaskOutputParser]): query task output parser. Defaults to None. max_iterations (int): max iterations. Defaults to 10. max_lookahead_query_tasks (int): max lookahead query tasks. Defaults to 1. callback_manager (Optional[CallbackManager]): callback manager. Defaults to None. verbose (bool): give verbose outputs. Defaults to False. """def__init__(self,query_engine:BaseQueryEngine,llm:Optional[LLM]=None,instruct_prompt:Optional[BasePromptTemplate]=None,lookahead_answer_inserter:Optional[BaseLookaheadAnswerInserter]=None,done_output_parser:Optional[IsDoneOutputParser]=None,query_task_output_parser:Optional[QueryTaskOutputParser]=None,max_iterations:int=10,max_lookahead_query_tasks:int=1,callback_manager:Optional[CallbackManager]=None,verbose:bool=False,)->None:"""Init params."""super().__init__(callback_manager=callback_manager)self._query_engine=query_engineself._llm=llmorSettings.llmself._instruct_prompt=instruct_promptorDEFAULT_INSTRUCT_PROMPTself._lookahead_answer_inserter=lookahead_answer_inserteror(LLMLookaheadAnswerInserter(llm=self._llm))self._done_output_parser=done_output_parserorIsDoneOutputParser()self._query_task_output_parser=(query_task_output_parserorQueryTaskOutputParser())self._max_iterations=max_iterationsself._max_lookahead_query_tasks=max_lookahead_query_tasksself._verbose=verbosedef_get_prompts(self)->Dict[str,Any]:"""Get prompts."""return{"instruct_prompt":self._instruct_prompt,}def_update_prompts(self,prompts:PromptDictType)->None:"""Update prompts."""if"instruct_prompt"inprompts:self._instruct_prompt=prompts["instruct_prompt"]def_get_prompt_modules(self)->PromptMixinType:"""Get prompt sub-modules."""return{"query_engine":self._query_engine,"lookahead_answer_inserter":self._lookahead_answer_inserter,}def_get_relevant_lookahead_response(self,updated_lookahead_resp:str)->str:"""Get relevant lookahead response."""# if there's remaining query tasks, then truncate the response# until the start position of the first tag# there may be remaining query tasks because the _max_lookahead_query_tasks# is less than the total number of generated [Search(query)] tagsremaining_query_tasks=self._query_task_output_parser.parse(updated_lookahead_resp)iflen(remaining_query_tasks)==0:relevant_lookahead_resp=updated_lookahead_respelse:first_task=remaining_query_tasks[0]relevant_lookahead_resp=updated_lookahead_resp[:first_task.start_idx]returnrelevant_lookahead_respdef_query(self,query_bundle:QueryBundle)->RESPONSE_TYPE:"""Query and get response."""print_text(f"Query: {query_bundle.query_str}\n",color="green")cur_response=""source_nodes=[]foriterinrange(self._max_iterations):ifself._verbose:print_text(f"Current response: {cur_response}\n",color="blue")# generate "lookahead response" that contains "[Search(query)]" tags# e.g.# The colors on the flag of Ghana have the following meanings. Red is# for [Search(Ghana flag meaning)],...lookahead_resp=self._llm.predict(self._instruct_prompt,query_str=query_bundle.query_str,existing_answer=cur_response,)lookahead_resp=lookahead_resp.strip()ifself._verbose:print_text(f"Lookahead response: {lookahead_resp}\n",color="pink")is_done,fmt_lookahead=self._done_output_parser.parse(lookahead_resp)ifis_done:cur_response=cur_response.strip()+" "+fmt_lookahead.strip()break# parse lookahead response into query tasksquery_tasks=self._query_task_output_parser.parse(lookahead_resp)# get answers for each query taskquery_tasks=query_tasks[:self._max_lookahead_query_tasks]query_answers=[]for_,query_taskinenumerate(query_tasks):answer_obj=self._query_engine.query(query_task.query_str)ifnotisinstance(answer_obj,Response):raiseValueError(f"Expected Response object, got {type(answer_obj)} instead.")query_answer=str(answer_obj)query_answers.append(query_answer)source_nodes.extend(answer_obj.source_nodes)# fill in the lookahead response template with the query answers# from the query engineupdated_lookahead_resp=self._lookahead_answer_inserter.insert(lookahead_resp,query_tasks,query_answers,prev_response=cur_response)# get "relevant" lookahead response by truncating the updated# lookahead response until the start position of the first tag# also remove the prefix from the lookahead response, so that# we can concatenate it with the existing responserelevant_lookahead_resp_wo_prefix=self._get_relevant_lookahead_response(updated_lookahead_resp)ifself._verbose:print_text("Updated lookahead response: "+f"{relevant_lookahead_resp_wo_prefix}\n",color="pink",)# append the relevant lookahead response to the final responsecur_response=(cur_response.strip()+" "+relevant_lookahead_resp_wo_prefix.strip())# NOTE: at the moment, does not support streamingreturnResponse(response=cur_response,source_nodes=source_nodes)asyncdef_aquery(self,query_bundle:QueryBundle)->RESPONSE_TYPE:returnself._query(query_bundle)defretrieve(self,query_bundle:QueryBundle)->List[NodeWithScore]:# if the query engine is a retriever, then use the retrieve methodifisinstance(self._query_engine,RetrieverQueryEngine):returnself._query_engine.retrieve(query_bundle)else:raiseNotImplementedError("This query engine does not support retrieve, use query directly")asyncdefaretrieve(self,query_bundle:QueryBundle)->List[NodeWithScore]:# if the query engine is a retriever, then use the retrieve methodifisinstance(self._query_engine,RetrieverQueryEngine):returnawaitself._query_engine.aretrieve(query_bundle)else:raiseNotImplementedError("This query engine does not support retrieve, use query directly")