Skip to content

Voyageai rerank

VoyageAIRerank #

Bases: BaseNodePostprocessor

Source code in llama_index/postprocessor/voyageai_rerank/base.py
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
class VoyageAIRerank(BaseNodePostprocessor):
    model: str = Field(description="Name of the model to use.")
    top_n: int = Field(
        description="The number of most relevant documents to return. If not specified, the reranking results of all documents will be returned."
    )
    truncation: bool = Field(
        description="Whether to truncate the input to satisfy the 'context length limit' on the query and the documents."
    )

    _client: Any = PrivateAttr()

    def __init__(
        self,
        api_key: str,
        model: str,
        top_n: Optional[int] = None,
        truncation: Optional[bool] = None,
        # deprecated
        top_k: Optional[int] = None,
    ):
        try:
            from voyageai import Client
        except ImportError:
            raise ImportError(
                "Cannot import voyageai package, please `pip install voyageai`."
            )

        self._client = Client(api_key=api_key)

        top_n = top_n or top_k
        super().__init__(top_n=top_n, model=model, truncation=truncation)

    @classmethod
    def class_name(cls) -> str:
        return "VoyageAIRerank"

    def _postprocess_nodes(
        self,
        nodes: List[NodeWithScore],
        query_bundle: Optional[QueryBundle] = None,
    ) -> List[NodeWithScore]:
        dispatch_event = dispatcher.get_dispatch_event()
        dispatch_event(
            ReRankStartEvent(
                query=query_bundle, nodes=nodes, top_n=self.top_n, model_name=self.model
            )
        )

        if query_bundle is None:
            raise ValueError("Missing query bundle in extra info.")
        if len(nodes) == 0:
            return []

        with self.callback_manager.event(
            CBEventType.RERANKING,
            payload={
                EventPayload.NODES: nodes,
                EventPayload.MODEL_NAME: self.model,
                EventPayload.QUERY_STR: query_bundle.query_str,
                EventPayload.TOP_K: self.top_n,
            },
        ) as event:
            texts = [
                node.node.get_content(metadata_mode=MetadataMode.EMBED)
                for node in nodes
            ]
            results = self._client.rerank(
                model=self.model,
                top_k=self.top_n,
                query=query_bundle.query_str,
                documents=texts,
                truncation=self.truncation,
            ).results

            new_nodes = []
            for result in results:
                new_node_with_score = NodeWithScore(
                    node=nodes[result.index].node, score=result.relevance_score
                )
                new_nodes.append(new_node_with_score)
            event.on_end(payload={EventPayload.NODES: new_nodes})

        dispatch_event(ReRankEndEvent(nodes=new_nodes))
        return new_nodes