import logging
import re
import string
import threading
from concurrent.futures import ThreadPoolExecutor, wait
from typing import Any, Dict, List, Literal, Optional, Tuple
from langchain_core._api.deprecation import deprecated
from langchain_core.embeddings import Embeddings
from langchain_core.language_models.llms import create_base_retry_decorator
from langchain_core.pydantic_v1 import root_validator
from langchain_community.llms.vertexai import _VertexAICommon
from langchain_community.utilities.vertexai import raise_vertex_import_error
logger = logging.getLogger(__name__)
_MAX_TOKENS_PER_BATCH = 20000
_MAX_BATCH_SIZE = 250
_MIN_BATCH_SIZE = 5
[docs]@deprecated(
since="0.0.12",
removal="0.3.0",
alternative_import="langchain_google_vertexai.VertexAIEmbeddings",
)
class VertexAIEmbeddings(_VertexAICommon, Embeddings):
"""谷歌云VertexAI嵌入模型。"""
# Instance context
instance: Dict[str, Any] = {} #: :meta private:
show_progress_bar: bool = False
"""是否显示tqdm进度条。必须安装`tqdm`。"""
@root_validator()
def validate_environment(cls, values: Dict) -> Dict:
"""验证Python包是否存在于环境中。"""
cls._try_init_vertexai(values)
if values["model_name"] == "textembedding-gecko-default":
logger.warning(
"Model_name will become a required arg for VertexAIEmbeddings "
"starting from Feb-01-2024. Currently the default is set to "
"textembedding-gecko@001"
)
values["model_name"] = "textembedding-gecko@001"
try:
from vertexai.language_models import TextEmbeddingModel
except ImportError:
raise_vertex_import_error()
values["client"] = TextEmbeddingModel.from_pretrained(values["model_name"])
return values
def __init__(
self,
# the default value would be removed after Feb-01-2024
model_name: str = "textembedding-gecko-default",
project: Optional[str] = None,
location: str = "us-central1",
request_parallelism: int = 5,
max_retries: int = 6,
credentials: Optional[Any] = None,
**kwargs: Any,
):
"""初始化sentence_transformer。"""
super().__init__(
project=project,
location=location,
credentials=credentials,
request_parallelism=request_parallelism,
max_retries=max_retries,
model_name=model_name,
**kwargs,
)
self.instance["max_batch_size"] = kwargs.get("max_batch_size", _MAX_BATCH_SIZE)
self.instance["batch_size"] = self.instance["max_batch_size"]
self.instance["min_batch_size"] = kwargs.get("min_batch_size", _MIN_BATCH_SIZE)
self.instance["min_good_batch_size"] = self.instance["min_batch_size"]
self.instance["lock"] = threading.Lock()
self.instance["batch_size_validated"] = False
self.instance["task_executor"] = ThreadPoolExecutor(
max_workers=request_parallelism
)
self.instance[
"embeddings_task_type_supported"
] = not self.client._endpoint_name.endswith("/textembedding-gecko@001")
@staticmethod
def _split_by_punctuation(text: str) -> List[str]:
"""将字符串按标点符号和空白字符分割。"""
split_by = string.punctuation + "\t\n "
pattern = f"([{split_by}])"
# Using re.split to split the text based on the pattern
return [segment for segment in re.split(pattern, text) if segment]
@staticmethod
def _prepare_batches(texts: List[str], batch_size: int) -> List[List[str]]:
"""根据当前的最大批处理大小和每个请求的最大标记数将文本分割成批次。
"""
text_index = 0
texts_len = len(texts)
batch_token_len = 0
batches: List[List[str]] = []
current_batch: List[str] = []
if texts_len == 0:
return []
while text_index < texts_len:
current_text = texts[text_index]
# Number of tokens per a text is conservatively estimated
# as 2 times number of words, punctuation and whitespace characters.
# Using `count_tokens` API will make batching too expensive.
# Utilizing a tokenizer, would add a dependency that would not
# necessarily be reused by the application using this class.
current_text_token_cnt = (
len(VertexAIEmbeddings._split_by_punctuation(current_text)) * 2
)
end_of_batch = False
if current_text_token_cnt > _MAX_TOKENS_PER_BATCH:
# Current text is too big even for a single batch.
# Such request will fail, but we still make a batch
# so that the app can get the error from the API.
if len(current_batch) > 0:
# Adding current batch if not empty.
batches.append(current_batch)
current_batch = [current_text]
text_index += 1
end_of_batch = True
elif (
batch_token_len + current_text_token_cnt > _MAX_TOKENS_PER_BATCH
or len(current_batch) == batch_size
):
end_of_batch = True
else:
if text_index == texts_len - 1:
# Last element - even though the batch may be not big,
# we still need to make it.
end_of_batch = True
batch_token_len += current_text_token_cnt
current_batch.append(current_text)
text_index += 1
if end_of_batch:
batches.append(current_batch)
current_batch = []
batch_token_len = 0
return batches
def _get_embeddings_with_retry(
self, texts: List[str], embeddings_type: Optional[str] = None
) -> List[List[float]]:
"""使用重试逻辑进行 Vertex AI 模型请求。"""
from google.api_core.exceptions import (
Aborted,
DeadlineExceeded,
ResourceExhausted,
ServiceUnavailable,
)
errors = [
ResourceExhausted,
ServiceUnavailable,
Aborted,
DeadlineExceeded,
]
retry_decorator = create_base_retry_decorator(
error_types=errors, max_retries=self.max_retries
)
@retry_decorator
def _completion_with_retry(texts_to_process: List[str]) -> Any:
if embeddings_type and self.instance["embeddings_task_type_supported"]:
from vertexai.language_models import TextEmbeddingInput
requests = [
TextEmbeddingInput(text=t, task_type=embeddings_type)
for t in texts_to_process
]
else:
requests = texts_to_process
embeddings = self.client.get_embeddings(requests)
return [embs.values for embs in embeddings]
return _completion_with_retry(texts)
def _prepare_and_validate_batches(
self, texts: List[str], embeddings_type: Optional[str] = None
) -> Tuple[List[List[float]], List[List[str]]]:
"""准备文本批次,并对批次大小进行一次性验证。
批次大小因GCP地区和个人项目配额而异。
# 返回第一个经过的文本批次的嵌入,
# 以及其余文本的文本批次。
"""
from google.api_core.exceptions import InvalidArgument
batches = VertexAIEmbeddings._prepare_batches(
texts, self.instance["batch_size"]
)
# If batch size if less or equal to one that went through before,
# then keep batches as they are.
if len(batches[0]) <= self.instance["min_good_batch_size"]:
return [], batches
with self.instance["lock"]:
# If largest possible batch size was validated
# while waiting for the lock, then check for rebuilding
# our batches, and return.
if self.instance["batch_size_validated"]:
if len(batches[0]) <= self.instance["batch_size"]:
return [], batches
else:
return [], VertexAIEmbeddings._prepare_batches(
texts, self.instance["batch_size"]
)
# Figure out largest possible batch size by trying to push
# batches and lowering their size in half after every failure.
first_batch = batches[0]
first_result = []
had_failure = False
while True:
try:
first_result = self._get_embeddings_with_retry(
first_batch, embeddings_type
)
break
except InvalidArgument:
had_failure = True
first_batch_len = len(first_batch)
if first_batch_len == self.instance["min_batch_size"]:
raise
first_batch_len = max(
self.instance["min_batch_size"], int(first_batch_len / 2)
)
first_batch = first_batch[:first_batch_len]
first_batch_len = len(first_batch)
self.instance["min_good_batch_size"] = max(
self.instance["min_good_batch_size"], first_batch_len
)
# If had a failure and recovered
# or went through with the max size, then it's a legit batch size.
if had_failure or first_batch_len == self.instance["max_batch_size"]:
self.instance["batch_size"] = first_batch_len
self.instance["batch_size_validated"] = True
# If batch size was updated,
# rebuild batches with the new batch size
# (texts that went through are excluded here).
if first_batch_len != self.instance["max_batch_size"]:
batches = VertexAIEmbeddings._prepare_batches(
texts[first_batch_len:], self.instance["batch_size"]
)
else:
# Still figuring out max batch size.
batches = batches[1:]
# Returning embeddings of the first text batch that went through,
# and text batches for the rest of texts.
return first_result, batches
[docs] def embed(
self,
texts: List[str],
batch_size: int = 0,
embeddings_task_type: Optional[
Literal[
"RETRIEVAL_QUERY",
"RETRIEVAL_DOCUMENT",
"SEMANTIC_SIMILARITY",
"CLASSIFICATION",
"CLUSTERING",
]
] = None,
) -> List[List[float]]:
"""嵌入字符串列表。
参数:
texts:List[str] 要嵌入的字符串列表。
batch_size:[int] 发送到模型的嵌入批大小。
如果为零,则将在第一次请求时动态检测最大批大小,从250开始,逐渐减少到5。
embeddings_task_type:[str] 可选的嵌入任务类型,
以下为其中一种:
RETRIEVAL_QUERY - 文本用作搜索/检索设置中的查询。
RETRIEVAL_DOCUMENT - 文本用作搜索/检索设置中的文档。
SEMANTIC_SIMILARITY - 嵌入将用于语义文本相似度(STS)。
CLASSIFICATION - 嵌入将用于分类。
CLUSTERING - 嵌入将用于聚类。
返回:
每个文本的嵌入列表。
"""
if len(texts) == 0:
return []
embeddings: List[List[float]] = []
first_batch_result: List[List[float]] = []
if batch_size > 0:
# Fixed batch size.
batches = VertexAIEmbeddings._prepare_batches(texts, batch_size)
else:
# Dynamic batch size, starting from 250 at the first call.
first_batch_result, batches = self._prepare_and_validate_batches(
texts, embeddings_task_type
)
# First batch result may have some embeddings already.
# In such case, batches have texts that were not processed yet.
embeddings.extend(first_batch_result)
tasks = []
if self.show_progress_bar:
try:
from tqdm import tqdm
iter_ = tqdm(batches, desc="VertexAIEmbeddings")
except ImportError:
logger.warning(
"Unable to show progress bar because tqdm could not be imported. "
"Please install with `pip install tqdm`."
)
iter_ = batches
else:
iter_ = batches
for batch in iter_:
tasks.append(
self.instance["task_executor"].submit(
self._get_embeddings_with_retry,
texts=batch,
embeddings_type=embeddings_task_type,
)
)
if len(tasks) > 0:
wait(tasks)
for t in tasks:
embeddings.extend(t.result())
return embeddings
[docs] def embed_documents(
self, texts: List[str], batch_size: int = 0
) -> List[List[float]]:
"""嵌入文档列表。
参数:
texts: List[str] 要嵌入的文本列表。
batch_size: [int] 发送到模型的嵌入的批处理大小。
如果为零,则将在第一次请求时动态检测到最大的批处理大小,从250开始,逐渐减少到5。
返回:
每个文本的嵌入列表。
"""
return self.embed(texts, batch_size, "RETRIEVAL_DOCUMENT")
[docs] def embed_query(self, text: str) -> List[float]:
"""嵌入文本。
参数:
text: 要嵌入的文本。
返回:
文本的嵌入。
"""
embeddings = self.embed([text], 1, "RETRIEVAL_QUERY")
return embeddings[0]