import json
from typing import Any, Dict, Generator, Iterator, List, Optional, Union
import requests
from langchain_core.callbacks.manager import CallbackManagerForLLMRun
from langchain_core.language_models.llms import LLM
from langchain_core.outputs import GenerationChunk
from langchain_core.pydantic_v1 import Extra, root_validator
from langchain_core.utils import get_from_dict_or_env
[docs]class SVEndpointHandler:
"""SambaNova Systems接口用于Sambaverse端点。
:param str host_url: DaaS API服务的基本URL"""
API_BASE_PATH = "/api/predict"
[docs] def __init__(self, host_url: str):
"""初始化SVEndpointHandler。
:param str host_url: DaaS API 服务的基本URL
"""
self.host_url = host_url
self.http_session = requests.Session()
@staticmethod
def _process_response(response: requests.Response) -> Dict:
"""处理API响应并返回结果字典。
所有的结果字典,无论成功与否,都将包含`status_code`键与API响应状态码。
如果API返回错误,结果字典将包含键`detail`与错误消息。
如果API调用成功,结果字典将包含键`data`与响应数据。
:param requests.Response response: 要处理的响应对象
:return: 响应字典
:rtype: dict
"""
result: Dict[str, Any] = {}
try:
lines_result = response.text.strip().split("\n")
text_result = lines_result[-1]
if response.status_code == 200 and json.loads(text_result).get("error"):
completion = ""
for line in lines_result[:-1]:
completion += json.loads(line)["result"]["responses"][0][
"stream_token"
]
text_result = lines_result[-2]
result = json.loads(text_result)
result["result"]["responses"][0]["completion"] = completion
else:
result = json.loads(text_result)
except Exception as e:
result["detail"] = str(e)
if "status_code" not in result:
result["status_code"] = response.status_code
return result
@staticmethod
def _process_streaming_response(
response: requests.Response,
) -> Generator[Dict, None, None]:
"""处理流式响应"""
try:
for line in response.iter_lines():
chunk = json.loads(line)
if "status_code" not in chunk:
chunk["status_code"] = response.status_code
if chunk["status_code"] == 200 and chunk.get("error"):
chunk["result"] = {"responses": [{"stream_token": ""}]}
return chunk
yield chunk
except Exception as e:
raise RuntimeError(f"Error processing streaming response: {e}")
def _get_full_url(self) -> str:
"""返回给定路径的完整API URL。
:返回:子路径的完整API URL
:类型:str
"""
return f"{self.host_url}{self.API_BASE_PATH}"
[docs] def nlp_predict(
self,
key: str,
sambaverse_model_name: Optional[str],
input: Union[List[str], str],
params: Optional[str] = "",
stream: bool = False,
) -> Dict:
"""使用内联输入字符串进行自然语言处理预测。
:param str project: 存在端点的项目ID
:param str endpoint: 端点ID
:param str key: API密钥
:param str input_str: 输入字符串
:param str params: 输入参数字符串
:returns: 预测结果
:rtype: dict
"""
parsed_element = {
"conversation_id": "sambaverse-conversation-id",
"messages": [
{
"message_id": 0,
"role": "user",
"content": input,
}
],
}
parsed_input = json.dumps(parsed_element)
if params:
data = {"instance": parsed_input, "params": json.loads(params)}
else:
data = {"instance": parsed_input}
response = self.http_session.post(
self._get_full_url(),
headers={
"key": key,
"Content-Type": "application/json",
"modelName": sambaverse_model_name,
},
json=data,
)
return SVEndpointHandler._process_response(response)
[docs] def nlp_predict_stream(
self,
key: str,
sambaverse_model_name: Optional[str],
input: Union[List[str], str],
params: Optional[str] = "",
) -> Iterator[Dict]:
"""使用内联输入字符串进行自然语言处理预测。
:param str project: 存在端点的项目ID
:param str endpoint: 端点ID
:param str key: API密钥
:param str input_str: 输入字符串
:param str params: 输入参数字符串
:returns: 预测结果
:rtype: dict
"""
parsed_element = {
"conversation_id": "sambaverse-conversation-id",
"messages": [
{
"message_id": 0,
"role": "user",
"content": input,
}
],
}
parsed_input = json.dumps(parsed_element)
if params:
data = {"instance": parsed_input, "params": json.loads(params)}
else:
data = {"instance": parsed_input}
# Streaming output
response = self.http_session.post(
self._get_full_url(),
headers={
"key": key,
"Content-Type": "application/json",
"modelName": sambaverse_model_name,
},
json=data,
stream=True,
)
for chunk in SVEndpointHandler._process_streaming_response(response):
yield chunk
[docs]class Sambaverse(LLM):
""" Sambaverse大型语言模型。
要使用,您应该设置环境变量``SAMBAVERSE_API_KEY``为您的API密钥。
在https://sambaverse.sambanova.ai获取一个
在https://docs.sambanova.ai/sambaverse/latest/index.html阅读额外文档
示例:
.. code-block:: python
from langchain_community.llms.sambanova import Sambaverse
Sambaverse(
sambaverse_url="https://sambaverse.sambanova.ai",
sambaverse_api_key="your-sambaverse-api-key",
sambaverse_model_name="Meta/llama-2-7b-chat-hf",
streaming: = False
model_kwargs={
"select_expert": "llama-2-7b-chat-hf",
"do_sample": False,
"max_tokens_to_generate": 100,
"temperature": 0.7,
"top_p": 1.0,
"repetition_penalty": 1.0,
"top_k": 50,
},
)"""
sambaverse_url: str = ""
"""Sambaverse要使用的URL"""
sambaverse_api_key: str = ""
"""sambaverse API密钥"""
sambaverse_model_name: Optional[str] = None
"""专家模型的Samba版本"""
model_kwargs: Optional[dict] = None
"""传递给模型的关键字参数。"""
streaming: Optional[bool] = False
"""用于获取流式响应的流标志。"""
class Config:
"""此pydantic对象的配置。"""
extra = Extra.forbid
[docs] @classmethod
def is_lc_serializable(cls) -> bool:
return True
@root_validator()
def validate_environment(cls, values: Dict) -> Dict:
"""验证环境中是否存在API密钥。"""
values["sambaverse_url"] = get_from_dict_or_env(
values,
"sambaverse_url",
"SAMBAVERSE_URL",
default="https://sambaverse.sambanova.ai",
)
values["sambaverse_api_key"] = get_from_dict_or_env(
values, "sambaverse_api_key", "SAMBAVERSE_API_KEY"
)
values["sambaverse_model_name"] = get_from_dict_or_env(
values, "sambaverse_model_name", "SAMBAVERSE_MODEL_NAME"
)
return values
@property
def _identifying_params(self) -> Dict[str, Any]:
"""获取识别参数。"""
return {**{"model_kwargs": self.model_kwargs}}
@property
def _llm_type(self) -> str:
"""llm的返回类型。"""
return "Sambaverse LLM"
def _get_tuning_params(self, stop: Optional[List[str]]) -> str:
"""获取调整参数以在调用LLM时使用。
参数:
stop: 生成时要使用的停用词。模型输出在停止子字符串的任何首次出现时被截断。
返回:
调整参数作为JSON字符串。
"""
_model_kwargs = self.model_kwargs or {}
_kwarg_stop_sequences = _model_kwargs.get("stop_sequences", [])
_stop_sequences = stop or _kwarg_stop_sequences
if not _kwarg_stop_sequences:
_model_kwargs["stop_sequences"] = ",".join(
f'"{x}"' for x in _stop_sequences
)
tuning_params_dict = {
k: {"type": type(v).__name__, "value": str(v)}
for k, v in (_model_kwargs.items())
}
_model_kwargs["stop_sequences"] = _kwarg_stop_sequences
tuning_params = json.dumps(tuning_params_dict)
return tuning_params
def _handle_nlp_predict(
self,
sdk: SVEndpointHandler,
prompt: Union[List[str], str],
tuning_params: str,
) -> str:
"""使用Sambaverse端点处理程序进行NLP预测。
参数:
sdk:用于预测的SVEndpointHandler。
prompt:用于预测的提示。
tuning_params:用于预测的调整参数。
返回:
预测结果。
引发:
ValueError:如果预测失败。
"""
response = sdk.nlp_predict(
self.sambaverse_api_key, self.sambaverse_model_name, prompt, tuning_params
)
if response["status_code"] != 200:
error = response.get("error")
if error:
optional_code = error.get("code")
optional_details = error.get("details")
optional_message = error.get("message")
raise RuntimeError(
f"Sambanova /complete call failed with status code "
f"{response['status_code']}.\n"
f"Message: {optional_message}\n"
f"Details: {optional_details}\n"
f"Code: {optional_code}\n"
)
else:
raise RuntimeError(
f"Sambanova /complete call failed with status code "
f"{response['status_code']}."
f"{response}."
)
return response["result"]["responses"][0]["completion"]
def _handle_completion_requests(
self, prompt: Union[List[str], str], stop: Optional[List[str]]
) -> str:
"""使用Sambaverse端点处理程序进行预测。
参数:
prompt: 用于预测的提示。
stop: 停止序列。
返回:
预测结果。
引发:
ValueError: 如果预测失败。
"""
ss_endpoint = SVEndpointHandler(self.sambaverse_url)
tuning_params = self._get_tuning_params(stop)
return self._handle_nlp_predict(ss_endpoint, prompt, tuning_params)
def _handle_nlp_predict_stream(
self, sdk: SVEndpointHandler, prompt: Union[List[str], str], tuning_params: str
) -> Iterator[GenerationChunk]:
"""执行到LLM的流式请求。
参数:
sdk: 用于预测的SVEndpointHandler。
prompt: 用于预测的提示。
tuning_params: 用于预测的调整参数。
返回:
一个GenerationChunks的迭代器。
"""
for chunk in sdk.nlp_predict_stream(
self.sambaverse_api_key, self.sambaverse_model_name, prompt, tuning_params
):
if chunk["status_code"] != 200:
error = chunk.get("error")
if error:
optional_code = error.get("code")
optional_details = error.get("details")
optional_message = error.get("message")
raise ValueError(
f"Sambanova /complete call failed with status code "
f"{chunk['status_code']}.\n"
f"Message: {optional_message}\n"
f"Details: {optional_details}\n"
f"Code: {optional_code}\n"
)
else:
raise RuntimeError(
f"Sambanova /complete call failed with status code "
f"{chunk['status_code']}."
f"{chunk}."
)
text = chunk["result"]["responses"][0]["stream_token"]
generated_chunk = GenerationChunk(text=text)
yield generated_chunk
def _stream(
self,
prompt: Union[List[str], str],
stop: Optional[List[str]] = None,
run_manager: Optional[CallbackManagerForLLMRun] = None,
**kwargs: Any,
) -> Iterator[GenerationChunk]:
"""在给定提示上流式传输Sambaverse的LLM。
参数:
prompt: 传递给模型的提示。
stop: 生成时使用的可选停止词列表。
run_manager: 运行的回调管理器。
**kwargs: 附加关键字参数。直接传递给Sambaverse模型的API调用。
返回:
一个GenerationChunks的迭代器。
"""
ss_endpoint = SVEndpointHandler(self.sambaverse_url)
tuning_params = self._get_tuning_params(stop)
try:
if self.streaming:
for chunk in self._handle_nlp_predict_stream(
ss_endpoint, prompt, tuning_params
):
if run_manager:
run_manager.on_llm_new_token(chunk.text)
yield chunk
else:
return
except Exception as e:
# Handle any errors raised by the inference endpoint
raise ValueError(f"Error raised by the inference endpoint: {e}") from e
def _handle_stream_request(
self,
prompt: Union[List[str], str],
stop: Optional[List[str]],
run_manager: Optional[CallbackManagerForLLMRun],
kwargs: Dict[str, Any],
) -> str:
"""执行一个到LLM的流式请求。
参数:
prompt: 生成的提示。
stop: 在生成时使用的停止词。模型输出在第一次出现任何停止子字符串时被截断。
run_manager: 运行的回调管理器。
**kwargs: 附加关键字参数。直接传递给API调用中的sambaverse模型。
返回:
模型输出作为字符串。
"""
completion = ""
for chunk in self._stream(
prompt=prompt, stop=stop, run_manager=run_manager, **kwargs
):
completion += chunk.text
return completion
def _call(
self,
prompt: Union[List[str], str],
stop: Optional[List[str]] = None,
run_manager: Optional[CallbackManagerForLLMRun] = None,
**kwargs: Any,
) -> str:
"""在给定的输入上运行LLM。
参数:
prompt: 生成的提示。
stop: 生成时使用的停止词。模型输出在停止子字符串的第一次出现时被截断。
run_manager: 运行的回调管理器。
**kwargs: 附加的关键字参数。直接传递给API调用中的sambaverse模型。
返回:
模型输出作为字符串。
"""
try:
if self.streaming:
return self._handle_stream_request(prompt, stop, run_manager, kwargs)
return self._handle_completion_requests(prompt, stop)
except Exception as e:
# Handle any errors raised by the inference endpoint
raise ValueError(f"Error raised by the inference endpoint: {e}") from e
[docs]class SSEndpointHandler:
"""SambaNova Systems接口,用于SambaStudio模型端点。
:param str host_url: DaaS API服务的基本URL"""
[docs] def __init__(self, host_url: str, api_base_uri: str):
"""初始化SSEndpointHandler。
:param str host_url: DaaS API服务的基本URL
:param str api_base_uri: DaaS API服务的基本URI
"""
self.host_url = host_url
self.api_base_uri = api_base_uri
self.http_session = requests.Session()
def _process_response(self, response: requests.Response) -> Dict:
"""处理API响应并返回结果字典。
所有的结果字典,无论成功与否,都将包含`status_code`键与API响应状态码。
如果API返回错误,结果字典将包含键`detail`与错误消息。
如果API调用成功,结果字典将包含键`data`与响应数据。
:param requests.Response response: 要处理的响应对象
:return: 响应字典
:rtype: dict
"""
result: Dict[str, Any] = {}
try:
result = response.json()
except Exception as e:
result["detail"] = str(e)
if "status_code" not in result:
result["status_code"] = response.status_code
return result
def _process_streaming_response(
self,
response: requests.Response,
) -> Generator[Dict, None, None]:
"""处理流式响应"""
if "nlp" in self.api_base_uri:
try:
import sseclient
except ImportError:
raise ImportError(
"could not import sseclient library"
"Please install it with `pip install sseclient-py`."
)
client = sseclient.SSEClient(response)
close_conn = False
for event in client.events():
if event.event == "error_event":
close_conn = True
chunk = {
"event": event.event,
"data": event.data,
"status_code": response.status_code,
}
yield chunk
if close_conn:
client.close()
elif "generic" in self.api_base_uri:
try:
for line in response.iter_lines():
chunk = json.loads(line)
if "status_code" not in chunk:
chunk["status_code"] = response.status_code
if chunk["status_code"] == 200 and chunk.get("error"):
chunk["result"] = {"responses": [{"stream_token": ""}]}
yield chunk
except Exception as e:
raise RuntimeError(f"Error processing streaming response: {e}")
else:
raise ValueError(
f"handling of endpoint uri: {self.api_base_uri} not implemented"
)
def _get_full_url(self, path: str) -> str:
"""返回给定路径的完整API URL。
:param str path: 子路径
:returns: 子路径的完整API URL
:rtype: str
"""
return f"{self.host_url}/{self.api_base_uri}/{path}"
[docs] def nlp_predict(
self,
project: str,
endpoint: str,
key: str,
input: Union[List[str], str],
params: Optional[str] = "",
stream: bool = False,
) -> Dict:
"""使用内联输入字符串进行自然语言处理预测。
:param str project: 存在端点的项目ID
:param str endpoint: 端点ID
:param str key: API密钥
:param str input_str: 输入字符串
:param str params: 输入参数字符串
:returns: 预测结果
:rtype: dict
"""
if isinstance(input, str):
input = [input]
if "nlp" in self.api_base_uri:
if params:
data = {"inputs": input, "params": json.loads(params)}
else:
data = {"inputs": input}
elif "generic" in self.api_base_uri:
if params:
data = {"instances": input, "params": json.loads(params)}
else:
data = {"instances": input}
else:
raise ValueError(
f"handling of endpoint uri: {self.api_base_uri} not implemented"
)
response = self.http_session.post(
self._get_full_url(f"{project}/{endpoint}"),
headers={"key": key},
json=data,
)
return self._process_response(response)
[docs] def nlp_predict_stream(
self,
project: str,
endpoint: str,
key: str,
input: Union[List[str], str],
params: Optional[str] = "",
) -> Iterator[Dict]:
"""使用内联输入字符串进行自然语言处理预测。
:param str project: 存在端点的项目ID
:param str endpoint: 端点ID
:param str key: API密钥
:param str input_str: 输入字符串
:param str params: 输入参数字符串
:returns: 预测结果
:rtype: dict
"""
if "nlp" in self.api_base_uri:
if isinstance(input, str):
input = [input]
if params:
data = {"inputs": input, "params": json.loads(params)}
else:
data = {"inputs": input}
elif "generic" in self.api_base_uri:
if isinstance(input, list):
input = input[0]
if params:
data = {"instance": input, "params": json.loads(params)}
else:
data = {"instance": input}
else:
raise ValueError(
f"handling of endpoint uri: {self.api_base_uri} not implemented"
)
# Streaming output
response = self.http_session.post(
self._get_full_url(f"stream/{project}/{endpoint}"),
headers={"key": key},
json=data,
stream=True,
)
for chunk in self._process_streaming_response(response):
yield chunk
[docs]class SambaStudio(LLM):
""" SambaStudio大型语言模型。
要使用,您应该设置环境变量
``SAMBASTUDIO_BASE_URL`` 为您的SambaStudio环境URL。
``SAMBASTUDIO_BASE_URI`` 为您的SambaStudio API基本URI。
``SAMBASTUDIO_PROJECT_ID`` 为您的SambaStudio项目ID。
``SAMBASTUDIO_ENDPOINT_ID`` 为您的SambaStudio端点ID。
``SAMBASTUDIO_API_KEY`` 为您的SambaStudio端点API密钥。
https://sambanova.ai/products/enterprise-ai-platform-sambanova-suite
阅读更多文档,请访问 https://docs.sambanova.ai/sambastudio/latest/index.html
示例:
.. code-block:: python
from langchain_community.llms.sambanova import Sambaverse
SambaStudio(
sambastudio_base_url="your-SambaStudio-environment-URL",
sambastudio_base_uri="your-SambaStudio-base-URI",
sambastudio_project_id="your-SambaStudio-project-ID",
sambastudio_endpoint_id="your-SambaStudio-endpoint-ID",
sambastudio_api_key="your-SambaStudio-endpoint-API-key,
streaming=False
model_kwargs={
"do_sample": False,
"max_tokens_to_generate": 1000,
"temperature": 0.7,
"top_p": 1.0,
"repetition_penalty": 1,
"top_k": 50,
},
)"""
sambastudio_base_url: str = ""
"""用于基础URL"""
sambastudio_base_uri: str = ""
"""端点基本URI"""
sambastudio_project_id: str = ""
"""在sambastudio上用于模型的项目ID"""
sambastudio_endpoint_id: str = ""
"""在SambaStudio上用于模型的端点ID"""
sambastudio_api_key: str = ""
"""Sambastudio API密钥"""
model_kwargs: Optional[dict] = None
"""传递给模型的关键字参数。"""
streaming: Optional[bool] = False
"""用于获取流式响应的流标志。"""
class Config:
"""此pydantic对象的配置。"""
extra = Extra.forbid
[docs] @classmethod
def is_lc_serializable(cls) -> bool:
return True
@property
def _identifying_params(self) -> Dict[str, Any]:
"""获取识别参数。"""
return {**{"model_kwargs": self.model_kwargs}}
@property
def _llm_type(self) -> str:
"""llm的返回类型。"""
return "Sambastudio LLM"
@root_validator()
def validate_environment(cls, values: Dict) -> Dict:
"""验证环境中是否存在API密钥和Python包。"""
values["sambastudio_base_url"] = get_from_dict_or_env(
values, "sambastudio_base_url", "SAMBASTUDIO_BASE_URL"
)
values["sambastudio_base_uri"] = get_from_dict_or_env(
values,
"sambastudio_base_uri",
"SAMBASTUDIO_BASE_URI",
default="api/predict/nlp",
)
values["sambastudio_project_id"] = get_from_dict_or_env(
values, "sambastudio_project_id", "SAMBASTUDIO_PROJECT_ID"
)
values["sambastudio_endpoint_id"] = get_from_dict_or_env(
values, "sambastudio_endpoint_id", "SAMBASTUDIO_ENDPOINT_ID"
)
values["sambastudio_api_key"] = get_from_dict_or_env(
values, "sambastudio_api_key", "SAMBASTUDIO_API_KEY"
)
return values
def _get_tuning_params(self, stop: Optional[List[str]]) -> str:
"""获取调整参数以在调用LLM时使用。
参数:
stop: 生成时要使用的停用词。模型输出在停止子字符串的任何首次出现时被截断。
返回:
调整参数作为JSON字符串。
"""
_model_kwargs = self.model_kwargs or {}
_kwarg_stop_sequences = _model_kwargs.get("stop_sequences", [])
_stop_sequences = stop or _kwarg_stop_sequences
# if not _kwarg_stop_sequences:
# _model_kwargs["stop_sequences"] = ",".join(
# f'"{x}"' for x in _stop_sequences
# )
tuning_params_dict = {
k: {"type": type(v).__name__, "value": str(v)}
for k, v in (_model_kwargs.items())
}
# _model_kwargs["stop_sequences"] = _kwarg_stop_sequences
tuning_params = json.dumps(tuning_params_dict)
return tuning_params
def _handle_nlp_predict(
self, sdk: SSEndpointHandler, prompt: Union[List[str], str], tuning_params: str
) -> str:
"""使用SambaStudio端点处理程序执行NLP预测。
参数:
sdk:用于预测的SSEndpointHandler。
prompt:用于预测的提示。
tuning_params:用于预测的调整参数。
返回:
预测结果。
引发:
ValueError:如果预测失败。
"""
response = sdk.nlp_predict(
self.sambastudio_project_id,
self.sambastudio_endpoint_id,
self.sambastudio_api_key,
prompt,
tuning_params,
)
if response["status_code"] != 200:
optional_detail = response.get("detail")
if optional_detail:
raise RuntimeError(
f"Sambanova /complete call failed with status code "
f"{response['status_code']}.\n Details: {optional_detail}"
)
else:
raise RuntimeError(
f"Sambanova /complete call failed with status code "
f"{response['status_code']}.\n response {response}"
)
if "nlp" in self.sambastudio_base_uri:
return response["data"][0]["completion"]
elif "generic" in self.sambastudio_base_uri:
return response["predictions"][0]["completion"]
else:
raise ValueError(
f"handling of endpoint uri: {self.sambastudio_base_uri} not implemented"
)
def _handle_completion_requests(
self, prompt: Union[List[str], str], stop: Optional[List[str]]
) -> str:
"""使用SambaStudio端点处理程序进行预测。
参数:
prompt: 用于预测的提示。
stop: 停止序列。
返回:
预测结果。
引发:
ValueError: 如果预测失败。
"""
ss_endpoint = SSEndpointHandler(
self.sambastudio_base_url, self.sambastudio_base_uri
)
tuning_params = self._get_tuning_params(stop)
return self._handle_nlp_predict(ss_endpoint, prompt, tuning_params)
def _handle_nlp_predict_stream(
self, sdk: SSEndpointHandler, prompt: Union[List[str], str], tuning_params: str
) -> Iterator[GenerationChunk]:
"""执行到LLM的流式请求。
参数:
sdk: 用于预测的SVEndpointHandler。
prompt: 用于预测的提示。
tuning_params: 用于预测的调整参数。
返回:
一个GenerationChunks的迭代器。
"""
for chunk in sdk.nlp_predict_stream(
self.sambastudio_project_id,
self.sambastudio_endpoint_id,
self.sambastudio_api_key,
prompt,
tuning_params,
):
if chunk["status_code"] != 200:
error = chunk.get("error")
if error:
optional_code = error.get("code")
optional_details = error.get("details")
optional_message = error.get("message")
raise ValueError(
f"Sambanova /complete call failed with status code "
f"{chunk['status_code']}.\n"
f"Message: {optional_message}\n"
f"Details: {optional_details}\n"
f"Code: {optional_code}\n"
)
else:
raise RuntimeError(
f"Sambanova /complete call failed with status code "
f"{chunk['status_code']}."
f"{chunk}."
)
if "nlp" in self.sambastudio_base_uri:
text = json.loads(chunk["data"])["stream_token"]
elif "generic" in self.sambastudio_base_uri:
text = chunk["result"]["responses"][0]["stream_token"]
else:
raise ValueError(
f"handling of endpoint uri: {self.sambastudio_base_uri}"
f"not implemented"
)
generated_chunk = GenerationChunk(text=text)
yield generated_chunk
def _stream(
self,
prompt: Union[List[str], str],
stop: Optional[List[str]] = None,
run_manager: Optional[CallbackManagerForLLMRun] = None,
**kwargs: Any,
) -> Iterator[GenerationChunk]:
"""调用Sambanova的完整端点。
参数:
prompt: 传递给模型的提示。
stop: 生成时可选的停止词列表。
返回:
模型生成的字符串。
"""
ss_endpoint = SSEndpointHandler(
self.sambastudio_base_url, self.sambastudio_base_uri
)
tuning_params = self._get_tuning_params(stop)
try:
if self.streaming:
for chunk in self._handle_nlp_predict_stream(
ss_endpoint, prompt, tuning_params
):
if run_manager:
run_manager.on_llm_new_token(chunk.text)
yield chunk
else:
return
except Exception as e:
# Handle any errors raised by the inference endpoint
raise ValueError(f"Error raised by the inference endpoint: {e}") from e
def _handle_stream_request(
self,
prompt: Union[List[str], str],
stop: Optional[List[str]],
run_manager: Optional[CallbackManagerForLLMRun],
kwargs: Dict[str, Any],
) -> str:
"""执行一个到LLM的流式请求。
参数:
prompt: 生成的提示。
stop: 在生成时使用的停止词。模型输出在第一次出现任何停止子字符串时被截断。
run_manager: 运行的回调管理器。
**kwargs: 附加关键字参数。直接传递给API调用中的sambaverse模型。
返回:
模型输出作为字符串。
"""
completion = ""
for chunk in self._stream(
prompt=prompt, stop=stop, run_manager=run_manager, **kwargs
):
completion += chunk.text
return completion
def _call(
self,
prompt: Union[List[str], str],
stop: Optional[List[str]] = None,
run_manager: Optional[CallbackManagerForLLMRun] = None,
**kwargs: Any,
) -> str:
"""调用Sambanova的完整端点。
参数:
prompt: 传递给模型的提示。
stop: 生成时可选的停止词列表。
返回:
模型生成的字符串。
"""
if stop is not None:
raise Exception("stop not implemented")
try:
if self.streaming:
return self._handle_stream_request(prompt, stop, run_manager, kwargs)
return self._handle_completion_requests(prompt, stop)
except Exception as e:
# Handle any errors raised by the inference endpoint
raise ValueError(f"Error raised by the inference endpoint: {e}") from e