Source code for langchain_community.llms.writer

from typing import Any, Dict, List, Mapping, Optional

import requests
from langchain_core.callbacks import CallbackManagerForLLMRun
from langchain_core.language_models.llms import LLM
from langchain_core.pydantic_v1 import Extra, root_validator
from langchain_core.utils import get_from_dict_or_env

from langchain_community.llms.utils import enforce_stop_tokens


[docs]class Writer(LLM): """Writer大型语言模型。 要使用,您应该设置环境变量``WRITER_API_KEY``和``WRITER_ORG_ID``,分别使用您的API密钥和组织ID。 示例: .. code-block:: python from langchain_community.llms import Writer writer = Writer(model_id="palmyra-base") """ writer_org_id: Optional[str] = None """写入者组织ID。""" model_id: str = "palmyra-instruct" """要使用的模型名称。""" min_tokens: Optional[int] = None """生成所需的最小令牌数量。""" max_tokens: Optional[int] = None """生成的令牌的最大数量。""" temperature: Optional[float] = None """使用哪种采样温度。""" top_p: Optional[float] = None """每一步需要考虑的标记的总概率质量。""" stop: Optional[List[str]] = None """完成生成时序列将停止。""" presence_penalty: Optional[float] = None """无论频率如何,都会惩罚重复的标记。""" repetition_penalty: Optional[float] = None """根据频率惩罚重复的标记。""" best_of: Optional[int] = None """在服务器端生成这么多完成并返回“最佳”。""" logprobs: bool = False """是否返回对数概率。""" n: Optional[int] = None """生成多少个完成。""" writer_api_key: Optional[str] = None """写入API密钥。""" base_url: Optional[str] = None """基础URL的使用,如果为None,则根据模型名称决定。""" class Config: """此pydantic对象的配置。""" extra = Extra.forbid @root_validator() def validate_environment(cls, values: Dict) -> Dict: """验证环境中是否存在API密钥和组织ID。""" writer_api_key = get_from_dict_or_env( values, "writer_api_key", "WRITER_API_KEY" ) values["writer_api_key"] = writer_api_key writer_org_id = get_from_dict_or_env(values, "writer_org_id", "WRITER_ORG_ID") values["writer_org_id"] = writer_org_id return values @property def _default_params(self) -> Mapping[str, Any]: """获取调用Writer API的默认参数。""" return { "minTokens": self.min_tokens, "maxTokens": self.max_tokens, "temperature": self.temperature, "topP": self.top_p, "stop": self.stop, "presencePenalty": self.presence_penalty, "repetitionPenalty": self.repetition_penalty, "bestOf": self.best_of, "logprobs": self.logprobs, "n": self.n, } @property def _identifying_params(self) -> Mapping[str, Any]: """获取识别参数。""" return { **{"model_id": self.model_id, "writer_org_id": self.writer_org_id}, **self._default_params, } @property def _llm_type(self) -> str: """llm的返回类型。""" return "writer" def _call( self, prompt: str, stop: Optional[List[str]] = None, run_manager: Optional[CallbackManagerForLLMRun] = None, **kwargs: Any, ) -> str: """调用Writer的完成端点。 参数: prompt: 传递给模型的提示。 stop: 生成时可选的停止词列表。 返回: 模型生成的字符串。 示例: .. code-block:: python response = Writer("Tell me a joke.") """ if self.base_url is not None: base_url = self.base_url else: base_url = ( "https://enterprise-api.writer.com/llm" f"/organization/{self.writer_org_id}" f"/model/{self.model_id}/completions" ) params = {**self._default_params, **kwargs} response = requests.post( url=base_url, headers={ "Authorization": f"{self.writer_api_key}", "Content-Type": "application/json", "Accept": "application/json", }, json={"prompt": prompt, **params}, ) text = response.text if stop is not None: # I believe this is required since the stop tokens # are not enforced by the model parameters text = enforce_stop_tokens(text, stop) return text