Source code for langchain_community.llms.ai21

from typing import Any, Dict, List, Optional, cast

import requests
from langchain_core.callbacks import CallbackManagerForLLMRun
from langchain_core.language_models.llms import LLM
from langchain_core.pydantic_v1 import BaseModel, Extra, SecretStr, root_validator
from langchain_core.utils import convert_to_secret_str, get_from_dict_or_env


[docs]class AI21PenaltyData(BaseModel): """AI21惩罚数据的参数。""" scale: int = 0 applyToWhitespaces: bool = True applyToPunctuations: bool = True applyToNumbers: bool = True applyToStopwords: bool = True applyToEmojis: bool = True
[docs]class AI21(LLM): """AI21大型语言模型。 要使用,您应该设置环境变量``AI21_API_KEY`` 为您的API密钥,或将其作为命名参数传递给构造函数。 示例: .. code-block:: python from langchain_community.llms import AI21 ai21 = AI21(ai21_api_key="my-api-key", model="j2-jumbo-instruct") """ model: str = "j2-jumbo-instruct" """要使用的模型名称。""" temperature: float = 0.7 """使用哪种采样温度。""" maxTokens: int = 256 """生成完成的最大令牌数。""" minTokens: int = 0 """生成完成所需的最小令牌数量。""" topP: float = 1.0 """每一步需要考虑的标记的总概率质量。""" presencePenalty: AI21PenaltyData = AI21PenaltyData() """惩罚重复的标记。""" countPenalty: AI21PenaltyData = AI21PenaltyData() """根据计数对重复的标记进行惩罚。""" frequencyPenalty: AI21PenaltyData = AI21PenaltyData() """根据频率惩罚重复的标记。""" numResults: int = 1 """每个提示生成多少个完成。""" logitBias: Optional[Dict[str, float]] = None """调整生成特定令牌的概率。""" ai21_api_key: Optional[SecretStr] = None stop: Optional[List[str]] = None base_url: Optional[str] = None """基础URL的使用,如果为None,则根据模型名称决定。""" class Config: """此pydantic对象的配置。""" extra = Extra.forbid @root_validator() def validate_environment(cls, values: Dict) -> Dict: """验证环境中是否存在API密钥。""" ai21_api_key = convert_to_secret_str( get_from_dict_or_env(values, "ai21_api_key", "AI21_API_KEY") ) values["ai21_api_key"] = ai21_api_key return values @property def _default_params(self) -> Dict[str, Any]: """获取调用AI21 API的默认参数。""" return { "temperature": self.temperature, "maxTokens": self.maxTokens, "minTokens": self.minTokens, "topP": self.topP, "presencePenalty": self.presencePenalty.dict(), "countPenalty": self.countPenalty.dict(), "frequencyPenalty": self.frequencyPenalty.dict(), "numResults": self.numResults, "logitBias": self.logitBias, } @property def _identifying_params(self) -> Dict[str, Any]: """获取识别参数。""" return {**{"model": self.model}, **self._default_params} @property def _llm_type(self) -> str: """llm的返回类型。""" return "ai21" def _call( self, prompt: str, stop: Optional[List[str]] = None, run_manager: Optional[CallbackManagerForLLMRun] = None, **kwargs: Any, ) -> str: """调用AI21的完整端点。 参数: prompt: 传递给模型的提示。 stop: 生成时使用的可选停止词列表。 返回: 模型生成的字符串。 示例: .. code-block:: python response = ai21("告诉我一个笑话。") """ if self.stop is not None and stop is not None: raise ValueError("`stop` found in both the input and default params.") elif self.stop is not None: stop = self.stop elif stop is None: stop = [] if self.base_url is not None: base_url = self.base_url else: if self.model in ("j1-grande-instruct",): base_url = "https://api.ai21.com/studio/v1/experimental" else: base_url = "https://api.ai21.com/studio/v1" params = {**self._default_params, **kwargs} self.ai21_api_key = cast(SecretStr, self.ai21_api_key) response = requests.post( url=f"{base_url}/{self.model}/complete", headers={"Authorization": f"Bearer {self.ai21_api_key.get_secret_value()}"}, json={"prompt": prompt, "stopSequences": stop, **params}, ) if response.status_code != 200: optional_detail = response.json().get("error") raise ValueError( f"AI21 /complete call failed with status code {response.status_code}." f" Details: {optional_detail}" ) response_json = response.json() return response_json["completions"][0]["data"]["text"]