Source code for langchain_community.embeddings.baidu_qianfan_endpoint
from __future__ import annotations
import logging
from typing import Any, Dict, List, Optional
from langchain_core.embeddings import Embeddings
from langchain_core.pydantic_v1 import BaseModel, Field, root_validator
from langchain_core.utils import convert_to_secret_str, get_from_dict_or_env
logger = logging.getLogger(__name__)
[docs]class QianfanEmbeddingsEndpoint(BaseModel, Embeddings):
"""`百度千帆嵌入` 嵌入模型。"""
qianfan_ak: Optional[str] = None
"""前方应用程序 API 密钥"""
qianfan_sk: Optional[str] = None
"""Qianfan应用程序的秘钥"""
chunk_size: int = 16
"""多个文本输入时的块大小"""
model: str = "Embedding-V1"
"""模型名称
您可以从 https://cloud.baidu.com/doc/WENXINWORKSHOP/s/Nlks5zkzu 获取
目前,我们支持 Embedding-V1 和
- Embedding-V1 (默认模型)
- bge-large-en
- bge-large-zh
预设模型将映射到一个端点。
如果设置了 `endpoint`,则 `model` 将被忽略。"""
endpoint: str = ""
"""Qianfan嵌入的端点,如果使用自定义模型则需要。"""
client: Any
"""千帆客户端"""
init_kwargs: Dict[str, Any] = Field(default_factory=dict)
"""初始化qianfan客户端的kwargs,例如`query_per_second`,它与qianfan资源对象相关联,用于限制QPS。"""
model_kwargs: Dict[str, Any] = Field(default_factory=dict)
"""使用`do`时调用模型的额外参数。"""
@root_validator()
def validate_environment(cls, values: Dict) -> Dict:
"""验证环境变量或配置文件中是否存在qianfan_ak和qianfan_sk。
使用`ak`、`sk`、`model`、`endpoint`初始化qianfan嵌入式客户端。
参数:
values:包含配置信息的字典,必须包括qianfan_ak和qianfan_sk字段
返回值:
包含配置信息的字典。如果环境变量或配置文件中未提供qianfan_ak和qianfan_sk,则将返回原始值;否则,将返回包含qianfan_ak和qianfan_sk的值。
抛出:
ValueError:未找到qianfan包,请使用`pip install qianfan`安装。
"""
values["qianfan_ak"] = convert_to_secret_str(
get_from_dict_or_env(
values,
"qianfan_ak",
"QIANFAN_AK",
default="",
)
)
values["qianfan_sk"] = convert_to_secret_str(
get_from_dict_or_env(
values,
"qianfan_sk",
"QIANFAN_SK",
default="",
)
)
try:
import qianfan
params = {
**values.get("init_kwargs", {}),
"model": values["model"],
}
if values["qianfan_ak"].get_secret_value() != "":
params["ak"] = values["qianfan_ak"].get_secret_value()
if values["qianfan_sk"].get_secret_value() != "":
params["sk"] = values["qianfan_sk"].get_secret_value()
if values["endpoint"] is not None and values["endpoint"] != "":
params["endpoint"] = values["endpoint"]
values["client"] = qianfan.Embedding(**params)
except ImportError:
raise ImportError(
"qianfan package not found, please install it with "
"`pip install qianfan`"
)
return values
[docs] def embed_query(self, text: str) -> List[float]:
resp = self.embed_documents([text])
return resp[0]
[docs] def embed_documents(self, texts: List[str]) -> List[List[float]]:
"""使用AutoVOT算法嵌入文本文档列表。
参数:
texts(List[str]):要嵌入的文本文档列表。
返回:
List[List[float]]:输入列表中每个文档的嵌入列表。
每个嵌入都表示为一组浮点值。
"""
text_in_chunks = [
texts[i : i + self.chunk_size]
for i in range(0, len(texts), self.chunk_size)
]
lst = []
for chunk in text_in_chunks:
resp = self.client.do(texts=chunk, **self.model_kwargs)
lst.extend([res["embedding"] for res in resp["data"]])
return lst
[docs] async def aembed_query(self, text: str) -> List[float]:
embeddings = await self.aembed_documents([text])
return embeddings[0]
[docs] async def aembed_documents(self, texts: List[str]) -> List[List[float]]:
text_in_chunks = [
texts[i : i + self.chunk_size]
for i in range(0, len(texts), self.chunk_size)
]
lst = []
for chunk in text_in_chunks:
resp = await self.client.ado(texts=chunk, **self.model_kwargs)
for res in resp["data"]:
lst.extend([res["embedding"]])
return lst