class MistralAIEmbedding(BaseEmbedding):
"""用于MistralAI嵌入的类。
Args:
model_name (str): 用于嵌入的模型。
默认为"mistral-embed"。
api_key (Optional[str]): 访问模型的API密钥。默认为None。"""
# Instance variables initialized via Pydantic's mechanism
_mistralai_client: Any = PrivateAttr()
_mistralai_async_client: Any = PrivateAttr()
def __init__(
self,
model_name: str = "mistral-embed",
api_key: Optional[str] = None,
embed_batch_size: int = DEFAULT_EMBED_BATCH_SIZE,
callback_manager: Optional[CallbackManager] = None,
**kwargs: Any,
):
api_key = get_from_param_or_env("api_key", api_key, "MISTRAL_API_KEY", "")
if not api_key:
raise ValueError(
"You must provide an API key to use mistralai. "
"You can either pass it in as an argument or set it `MISTRAL_API_KEY`."
)
self._mistralai_client = MistralClient(api_key=api_key)
self._mistralai_async_client = MistralAsyncClient(api_key=api_key)
super().__init__(
model_name=model_name,
embed_batch_size=embed_batch_size,
callback_manager=callback_manager,
**kwargs,
)
@classmethod
def class_name(cls) -> str:
return "MistralAIEmbedding"
def _get_query_embedding(self, query: str) -> List[float]:
"""获取查询嵌入。"""
return (
self._mistralai_client.embeddings(model=self.model_name, input=[query])
.data[0]
.embedding
)
async def _aget_query_embedding(self, query: str) -> List[float]:
"""_get_query_embedding的异步版本。"""
return (
(
await self._mistralai_async_client.embeddings(
model=self.model_name, input=[query]
)
)
.data[0]
.embedding
)
def _get_text_embedding(self, text: str) -> List[float]:
"""获取文本嵌入。"""
return (
self._mistralai_client.embeddings(model=self.model_name, input=[text])
.data[0]
.embedding
)
async def _aget_text_embedding(self, text: str) -> List[float]:
"""异步获取文本嵌入。"""
return (
(
await self._mistralai_async_client.embeddings(
model=self.model_name, input=[text]
)
)
.data[0]
.embedding
)
def _get_text_embeddings(self, texts: List[str]) -> List[List[float]]:
"""获取文本嵌入。"""
embedding_response = self._mistralai_client.embeddings(
model=self.model_name, input=texts
).data
return [embed.embedding for embed in embedding_response]
async def _aget_text_embeddings(self, texts: List[str]) -> List[List[float]]:
"""异步获取文本嵌入。"""
embedding_response = await self._mistralai_async_client.embeddings(
model=self.model_name, input=texts
)
return [embed.embedding for embed in embedding_response.data]