class TransformQueryEngine(BaseQueryEngine):
"""转换查询引擎。
在将查询传递给查询引擎之前,对查询包进行查询转换。
Args:
query_engine(BaseQueryEngine):查询引擎对象。
query_transform(BaseQueryTransform):查询转换对象。
transform_metadata(Optional[dict]):传递给查询转换的元数据。
callback_manager(Optional[CallbackManager]):回调管理器。"""
def __init__(
self,
query_engine: BaseQueryEngine,
query_transform: BaseQueryTransform,
transform_metadata: Optional[dict] = None,
callback_manager: Optional[CallbackManager] = None,
) -> None:
self._query_engine = query_engine
self._query_transform = query_transform
self._transform_metadata = transform_metadata
super().__init__(callback_manager)
def _get_prompt_modules(self) -> PromptMixinType:
"""获取提示子模块。"""
return {
"query_transform": self._query_transform,
"query_engine": self._query_engine,
}
def retrieve(self, query_bundle: QueryBundle) -> List[NodeWithScore]:
query_bundle = self._query_transform.run(
query_bundle, metadata=self._transform_metadata
)
return self._query_engine.retrieve(query_bundle)
def synthesize(
self,
query_bundle: QueryBundle,
nodes: List[NodeWithScore],
additional_source_nodes: Optional[Sequence[NodeWithScore]] = None,
) -> RESPONSE_TYPE:
query_bundle = self._query_transform.run(
query_bundle, metadata=self._transform_metadata
)
return self._query_engine.synthesize(
query_bundle=query_bundle,
nodes=nodes,
additional_source_nodes=additional_source_nodes,
)
async def asynthesize(
self,
query_bundle: QueryBundle,
nodes: List[NodeWithScore],
additional_source_nodes: Optional[Sequence[NodeWithScore]] = None,
) -> RESPONSE_TYPE:
query_bundle = self._query_transform.run(
query_bundle, metadata=self._transform_metadata
)
return await self._query_engine.asynthesize(
query_bundle=query_bundle,
nodes=nodes,
additional_source_nodes=additional_source_nodes,
)
def _query(self, query_bundle: QueryBundle) -> RESPONSE_TYPE:
"""回答一个查询。"""
query_bundle = self._query_transform.run(
query_bundle, metadata=self._transform_metadata
)
return self._query_engine.query(query_bundle)
async def _aquery(self, query_bundle: QueryBundle) -> RESPONSE_TYPE:
"""回答一个查询。"""
query_bundle = self._query_transform.run(
query_bundle, metadata=self._transform_metadata
)
return await self._query_engine.aquery(query_bundle)