Source code for langchain_community.llms.google_palm

from __future__ import annotations

from typing import Any, Dict, Iterator, List, Optional

from langchain_core._api.deprecation import deprecated
from langchain_core.callbacks import CallbackManagerForLLMRun
from langchain_core.language_models import LanguageModelInput
from langchain_core.outputs import Generation, GenerationChunk, LLMResult
from langchain_core.pydantic_v1 import BaseModel, SecretStr, root_validator
from langchain_core.utils import get_from_dict_or_env

from langchain_community.llms import BaseLLM
from langchain_community.utilities.vertexai import create_retry_decorator


[docs]def completion_with_retry( llm: GooglePalm, prompt: LanguageModelInput, is_gemini: bool = False, stream: bool = False, run_manager: Optional[CallbackManagerForLLMRun] = None, **kwargs: Any, ) -> Any: """使用tenacity来重试完成调用。""" retry_decorator = create_retry_decorator( llm, max_retries=llm.max_retries, run_manager=run_manager ) @retry_decorator def _completion_with_retry( prompt: LanguageModelInput, is_gemini: bool, stream: bool, **kwargs: Any ) -> Any: generation_config = kwargs.get("generation_config", {}) if is_gemini: return llm.client.generate_content( contents=prompt, stream=stream, generation_config=generation_config ) return llm.client.generate_text(prompt=prompt, **kwargs) return _completion_with_retry( prompt=prompt, is_gemini=is_gemini, stream=stream, **kwargs )
def _is_gemini_model(model_name: str) -> bool: return "gemini" in model_name def _strip_erroneous_leading_spaces(text: str) -> str: """从文本中删除错误的前导空格。 PaLM API有时会在所有大于1的行中错误地返回一个单个前导空格。此函数将删除该空格。 """ has_leading_space = all(not line or line[0] == " " for line in text.split("\n")[1:]) if has_leading_space: return text.replace("\n ", "\n") else: return text
[docs]@deprecated("0.0.12", alternative_import="langchain_google_genai.GoogleGenerativeAI") class GooglePalm(BaseLLM, BaseModel): """已弃用:请改用`langchain_google_genai.GoogleGenerativeAI`。 Google PaLM模型。""" client: Any #: :meta private: google_api_key: Optional[SecretStr] model_name: str = "models/text-bison-001" """要使用的模型名称。""" temperature: float = 0.7 """使用这个温度进行推断。必须在闭区间[0.0, 1.0]内。""" top_p: Optional[float] = None """使用核采样解码:考虑概率总和至少为top_p的最小标记集。必须在闭区间[0.0, 1.0]内。""" top_k: Optional[int] = None """使用top-k抽样解码:考虑前k个最有可能的标记集合。 必须为正数。""" max_output_tokens: Optional[int] = None """候选项中包含的最大令牌数。必须大于零。 如果未设置,将默认为64。""" n: int = 1 """每个提示生成的聊天完成次数。请注意,如果生成了重复项,API可能不会返回完整的n个完成项。""" max_retries: int = 6 """生成时最大重试次数。""" @property def is_gemini(self) -> bool: """返回模型是否属于Gemini家族。""" return _is_gemini_model(self.model_name) @property def lc_secrets(self) -> Dict[str, str]: return {"google_api_key": "GOOGLE_API_KEY"}
[docs] @classmethod def is_lc_serializable(self) -> bool: return True
[docs] @classmethod def get_lc_namespace(cls) -> List[str]: """获取langchain对象的命名空间。""" return ["langchain", "llms", "google_palm"]
@root_validator() def validate_environment(cls, values: Dict) -> Dict: """验证API密钥,Python包是否存在。""" google_api_key = get_from_dict_or_env( values, "google_api_key", "GOOGLE_API_KEY" ) model_name = values["model_name"] try: import google.generativeai as genai if isinstance(google_api_key, SecretStr): google_api_key = google_api_key.get_secret_value() genai.configure(api_key=google_api_key) if _is_gemini_model(model_name): values["client"] = genai.GenerativeModel(model_name=model_name) else: values["client"] = genai except ImportError: raise ImportError( "Could not import google-generativeai python package. " "Please install it with `pip install google-generativeai`." ) if values["temperature"] is not None and not 0 <= values["temperature"] <= 1: raise ValueError("temperature must be in the range [0.0, 1.0]") if values["top_p"] is not None and not 0 <= values["top_p"] <= 1: raise ValueError("top_p must be in the range [0.0, 1.0]") if values["top_k"] is not None and values["top_k"] <= 0: raise ValueError("top_k must be positive") if values["max_output_tokens"] is not None and values["max_output_tokens"] <= 0: raise ValueError("max_output_tokens must be greater than zero") return values def _generate( self, prompts: List[str], stop: Optional[List[str]] = None, run_manager: Optional[CallbackManagerForLLMRun] = None, **kwargs: Any, ) -> LLMResult: generations: List[List[Generation]] = [] generation_config = { "stop_sequences": stop, "temperature": self.temperature, "top_p": self.top_p, "top_k": self.top_k, "max_output_tokens": self.max_output_tokens, "candidate_count": self.n, } for prompt in prompts: if self.is_gemini: res = completion_with_retry( self, prompt=prompt, stream=False, is_gemini=True, run_manager=run_manager, generation_config=generation_config, ) candidates = [ "".join([p.text for p in c.content.parts]) for c in res.candidates ] generations.append([Generation(text=c) for c in candidates]) else: res = completion_with_retry( self, model=self.model_name, prompt=prompt, stream=False, is_gemini=False, run_manager=run_manager, **generation_config, ) prompt_generations = [] for candidate in res.candidates: raw_text = candidate["output"] stripped_text = _strip_erroneous_leading_spaces(raw_text) prompt_generations.append(Generation(text=stripped_text)) generations.append(prompt_generations) return LLMResult(generations=generations) def _stream( self, prompt: str, stop: Optional[List[str]] = None, run_manager: Optional[CallbackManagerForLLMRun] = None, **kwargs: Any, ) -> Iterator[GenerationChunk]: generation_config = kwargs.get("generation_config", {}) if stop: generation_config["stop_sequences"] = stop for stream_resp in completion_with_retry( self, prompt, stream=True, is_gemini=True, run_manager=run_manager, generation_config=generation_config, **kwargs, ): chunk = GenerationChunk(text=stream_resp.text) yield chunk if run_manager: run_manager.on_llm_new_token( stream_resp.text, chunk=chunk, verbose=self.verbose, ) @property def _llm_type(self) -> str: """llm的返回类型。""" return "google_palm"
[docs] def get_num_tokens(self, text: str) -> int: """获取文本中存在的标记数。 用于检查输入是否适合模型的上下文窗口。 参数: text:要标记化的字符串输入。 返回: 文本中标记的整数数量。 """ if self.is_gemini: raise ValueError("Counting tokens is not yet supported!") result = self.client.count_text_tokens(model=self.model_name, prompt=text) return result["token_count"]