Source code for langchain_community.embeddings.gigachat
from __future__ import annotations
import logging
from functools import cached_property
from typing import Any, Dict, List, Optional
from langchain_core.embeddings import Embeddings
from langchain_core.pydantic_v1 import BaseModel, root_validator
logger = logging.getLogger(__name__)
MAX_BATCH_SIZE_CHARS = 1000000
MAX_BATCH_SIZE_PARTS = 90
[docs]class GigaChatEmbeddings(BaseModel, Embeddings):
"""GigaChat嵌入模型。
示例:
.. code-block:: python
from langchain_community.embeddings.gigachat import GigaChatEmbeddings
embeddings = GigaChatEmbeddings(
credentials=..., scope=..., verify_ssl_certs=False
)
"""
base_url: Optional[str] = None
"""基本API URL"""
auth_url: Optional[str] = None
"""认证 URL"""
credentials: Optional[str] = None
"""身份验证令牌"""
scope: Optional[str] = None
"""访问令牌的权限范围"""
access_token: Optional[str] = None
"""GigaChat的访问令牌"""
model: Optional[str] = None
"""要使用的模型名称。"""
user: Optional[str] = None
"""用于身份验证的用户名"""
password: Optional[str] = None
"""用于身份验证的密码"""
timeout: Optional[float] = 600
"""请求超时。默认情况下,它适用于长时间请求。"""
verify_ssl_certs: Optional[bool] = None
"""检查所有请求的证书"""
ca_bundle_file: Optional[str] = None
cert_file: Optional[str] = None
key_file: Optional[str] = None
key_file_password: Optional[str] = None
# Support for connection to GigaChat through SSL certificates
@cached_property
def _client(self) -> Any:
"""返回 GigaChat API 客户端"""
import gigachat
return gigachat.GigaChat(
base_url=self.base_url,
auth_url=self.auth_url,
credentials=self.credentials,
scope=self.scope,
access_token=self.access_token,
model=self.model,
user=self.user,
password=self.password,
timeout=self.timeout,
verify_ssl_certs=self.verify_ssl_certs,
ca_bundle_file=self.ca_bundle_file,
cert_file=self.cert_file,
key_file=self.key_file,
key_file_password=self.key_file_password,
)
@root_validator()
def validate_environment(cls, values: Dict) -> Dict:
"""验证环境中的身份验证数据,并检查是否安装了Python包。"""
try:
import gigachat # noqa: F401
except ImportError:
raise ImportError(
"Could not import gigachat python package. "
"Please install it with `pip install gigachat`."
)
fields = set(cls.__fields__.keys())
diff = set(values.keys()) - fields
if diff:
logger.warning(f"Extra fields {diff} in GigaChat class")
return values
[docs] def embed_documents(self, texts: List[str]) -> List[List[float]]:
"""使用GigaChat嵌入模型嵌入文档。
参数:
texts: 要嵌入的文本列表。
返回:
每个文本的嵌入列表。
"""
result: List[List[float]] = []
size = 0
local_texts = []
embed_kwargs = {}
if self.model is not None:
embed_kwargs["model"] = self.model
for text in texts:
local_texts.append(text)
size += len(text)
if size > MAX_BATCH_SIZE_CHARS or len(local_texts) > MAX_BATCH_SIZE_PARTS:
for embedding in self._client.embeddings(
texts=local_texts, **embed_kwargs
).data:
result.append(embedding.embedding)
size = 0
local_texts = []
# Call for last iteration
if local_texts:
for embedding in self._client.embeddings(
texts=local_texts, **embed_kwargs
).data:
result.append(embedding.embedding)
return result
[docs] async def aembed_documents(self, texts: List[str]) -> List[List[float]]:
"""使用GigaChat嵌入模型嵌入文档。
参数:
texts: 要嵌入的文本列表。
返回:
每个文本的嵌入列表。
"""
result: List[List[float]] = []
size = 0
local_texts = []
embed_kwargs = {}
if self.model is not None:
embed_kwargs["model"] = self.model
for text in texts:
local_texts.append(text)
size += len(text)
if size > MAX_BATCH_SIZE_CHARS or len(local_texts) > MAX_BATCH_SIZE_PARTS:
embeddings = await self._client.aembeddings(
texts=local_texts, **embed_kwargs
)
for embedding in embeddings.data:
result.append(embedding.embedding)
size = 0
local_texts = []
# Call for last iteration
if local_texts:
embeddings = await self._client.aembeddings(
texts=local_texts, **embed_kwargs
)
for embedding in embeddings.data:
result.append(embedding.embedding)
return result
[docs] def embed_query(self, text: str) -> List[float]:
"""使用GigaChat嵌入模型嵌入一个查询。
参数:
text: 要嵌入的文本。
返回:
文本的嵌入。
"""
return self.embed_documents(texts=[text])[0]
[docs] async def aembed_query(self, text: str) -> List[float]:
"""使用GigaChat嵌入模型嵌入一个查询。
参数:
text: 要嵌入的文本。
返回:
文本的嵌入。
"""
docs = await self.aembed_documents(texts=[text])
return docs[0]