"""封装了Together AI的完成API。"""
import logging
from typing import Any, Dict, List, Optional
from aiohttp import ClientSession
from langchain_core._api.deprecation import deprecated
from langchain_core.callbacks import (
AsyncCallbackManagerForLLMRun,
CallbackManagerForLLMRun,
)
from langchain_core.language_models.llms import LLM
from langchain_core.pydantic_v1 import Extra, SecretStr, root_validator
from langchain_core.utils import convert_to_secret_str, get_from_dict_or_env
from langchain_community.utilities.requests import Requests
logger = logging.getLogger(__name__)
[docs]@deprecated(
since="0.0.12", removal="0.3", alternative_import="langchain_together.Together"
)
class Together(LLM):
"""LLM模型来自`Together`。
要使用,您需要一个API密钥,您可以在这里找到:https://api.together.xyz/settings/api-keys。这可以作为init参数传递,``together_api_key``或设置为环境变量``TOGETHER_API_KEY``。
Together AI API参考文档:https://docs.together.ai/reference/inference"""
base_url: str = "https://api.together.xyz/inference"
"""基础推断API URL。"""
together_api_key: SecretStr
"""Together AI API密钥。在此处获取:https://api.together.xyz/settings/api-keys"""
model: str
"""模型名称。可用模型列在这里:https://docs.together.ai/docs/inference-models"""
temperature: Optional[float] = None
"""模型温度。"""
top_p: Optional[float] = None
"""用于根据累积概率动态调整每个预测标记的选择数量。 值为1将始终产生相同的输出。 小于1的温度更有利于更多的正确性,并适用于问答或摘要。 大于1的值会在输出中引入更多的随机性。"""
top_k: Optional[int] = None
"""用于限制下一个预测单词或标记的选择数量。它指定在每一步考虑的最大标记数量,基于它们出现的概率。这种技术有助于加快生成过程,并通过专注于最有可能的选项来提高生成文本的质量。"""
max_tokens: Optional[int] = None
"""生成的最大令牌数量。"""
repetition_penalty: Optional[float] = None
"""控制生成文本多样性的数字,通过减少重复序列的可能性。较高的值会减少重复。"""
logprobs: Optional[int] = None
"""一个整数,指定在每个令牌生成步骤中包含多少个顶部令牌对数概率在响应中。"""
class Config:
"""此pydantic对象的配置。"""
extra = Extra.forbid
@root_validator(pre=True)
def validate_environment(cls, values: Dict) -> Dict:
"""验证环境中是否存在API密钥。"""
values["together_api_key"] = convert_to_secret_str(
get_from_dict_or_env(values, "together_api_key", "TOGETHER_API_KEY")
)
return values
@property
def _llm_type(self) -> str:
"""模型的返回类型。"""
return "together"
def _format_output(self, output: dict) -> str:
return output["output"]["choices"][0]["text"]
[docs] @staticmethod
def get_user_agent() -> str:
from langchain_community import __version__
return f"langchain/{__version__}"
@property
def default_params(self) -> Dict[str, Any]:
return {
"model": self.model,
"temperature": self.temperature,
"top_p": self.top_p,
"top_k": self.top_k,
"max_tokens": self.max_tokens,
"repetition_penalty": self.repetition_penalty,
}
def _call(
self,
prompt: str,
stop: Optional[List[str]] = None,
run_manager: Optional[CallbackManagerForLLMRun] = None,
**kwargs: Any,
) -> str:
"""调用Together的文本生成端点。
参数:
prompt: 传递给模型的提示。
返回:
模型生成的字符串。
"""
headers = {
"Authorization": f"Bearer {self.together_api_key.get_secret_value()}",
"Content-Type": "application/json",
}
stop_to_use = stop[0] if stop and len(stop) == 1 else stop
payload: Dict[str, Any] = {
**self.default_params,
"prompt": prompt,
"stop": stop_to_use,
**kwargs,
}
# filter None values to not pass them to the http payload
payload = {k: v for k, v in payload.items() if v is not None}
request = Requests(headers=headers)
response = request.post(url=self.base_url, data=payload)
if response.status_code >= 500:
raise Exception(f"Together Server: Error {response.status_code}")
elif response.status_code >= 400:
raise ValueError(f"Together received an invalid payload: {response.text}")
elif response.status_code != 200:
raise Exception(
f"Together returned an unexpected response with status "
f"{response.status_code}: {response.text}"
)
data = response.json()
if data.get("status") != "finished":
err_msg = data.get("error", "Undefined Error")
raise Exception(err_msg)
output = self._format_output(data)
return output
async def _acall(
self,
prompt: str,
stop: Optional[List[str]] = None,
run_manager: Optional[AsyncCallbackManagerForLLMRun] = None,
**kwargs: Any,
) -> str:
"""调用Together模型根据提示获取预测。
参数:
prompt: 传递给模型的提示。
返回值:
模型生成的字符串。
"""
headers = {
"Authorization": f"Bearer {self.together_api_key.get_secret_value()}",
"Content-Type": "application/json",
}
stop_to_use = stop[0] if stop and len(stop) == 1 else stop
payload: Dict[str, Any] = {
**self.default_params,
"prompt": prompt,
"stop": stop_to_use,
**kwargs,
}
# filter None values to not pass them to the http payload
payload = {k: v for k, v in payload.items() if v is not None}
async with ClientSession() as session:
async with session.post(
self.base_url, json=payload, headers=headers
) as response:
if response.status >= 500:
raise Exception(f"Together Server: Error {response.status}")
elif response.status >= 400:
raise ValueError(
f"Together received an invalid payload: {response.text}"
)
elif response.status != 200:
raise Exception(
f"Together returned an unexpected response with status "
f"{response.status}: {response.text}"
)
response_json = await response.json()
if response_json.get("status") != "finished":
err_msg = response_json.get("error", "Undefined Error")
raise Exception(err_msg)
output = self._format_output(response_json)
return output