15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119 | class VoyageEmbedding(BaseEmbedding):
"""航行嵌入的类。
Args:
model_name (str): 嵌入模型。
默认为"voyage-01"。
voyage_api_key (Optional[str]): 航行API密钥。默认为None。
您可以在这里指定密钥,也可以将其存储为环境变量。"""
_client: voyageai.Client = PrivateAttr(None)
_aclient: voyageai.client_async.AsyncClient = PrivateAttr()
truncation: Optional[bool] = None
def __init__(
self,
model_name: str,
voyage_api_key: Optional[str] = None,
embed_batch_size: Optional[int] = None,
truncation: Optional[bool] = None,
callback_manager: Optional[CallbackManager] = None,
**kwargs: Any,
):
if model_name == "voyage-01":
logger.warning(
"voyage-01 is not the latest model by Voyage AI. Please note that `model_name` "
"will be a required argument in the future. We recommend setting it explicitly. Please see "
"https://docs.voyageai.com/docs/embeddings for the latest models offered by Voyage AI."
)
if embed_batch_size is None:
embed_batch_size = 72 if model_name in ["voyage-2", "voyage-02"] else 7
super().__init__(
model_name=model_name,
embed_batch_size=embed_batch_size,
callback_manager=callback_manager,
**kwargs,
)
self._client = voyageai.Client(api_key=voyage_api_key)
self._aclient = voyageai.AsyncClient(api_key=voyage_api_key)
self.truncation = truncation
@classmethod
def class_name(cls) -> str:
return "VoyageEmbedding"
def _get_embedding(self, texts: List[str], input_type: str) -> List[List[float]]:
return self._client.embed(
texts,
model=self.model_name,
input_type=input_type,
truncation=self.truncation,
).embeddings
async def _aget_embedding(
self, texts: List[str], input_type: str
) -> List[List[float]]:
r = await self._aclient.embed(
texts,
model=self.model_name,
input_type=input_type,
truncation=self.truncation,
)
return r.embeddings
def _get_query_embedding(self, query: str) -> List[float]:
"""获取查询嵌入。"""
return self._get_embedding([query], input_type="query")[0]
async def _aget_query_embedding(self, query: str) -> List[float]:
"""_get_query_embedding的异步版本。"""
r = await self._aget_embedding([query], input_type="query")
return r[0]
def _get_text_embedding(self, text: str) -> List[float]:
"""获取文本嵌入。"""
return self._get_embedding([text], input_type="document")[0]
async def _aget_text_embedding(self, text: str) -> List[float]:
"""异步获取文本嵌入。"""
r = await self._aget_embedding([text], input_type="document")
return r[0]
def _get_text_embeddings(self, texts: List[str]) -> List[List[float]]:
"""获取文本嵌入。"""
return self._get_embedding(texts, input_type="document")
async def _aget_text_embeddings(self, texts: List[str]) -> List[List[float]]:
"""异步获取文本嵌入。"""
return await self._aget_embedding(texts, input_type="document")
def get_general_text_embedding(
self, text: str, input_type: Optional[str] = None
) -> List[float]:
"""使用input_type获取通用文本嵌入。"""
return self._get_embedding([text], input_type=input_type)[0]
async def aget_general_text_embedding(
self, text: str, input_type: Optional[str] = None
) -> List[float]:
"""使用输入类型异步获取通用文本嵌入。"""
r = await self._aget_embedding([text], input_type=input_type)
return r[0]
|