from __future__ import annotations
from concurrent.futures import Executor, ThreadPoolExecutor
from typing import TYPE_CHECKING, Any, ClassVar, Dict, Iterator, List, Optional, Union
from langchain_core._api.deprecation import deprecated
from langchain_core.callbacks.manager import (
AsyncCallbackManagerForLLMRun,
CallbackManagerForLLMRun,
)
from langchain_core.language_models.llms import BaseLLM
from langchain_core.outputs import Generation, GenerationChunk, LLMResult
from langchain_core.pydantic_v1 import BaseModel, Field, root_validator
from langchain_community.utilities.vertexai import (
create_retry_decorator,
get_client_info,
init_vertexai,
raise_vertex_import_error,
)
if TYPE_CHECKING:
from google.cloud.aiplatform.gapic import (
PredictionServiceAsyncClient,
PredictionServiceClient,
)
from google.cloud.aiplatform.models import Prediction
from google.protobuf.struct_pb2 import Value
from vertexai.language_models._language_models import (
TextGenerationResponse,
_LanguageModel,
)
from vertexai.preview.generative_models import Image
# This is for backwards compatibility
# We can remove after `langchain` stops importing it
_response_to_generation = None
completion_with_retry = None
stream_completion_with_retry = None
[docs]def is_codey_model(model_name: str) -> bool:
"""如果模型名称是Codey模型,则返回True。"""
return "code" in model_name
[docs]def is_gemini_model(model_name: str) -> bool:
"""如果模型名称是Gemini模型,则返回True。"""
return model_name is not None and "gemini" in model_name
[docs]def completion_with_retry( # type: ignore[no-redef]
llm: VertexAI,
prompt: List[Union[str, "Image"]],
stream: bool = False,
is_gemini: bool = False,
run_manager: Optional[CallbackManagerForLLMRun] = None,
**kwargs: Any,
) -> Any:
"""使用tenacity来重试完成调用。"""
retry_decorator = create_retry_decorator(llm, run_manager=run_manager)
@retry_decorator
def _completion_with_retry(
prompt: List[Union[str, "Image"]], is_gemini: bool = False, **kwargs: Any
) -> Any:
if is_gemini:
return llm.client.generate_content(
prompt, stream=stream, generation_config=kwargs
)
else:
if stream:
return llm.client.predict_streaming(prompt[0], **kwargs)
return llm.client.predict(prompt[0], **kwargs)
return _completion_with_retry(prompt, is_gemini, **kwargs)
[docs]async def acompletion_with_retry(
llm: VertexAI,
prompt: str,
is_gemini: bool = False,
run_manager: Optional[AsyncCallbackManagerForLLMRun] = None,
**kwargs: Any,
) -> Any:
"""使用tenacity来重试完成调用。"""
retry_decorator = create_retry_decorator(llm, run_manager=run_manager)
@retry_decorator
async def _acompletion_with_retry(
prompt: str, is_gemini: bool = False, **kwargs: Any
) -> Any:
if is_gemini:
return await llm.client.generate_content_async(
prompt, generation_config=kwargs
)
return await llm.client.predict_async(prompt, **kwargs)
return await _acompletion_with_retry(prompt, is_gemini, **kwargs)
class _VertexAIBase(BaseModel):
project: Optional[str] = None
"在进行Vertex API调用时要使用的默认GCP项目。"
location: str = "us-central1"
"用于进行API调用时使用的默认位置。"
request_parallelism: int = 5
"允许发送给VertexAI模型的请求的并行度。"
"默认值是5。"
max_retries: int = 6
"""生成时最大重试次数。"""
task_executor: ClassVar[Optional[Executor]] = Field(default=None, exclude=True)
stop: Optional[List[str]] = None
"生成时使用的可选停用词列表。"
model_name: Optional[str] = None
"基础模型名称。"
@classmethod
def _get_task_executor(cls, request_parallelism: int = 5) -> Executor:
if cls.task_executor is None:
cls.task_executor = ThreadPoolExecutor(max_workers=request_parallelism)
return cls.task_executor
class _VertexAICommon(_VertexAIBase):
client: "_LanguageModel" = None #: :meta private:
client_preview: "_LanguageModel" = None #: :meta private:
model_name: str
"基础模型名称。"
temperature: float = 0.0
"采样温度,它控制了在标记选择中的随机程度。"
max_output_tokens: int = 128
"Token limit确定了一个提示输出的文本的最大数量。"
top_p: float = 0.95
"令牌按从最有可能到最不可能的顺序选择,直到它们的总和达到"
"概率等于前p值。对于Codey模型,忽略了Top-p。"
top_k: int = 40
"模型如何选择输出的标记,下一个标记是从"
"在最有可能的前k个标记中。对于Codey模型,忽略了Top-k。"
credentials: Any = Field(default=None, exclude=True)
"要使用的默认自定义凭据(google.auth.credentials.Credentials)"
"在进行API调用时。如果未提供,将从凭据中确定。"
"环境。"
n: int = 1
"""每个提示生成多少个完成。"""
streaming: bool = False
"""是否要流式传输结果。"""
@property
def _llm_type(self) -> str:
return "vertexai"
@property
def is_codey_model(self) -> bool:
return is_codey_model(self.model_name)
@property
def _is_gemini_model(self) -> bool:
return is_gemini_model(self.model_name)
@property
def _identifying_params(self) -> Dict[str, Any]:
"""获取识别参数。"""
return {**{"model_name": self.model_name}, **self._default_params}
@property
def _default_params(self) -> Dict[str, Any]:
params = {
"temperature": self.temperature,
"max_output_tokens": self.max_output_tokens,
"candidate_count": self.n,
}
if not self.is_codey_model:
params.update(
{
"top_k": self.top_k,
"top_p": self.top_p,
}
)
return params
@classmethod
def _try_init_vertexai(cls, values: Dict) -> None:
allowed_params = ["project", "location", "credentials"]
params = {k: v for k, v in values.items() if k in allowed_params}
init_vertexai(**params)
return None
def _prepare_params(
self,
stop: Optional[List[str]] = None,
stream: bool = False,
**kwargs: Any,
) -> dict:
stop_sequences = stop or self.stop
params_mapping = {"n": "candidate_count"}
params = {params_mapping.get(k, k): v for k, v in kwargs.items()}
params = {**self._default_params, "stop_sequences": stop_sequences, **params}
if stream or self.streaming:
params.pop("candidate_count")
return params
[docs]@deprecated(
since="0.0.12",
removal="0.3.0",
alternative_import="langchain_google_vertexai.VertexAI",
)
class VertexAI(_VertexAICommon, BaseLLM):
"""谷歌Vertex AI大型语言模型。"""
model_name: str = "text-bison"
"Vertex AI大型语言模型的名称。"
tuned_model_name: Optional[str] = None
"调整后的模型名称。如果提供了该参数,则会忽略model_name。"
[docs] @classmethod
def is_lc_serializable(self) -> bool:
return True
[docs] @classmethod
def get_lc_namespace(cls) -> List[str]:
"""获取langchain对象的命名空间。"""
return ["langchain", "llms", "vertexai"]
@root_validator()
def validate_environment(cls, values: Dict) -> Dict:
"""验证Python包是否存在于环境中。"""
tuned_model_name = values.get("tuned_model_name")
model_name = values["model_name"]
is_gemini = is_gemini_model(values["model_name"])
cls._try_init_vertexai(values)
try:
from vertexai.language_models import (
CodeGenerationModel,
TextGenerationModel,
)
from vertexai.preview.language_models import (
CodeGenerationModel as PreviewCodeGenerationModel,
)
from vertexai.preview.language_models import (
TextGenerationModel as PreviewTextGenerationModel,
)
if is_gemini:
from vertexai.preview.generative_models import (
GenerativeModel,
)
if is_codey_model(model_name):
model_cls = CodeGenerationModel
preview_model_cls = PreviewCodeGenerationModel
elif is_gemini:
model_cls = GenerativeModel
preview_model_cls = GenerativeModel
else:
model_cls = TextGenerationModel
preview_model_cls = PreviewTextGenerationModel
if tuned_model_name:
values["client"] = model_cls.get_tuned_model(tuned_model_name)
values["client_preview"] = preview_model_cls.get_tuned_model(
tuned_model_name
)
else:
if is_gemini:
values["client"] = model_cls(model_name=model_name)
values["client_preview"] = preview_model_cls(model_name=model_name)
else:
values["client"] = model_cls.from_pretrained(model_name)
values["client_preview"] = preview_model_cls.from_pretrained(
model_name
)
except ImportError:
raise_vertex_import_error()
if values["streaming"] and values["n"] > 1:
raise ValueError("Only one candidate can be generated with streaming!")
return values
[docs] def get_num_tokens(self, text: str) -> int:
"""获取文本中存在的标记数。
用于检查输入是否适合模型的上下文窗口。
参数:
text:要标记化的字符串输入。
返回:
文本中标记的整数数量。
"""
try:
result = self.client_preview.count_tokens([text])
except AttributeError:
raise_vertex_import_error()
return result.total_tokens
def _response_to_generation(
self, response: TextGenerationResponse
) -> GenerationChunk:
"""将流响应转换为生成的块。"""
try:
generation_info = {
"is_blocked": response.is_blocked,
"safety_attributes": response.safety_attributes,
}
except Exception:
generation_info = None
return GenerationChunk(text=response.text, generation_info=generation_info)
def _generate(
self,
prompts: List[str],
stop: Optional[List[str]] = None,
run_manager: Optional[CallbackManagerForLLMRun] = None,
stream: Optional[bool] = None,
**kwargs: Any,
) -> LLMResult:
should_stream = stream if stream is not None else self.streaming
params = self._prepare_params(stop=stop, stream=should_stream, **kwargs)
generations: List[List[Generation]] = []
for prompt in prompts:
if should_stream:
generation = GenerationChunk(text="")
for chunk in self._stream(
prompt, stop=stop, run_manager=run_manager, **kwargs
):
generation += chunk
generations.append([generation])
else:
res = completion_with_retry( # type: ignore[misc]
self,
[prompt],
stream=should_stream,
is_gemini=self._is_gemini_model,
run_manager=run_manager,
**params,
)
generations.append(
[self._response_to_generation(r) for r in res.candidates]
)
return LLMResult(generations=generations)
async def _agenerate(
self,
prompts: List[str],
stop: Optional[List[str]] = None,
run_manager: Optional[AsyncCallbackManagerForLLMRun] = None,
**kwargs: Any,
) -> LLMResult:
params = self._prepare_params(stop=stop, **kwargs)
generations = []
for prompt in prompts:
res = await acompletion_with_retry(
self,
prompt,
is_gemini=self._is_gemini_model,
run_manager=run_manager,
**params,
)
generations.append(
[self._response_to_generation(r) for r in res.candidates]
)
return LLMResult(generations=generations) # type: ignore[arg-type]
def _stream(
self,
prompt: str,
stop: Optional[List[str]] = None,
run_manager: Optional[CallbackManagerForLLMRun] = None,
**kwargs: Any,
) -> Iterator[GenerationChunk]:
params = self._prepare_params(stop=stop, stream=True, **kwargs)
for stream_resp in completion_with_retry( # type: ignore[misc]
self,
[prompt],
stream=True,
is_gemini=self._is_gemini_model,
run_manager=run_manager,
**params,
):
chunk = self._response_to_generation(stream_resp)
if run_manager:
run_manager.on_llm_new_token(
chunk.text,
chunk=chunk,
verbose=self.verbose,
)
yield chunk
[docs]@deprecated(
since="0.0.12",
removal="0.3.0",
alternative_import="langchain_google_vertexai.VertexAIModelGarden",
)
class VertexAIModelGarden(_VertexAIBase, BaseLLM):
"""Vertex AI模型花园大型语言模型。"""
client: "PredictionServiceClient" = None #: :meta private:
async_client: "PredictionServiceAsyncClient" = None #: :meta private:
endpoint_id: str
"模型部署的端点名称。"
allowed_model_args: Optional[List[str]] = None
"允许将可选参数传递给模型。"
prompt_arg: str = "prompt"
result_arg: Optional[str] = "generated_text"
"如果模型的输出预期是一个字符串,则将result_arg设置为None。"
"否则,如果它是一个字典,则提供一个包含结果的参数。"
@root_validator()
def validate_environment(cls, values: Dict) -> Dict:
"""验证Python包是否存在于环境中。"""
try:
from google.api_core.client_options import ClientOptions
from google.cloud.aiplatform.gapic import (
PredictionServiceAsyncClient,
PredictionServiceClient,
)
except ImportError:
raise_vertex_import_error()
if not values["project"]:
raise ValueError(
"A GCP project should be provided to run inference on Model Garden!"
)
client_options = ClientOptions(
api_endpoint=f"{values['location']}-aiplatform.googleapis.com"
)
client_info = get_client_info(module="vertex-ai-model-garden")
values["client"] = PredictionServiceClient(
client_options=client_options, client_info=client_info
)
values["async_client"] = PredictionServiceAsyncClient(
client_options=client_options, client_info=client_info
)
return values
@property
def endpoint_path(self) -> str:
return self.client.endpoint_path(
project=self.project,
location=self.location,
endpoint=self.endpoint_id,
)
@property
def _llm_type(self) -> str:
return "vertexai_model_garden"
def _prepare_request(self, prompts: List[str], **kwargs: Any) -> List["Value"]:
try:
from google.protobuf import json_format
from google.protobuf.struct_pb2 import Value
except ImportError:
raise ImportError(
"protobuf package not found, please install it with"
" `pip install protobuf`"
)
instances = []
for prompt in prompts:
if self.allowed_model_args:
instance = {
k: v for k, v in kwargs.items() if k in self.allowed_model_args
}
else:
instance = {}
instance[self.prompt_arg] = prompt
instances.append(instance)
predict_instances = [
json_format.ParseDict(instance_dict, Value()) for instance_dict in instances
]
return predict_instances
def _generate(
self,
prompts: List[str],
stop: Optional[List[str]] = None,
run_manager: Optional[CallbackManagerForLLMRun] = None,
**kwargs: Any,
) -> LLMResult:
"""在给定的提示和输入上运行LLM。"""
instances = self._prepare_request(prompts, **kwargs)
response = self.client.predict(endpoint=self.endpoint_path, instances=instances)
return self._parse_response(response)
def _parse_response(self, predictions: "Prediction") -> LLMResult:
generations: List[List[Generation]] = []
for result in predictions.predictions:
generations.append(
[
Generation(text=self._parse_prediction(prediction))
for prediction in result
]
)
return LLMResult(generations=generations)
def _parse_prediction(self, prediction: Any) -> str:
if isinstance(prediction, str):
return prediction
if self.result_arg:
try:
return prediction[self.result_arg]
except KeyError:
if isinstance(prediction, str):
error_desc = (
"Provided non-None `result_arg` (result_arg="
f"{self.result_arg}). But got prediction of type "
f"{type(prediction)} instead of dict. Most probably, you"
"need to set `result_arg=None` during VertexAIModelGarden "
"initialization."
)
raise ValueError(error_desc)
else:
raise ValueError(f"{self.result_arg} key not found in prediction!")
return prediction
async def _agenerate(
self,
prompts: List[str],
stop: Optional[List[str]] = None,
run_manager: Optional[AsyncCallbackManagerForLLMRun] = None,
**kwargs: Any,
) -> LLMResult:
"""在给定的提示和输入上运行LLM。"""
instances = self._prepare_request(prompts, **kwargs)
response = await self.async_client.predict(
endpoint=self.endpoint_path, instances=instances
)
return self._parse_response(response)