class PostgresMLRetriever(BaseRetriever):
"""PostgresML检索器。
Args:
index (PostgresMLIndex): PostgresML索引"""
def __init__(
self,
index: PostgresMLIndex,
callback_manager: Optional[CallbackManager] = None,
**kwargs: Any,
) -> None:
"""初始化参数。"""
self._index = index
super().__init__(callback_manager)
def _retrieve(
self,
query_bundle: Optional[QueryBundle] = None,
query: Optional[Dict[str, Any]] = None,
limit: Optional[int] = 5,
**kwargs: Any,
) -> List[NodeWithScore]:
"""检索前limit个最相似的节点。
必须提供query或query_bundle中的一个。
Args:
query_bundle: Optional[QueryBundle] = None
query: Optional[Dict[str, Any]] = None
limit: Optional[int] = 5
"""
return run_async_tasks([self._aretrieve(query_bundle, query, limit, **kwargs)])[
0
]
async def _aretrieve(
self,
query_bundle: Optional[QueryBundle] = None,
query: Optional[Dict[str, Any]] = None,
limit: Optional[int] = 5,
**kwargs: Any,
) -> List[NodeWithScore]:
"""获取前limit个最相似的节点。
必须提供query或query_bundle中的一个。
如果提供query,则忽略limit。
Args:
query_bundle: Optional[QueryBundle] = None
query: Optional[Dict[str, Any]] = None
limit: Optional[int] = 5
"""
async def do_vector_search():
if query:
return await self._index.collection.vector_search(
query,
self._index.pipeline,
)
else:
if not query_bundle:
raise Exception(
"Must provide either query or query_bundle to retrieve and aretrieve"
)
return await self._index.collection.vector_search(
{
"query": {
"fields": {
"content": {
"query": query_bundle.query_str,
"parameters": {"prompt": "query: "},
}
}
},
"limit": limit,
},
self._index.pipeline,
)
results = await do_vector_search()
return [
NodeWithScore(
node=TextNode(
id_=r["document"]["id"],
text=r["chunk"],
metadata=r["document"]["metadata"],
),
score=r["score"],
)
for r in results
]