Source code for langchain_community.embeddings.gigachat

from __future__ import annotations

import logging
from functools import cached_property
from typing import Any, Dict, List, Optional

from langchain_core.embeddings import Embeddings
from langchain_core.pydantic_v1 import BaseModel, root_validator

logger = logging.getLogger(__name__)

MAX_BATCH_SIZE_CHARS = 1000000
MAX_BATCH_SIZE_PARTS = 90


[docs]class GigaChatEmbeddings(BaseModel, Embeddings): """GigaChat嵌入模型。 示例: .. code-block:: python from langchain_community.embeddings.gigachat import GigaChatEmbeddings embeddings = GigaChatEmbeddings( credentials=..., scope=..., verify_ssl_certs=False ) """ base_url: Optional[str] = None """基本API URL""" auth_url: Optional[str] = None """认证 URL""" credentials: Optional[str] = None """身份验证令牌""" scope: Optional[str] = None """访问令牌的权限范围""" access_token: Optional[str] = None """GigaChat的访问令牌""" model: Optional[str] = None """要使用的模型名称。""" user: Optional[str] = None """用于身份验证的用户名""" password: Optional[str] = None """用于身份验证的密码""" timeout: Optional[float] = 600 """请求超时。默认情况下,它适用于长时间请求。""" verify_ssl_certs: Optional[bool] = None """检查所有请求的证书""" ca_bundle_file: Optional[str] = None cert_file: Optional[str] = None key_file: Optional[str] = None key_file_password: Optional[str] = None # Support for connection to GigaChat through SSL certificates @cached_property def _client(self) -> Any: """返回 GigaChat API 客户端""" import gigachat return gigachat.GigaChat( base_url=self.base_url, auth_url=self.auth_url, credentials=self.credentials, scope=self.scope, access_token=self.access_token, model=self.model, user=self.user, password=self.password, timeout=self.timeout, verify_ssl_certs=self.verify_ssl_certs, ca_bundle_file=self.ca_bundle_file, cert_file=self.cert_file, key_file=self.key_file, key_file_password=self.key_file_password, ) @root_validator() def validate_environment(cls, values: Dict) -> Dict: """验证环境中的身份验证数据,并检查是否安装了Python包。""" try: import gigachat # noqa: F401 except ImportError: raise ImportError( "Could not import gigachat python package. " "Please install it with `pip install gigachat`." ) fields = set(cls.__fields__.keys()) diff = set(values.keys()) - fields if diff: logger.warning(f"Extra fields {diff} in GigaChat class") return values
[docs] def embed_documents(self, texts: List[str]) -> List[List[float]]: """使用GigaChat嵌入模型嵌入文档。 参数: texts: 要嵌入的文本列表。 返回: 每个文本的嵌入列表。 """ result: List[List[float]] = [] size = 0 local_texts = [] embed_kwargs = {} if self.model is not None: embed_kwargs["model"] = self.model for text in texts: local_texts.append(text) size += len(text) if size > MAX_BATCH_SIZE_CHARS or len(local_texts) > MAX_BATCH_SIZE_PARTS: for embedding in self._client.embeddings( texts=local_texts, **embed_kwargs ).data: result.append(embedding.embedding) size = 0 local_texts = [] # Call for last iteration if local_texts: for embedding in self._client.embeddings( texts=local_texts, **embed_kwargs ).data: result.append(embedding.embedding) return result
[docs] async def aembed_documents(self, texts: List[str]) -> List[List[float]]: """使用GigaChat嵌入模型嵌入文档。 参数: texts: 要嵌入的文本列表。 返回: 每个文本的嵌入列表。 """ result: List[List[float]] = [] size = 0 local_texts = [] embed_kwargs = {} if self.model is not None: embed_kwargs["model"] = self.model for text in texts: local_texts.append(text) size += len(text) if size > MAX_BATCH_SIZE_CHARS or len(local_texts) > MAX_BATCH_SIZE_PARTS: embeddings = await self._client.aembeddings( texts=local_texts, **embed_kwargs ) for embedding in embeddings.data: result.append(embedding.embedding) size = 0 local_texts = [] # Call for last iteration if local_texts: embeddings = await self._client.aembeddings( texts=local_texts, **embed_kwargs ) for embedding in embeddings.data: result.append(embedding.embedding) return result
[docs] def embed_query(self, text: str) -> List[float]: """使用GigaChat嵌入模型嵌入一个查询。 参数: text: 要嵌入的文本。 返回: 文本的嵌入。 """ return self.embed_documents(texts=[text])[0]
[docs] async def aembed_query(self, text: str) -> List[float]: """使用GigaChat嵌入模型嵌入一个查询。 参数: text: 要嵌入的文本。 返回: 文本的嵌入。 """ docs = await self.aembed_documents(texts=[text]) return docs[0]