18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152 | class Accumulate(BaseSynthesizer):
"""累积多个文本块的响应。"""
def __init__(
self,
llm: Optional[LLMPredictorType] = None,
callback_manager: Optional[CallbackManager] = None,
prompt_helper: Optional[PromptHelper] = None,
text_qa_template: Optional[BasePromptTemplate] = None,
output_cls: Optional[Any] = None,
streaming: bool = False,
use_async: bool = False,
# deprecated
service_context: Optional[ServiceContext] = None,
) -> None:
super().__init__(
llm=llm,
callback_manager=callback_manager,
prompt_helper=prompt_helper,
service_context=service_context,
streaming=streaming,
)
self._text_qa_template = text_qa_template or DEFAULT_TEXT_QA_PROMPT_SEL
self._use_async = use_async
self._output_cls = output_cls
def _get_prompts(self) -> PromptDictType:
"""获取提示。"""
return {"text_qa_template": self._text_qa_template}
def _update_prompts(self, prompts: PromptDictType) -> None:
"""更新提示。"""
if "text_qa_template" in prompts:
self._text_qa_template = prompts["text_qa_template"]
def flatten_list(self, md_array: List[List[Any]]) -> List[Any]:
return [item for sublist in md_array for item in sublist]
def _format_response(self, outputs: List[Any], separator: str) -> str:
responses: List[str] = []
for response in outputs:
responses.append(response or "Empty Response")
return separator.join(
[f"Response {index + 1}: {item}" for index, item in enumerate(responses)]
)
async def aget_response(
self,
query_str: str,
text_chunks: Sequence[str],
separator: str = "\n---------------------\n",
**response_kwargs: Any,
) -> RESPONSE_TEXT_TYPE:
"""将相同的提示应用于文本块,并返回异步响应。"""
if self._streaming:
raise ValueError("Unable to stream in Accumulate response mode")
tasks = [
self._give_responses(
query_str, text_chunk, use_async=True, **response_kwargs
)
for text_chunk in text_chunks
]
flattened_tasks = self.flatten_list(tasks)
outputs = await asyncio.gather(*flattened_tasks)
return self._format_response(outputs, separator)
def get_response(
self,
query_str: str,
text_chunks: Sequence[str],
separator: str = "\n---------------------\n",
**response_kwargs: Any,
) -> RESPONSE_TEXT_TYPE:
"""将相同的提示应用到文本块中,并返回响应。"""
if self._streaming:
raise ValueError("Unable to stream in Accumulate response mode")
tasks = [
self._give_responses(
query_str, text_chunk, use_async=self._use_async, **response_kwargs
)
for text_chunk in text_chunks
]
outputs = self.flatten_list(tasks)
if self._use_async:
outputs = run_async_tasks(outputs)
return self._format_response(outputs, separator)
def _give_responses(
self,
query_str: str,
text_chunk: str,
use_async: bool = False,
**response_kwargs: Any,
) -> List[Any]:
"""给定一个查询和相应的文本块,给出响应。"""
text_qa_template = self._text_qa_template.partial_format(query_str=query_str)
text_chunks = self._prompt_helper.repack(text_qa_template, [text_chunk])
predictor: Callable
if self._output_cls is None:
predictor = self._llm.apredict if use_async else self._llm.predict
return [
predictor(
text_qa_template,
context_str=cur_text_chunk,
**response_kwargs,
)
for cur_text_chunk in text_chunks
]
else:
predictor = (
self._llm.astructured_predict
if use_async
else self._llm.structured_predict
)
return [
predictor(
self._output_cls,
text_qa_template,
context_str=cur_text_chunk,
**response_kwargs,
)
for cur_text_chunk in text_chunks
]
|