# Copyright (c) 2023, 2024, Oracle and/or its affiliates.
"""LLM for OCI data science model deployment endpoint."""
import json
import logging
import traceback
from typing import (
Any,
AsyncIterator,
Callable,
Dict,
Iterator,
List,
Literal,
Optional,
Union,
)
import aiohttp
import requests
from langchain_core.callbacks import (
AsyncCallbackManagerForLLMRun,
CallbackManagerForLLMRun,
)
from langchain_core.language_models.llms import BaseLLM, create_base_retry_decorator
from langchain_core.load.serializable import Serializable
from langchain_core.outputs import Generation, GenerationChunk, LLMResult
from langchain_core.utils import get_from_dict_or_env
from pydantic import Field, model_validator
from langchain_community.utilities.requests import Requests
logger = logging.getLogger(__name__)
DEFAULT_INFERENCE_ENDPOINT = "/v1/completions"
DEFAULT_TIME_OUT = 300
DEFAULT_CONTENT_TYPE_JSON = "application/json"
DEFAULT_MODEL_NAME = "odsc-llm"
[docs]
class TokenExpiredError(Exception):
"""Raises when token expired."""
[docs]
class ServerError(Exception):
"""Raises when encounter server error when making inference."""
def _create_retry_decorator(
llm: "BaseOCIModelDeployment",
*,
run_manager: Optional[
Union[AsyncCallbackManagerForLLMRun, CallbackManagerForLLMRun]
] = None,
) -> Callable[[Any], Any]:
"""Create a retry decorator."""
errors = [requests.exceptions.ConnectTimeout, TokenExpiredError]
decorator = create_base_retry_decorator(
error_types=errors, max_retries=llm.max_retries, run_manager=run_manager
)
return decorator
[docs]
class BaseOCIModelDeployment(Serializable):
"""Base class for LLM deployed on OCI Data Science Model Deployment."""
auth: dict = Field(default_factory=dict, exclude=True)
"""ADS auth dictionary for OCI authentication:
https://accelerated-data-science.readthedocs.io/en/latest/user_guide/cli/authentication.html.
This can be generated by calling `ads.common.auth.api_keys()`
or `ads.common.auth.resource_principal()`. If this is not
provided then the `ads.common.default_signer()` will be used."""
endpoint: str = ""
"""The uri of the endpoint from the deployed Model Deployment model."""
streaming: bool = False
"""Whether to stream the results or not."""
max_retries: int = 3
"""Maximum number of retries to make when generating."""
default_headers: Optional[Dict[str, Any]] = None
"""The headers to be added to the Model Deployment request."""
@model_validator(mode="before")
@classmethod
def validate_environment(cls, values: Dict) -> Dict:
"""Checks if oracle-ads is installed and
get credentials/endpoint from environment.
"""
try:
import ads
except ImportError as ex:
raise ImportError(
"Could not import ads python package. "
"Please install it with `pip install oracle_ads`."
) from ex
if not values.get("auth", None):
values["auth"] = ads.common.auth.default_signer()
values["endpoint"] = get_from_dict_or_env(
values,
"endpoint",
"OCI_LLM_ENDPOINT",
)
return values
def _headers(
self, is_async: Optional[bool] = False, body: Optional[dict] = None
) -> Dict:
"""Construct and return the headers for a request.
Args:
is_async (bool, optional): Indicates if the request is asynchronous.
Defaults to `False`.
body (optional): The request body to be included in the headers if
the request is asynchronous.
Returns:
Dict: A dictionary containing the appropriate headers for the request.
"""
headers = self.default_headers or {}
if is_async:
signer = self.auth["signer"]
_req = requests.Request("POST", self.endpoint, json=body)
req = _req.prepare()
req = signer(req)
for key, value in req.headers.items():
headers[key] = value
if self.streaming:
headers.update(
{"enable-streaming": "true", "Accept": "text/event-stream"}
)
return headers
headers.update(
{
"Content-Type": DEFAULT_CONTENT_TYPE_JSON,
"enable-streaming": "true",
"Accept": "text/event-stream",
}
if self.streaming
else {
"Content-Type": DEFAULT_CONTENT_TYPE_JSON,
}
)
return headers
[docs]
def completion_with_retry(
self, run_manager: Optional[CallbackManagerForLLMRun] = None, **kwargs: Any
) -> Any:
"""Use tenacity to retry the completion call."""
retry_decorator = _create_retry_decorator(self, run_manager=run_manager)
@retry_decorator
def _completion_with_retry(**kwargs: Any) -> Any:
try:
request_timeout = kwargs.pop("request_timeout", DEFAULT_TIME_OUT)
data = kwargs.pop("data")
stream = kwargs.pop("stream", self.streaming)
request = Requests(
headers=self._headers(), auth=self.auth.get("signer")
)
response = request.post(
url=self.endpoint,
data=data,
timeout=request_timeout,
stream=stream,
**kwargs,
)
self._check_response(response)
return response
except TokenExpiredError as e:
raise e
except Exception as err:
traceback.print_exc()
logger.debug(
f"Requests payload: {data}. Requests arguments: "
f"url={self.endpoint},timeout={request_timeout},stream={stream}. "
f"Additional request kwargs={kwargs}."
)
raise RuntimeError(
f"Error occurs by inference endpoint: {str(err)}"
) from err
return _completion_with_retry(**kwargs)
[docs]
async def acompletion_with_retry(
self,
run_manager: Optional[AsyncCallbackManagerForLLMRun] = None,
**kwargs: Any,
) -> Any:
"""Use tenacity to retry the async completion call."""
retry_decorator = _create_retry_decorator(self, run_manager=run_manager)
@retry_decorator
async def _completion_with_retry(**kwargs: Any) -> Any:
try:
request_timeout = kwargs.pop("request_timeout", DEFAULT_TIME_OUT)
data = kwargs.pop("data")
stream = kwargs.pop("stream", self.streaming)
request = Requests(headers=self._headers(is_async=True, body=data))
if stream:
response = request.apost(
url=self.endpoint,
data=data,
timeout=request_timeout,
)
return self._aiter_sse(response)
else:
async with request.apost(
url=self.endpoint,
data=data,
timeout=request_timeout,
) as resp:
self._check_response(resp)
data = await resp.json()
return data
except TokenExpiredError as e:
raise e
except Exception as err:
traceback.print_exc()
logger.debug(
f"Requests payload: `{data}`. "
f"Stream mode={stream}. "
f"Requests kwargs: url={self.endpoint}, timeout={request_timeout}."
)
raise RuntimeError(
f"Error occurs by inference endpoint: {str(err)}"
) from err
return await _completion_with_retry(**kwargs)
def _check_response(self, response: Any) -> None:
"""Handle server error by checking the response status.
Args:
response:
The response object from either `requests` or `aiohttp` library.
Raises:
TokenExpiredError:
If the response status code is 401 and the token refresh is successful.
ServerError:
If any other HTTP error occurs.
"""
try:
response.raise_for_status()
except requests.exceptions.HTTPError as http_err:
status_code = (
response.status_code
if hasattr(response, "status_code")
else response.status
)
if status_code == 401 and self._refresh_signer():
raise TokenExpiredError() from http_err
raise ServerError(
f"Server error: {str(http_err)}. \nMessage: {response.text}"
) from http_err
def _parse_stream(self, lines: Iterator[bytes]) -> Iterator[str]:
"""Parse a stream of byte lines and yield parsed string lines.
Args:
lines (Iterator[bytes]):
An iterator that yields lines in byte format.
Yields:
Iterator[str]:
An iterator that yields parsed lines as strings.
"""
for line in lines:
_line = self._parse_stream_line(line)
if _line is not None:
yield _line
async def _parse_stream_async(
self,
lines: aiohttp.StreamReader,
) -> AsyncIterator[str]:
"""
Asynchronously parse a stream of byte lines and yield parsed string lines.
Args:
lines (aiohttp.StreamReader):
An `aiohttp.StreamReader` object that yields lines in byte format.
Yields:
AsyncIterator[str]:
An asynchronous iterator that yields parsed lines as strings.
"""
async for line in lines:
_line = self._parse_stream_line(line)
if _line is not None:
yield _line
def _parse_stream_line(self, line: bytes) -> Optional[str]:
"""Parse a single byte line and return a processed string line if valid.
Args:
line (bytes): A single line in byte format.
Returns:
Optional[str]:
The processed line as a string if valid, otherwise `None`.
"""
line = line.strip()
if not line:
return None
_line = line.decode("utf-8")
if _line.lower().startswith("data:"):
_line = _line[5:].lstrip()
if _line.startswith("[DONE]"):
return None
return _line
return None
async def _aiter_sse(
self,
async_cntx_mgr: Any,
) -> AsyncIterator[str]:
"""Asynchronously iterate over server-sent events (SSE).
Args:
async_cntx_mgr: An asynchronous context manager that yields a client
response object.
Yields:
AsyncIterator[str]: An asynchronous iterator that yields parsed server-sent
event lines as json string.
"""
async with async_cntx_mgr as client_resp:
self._check_response(client_resp)
async for line in self._parse_stream_async(client_resp.content):
yield line
def _refresh_signer(self) -> bool:
"""Attempt to refresh the security token using the signer.
Returns:
bool: `True` if the token was successfully refreshed, `False` otherwise.
"""
if self.auth.get("signer", None) and hasattr(
self.auth["signer"], "refresh_security_token"
):
self.auth["signer"].refresh_security_token()
return True
return False
@classmethod
def is_lc_serializable(cls) -> bool:
"""Return whether this model can be serialized by LangChain."""
return True
[docs]
class OCIModelDeploymentLLM(BaseLLM, BaseOCIModelDeployment):
"""LLM deployed on OCI Data Science Model Deployment.
To use, you must provide the model HTTP endpoint from your deployed
model, e.g. https://modeldeployment.<region>.oci.customer-oci.com/<md_ocid>/predict.
To authenticate, `oracle-ads` has been used to automatically load
credentials: https://accelerated-data-science.readthedocs.io/en/latest/user_guide/cli/authentication.html
Make sure to have the required policies to access the OCI Data
Science Model Deployment endpoint. See:
https://docs.oracle.com/en-us/iaas/data-science/using/model-dep-policies-auth.htm#model_dep_policies_auth__predict-endpoint
Example:
.. code-block:: python
from langchain_community.llms import OCIModelDeploymentLLM
llm = OCIModelDeploymentLLM(
endpoint="https://modeldeployment.us-ashburn-1.oci.customer-oci.com/<ocid>/predict",
model="odsc-llm",
streaming=True,
model_kwargs={"frequency_penalty": 1.0},
headers={
"route": "/v1/completions",
# other request headers ...
}
)
llm.invoke("tell me a joke.")
Customized Usage:
User can inherit from our base class and overrwrite the `_process_response`, `_process_stream_response`,
`_construct_json_body` for satisfying customized needed.
.. code-block:: python
from langchain_community.llms import OCIModelDeploymentLLM
class MyCutomizedModel(OCIModelDeploymentLLM):
def _process_stream_response(self, response_json:dict) -> GenerationChunk:
print("My customized output stream handler.")
return GenerationChunk()
def _process_response(self, response_json:dict) -> List[Generation]:
print("My customized output handler.")
return [Generation()]
def _construct_json_body(self, prompt: str, param:dict) -> dict:
print("My customized input handler.")
return {}
llm = MyCutomizedModel(
endpoint=f"https://modeldeployment.us-ashburn-1.oci.customer-oci.com/{ocid}/predict",
model="<model_name>",
}
llm.invoke("tell me a joke.")
""" # noqa: E501
model: str = DEFAULT_MODEL_NAME
"""The name of the model."""
max_tokens: int = 256
"""Denotes the number of tokens to predict per generation."""
temperature: float = 0.2
"""A non-negative float that tunes the degree of randomness in generation."""
k: int = 50
"""Number of most likely tokens to consider at each step."""
p: float = 0.75
"""Total probability mass of tokens to consider at each step."""
best_of: int = 1
"""Generates best_of completions server-side and returns the "best"
(the one with the highest log probability per token).
"""
stop: Optional[List[str]] = None
"""Stop words to use when generating. Model output is cut off
at the first occurrence of any of these substrings."""
model_kwargs: Dict[str, Any] = Field(default_factory=dict)
"""Keyword arguments to pass to the model."""
@property
def _llm_type(self) -> str:
"""Return type of llm."""
return "oci_model_deployment_endpoint"
@property
def _default_params(self) -> Dict[str, Any]:
"""Get the default parameters."""
return {
"best_of": self.best_of,
"max_tokens": self.max_tokens,
"model": self.model,
"stop": self.stop,
"stream": self.streaming,
"temperature": self.temperature,
"top_k": self.k,
"top_p": self.p,
}
@property
def _identifying_params(self) -> Dict[str, Any]:
"""Get the identifying parameters."""
_model_kwargs = self.model_kwargs or {}
return {
**{"endpoint": self.endpoint, "model_kwargs": _model_kwargs},
**self._default_params,
}
def _headers(
self, is_async: Optional[bool] = False, body: Optional[dict] = None
) -> Dict:
"""Construct and return the headers for a request.
Args:
is_async (bool, optional): Indicates if the request is asynchronous.
Defaults to `False`.
body (optional): The request body to be included in the headers if
the request is asynchronous.
Returns:
Dict: A dictionary containing the appropriate headers for the request.
"""
return {
"route": DEFAULT_INFERENCE_ENDPOINT,
**super()._headers(is_async=is_async, body=body),
}
def _generate(
self,
prompts: List[str],
stop: Optional[List[str]] = None,
run_manager: Optional[CallbackManagerForLLMRun] = None,
**kwargs: Any,
) -> LLMResult:
"""Call out to OCI Data Science Model Deployment endpoint with k unique prompts.
Args:
prompts: The prompts to pass into the service.
stop: Optional list of stop words to use when generating.
Returns:
The full LLM output.
Example:
.. code-block:: python
response = llm.invoke("Tell me a joke.")
response = llm.generate(["Tell me a joke."])
"""
generations: List[List[Generation]] = []
params = self._invocation_params(stop, **kwargs)
for prompt in prompts:
body = self._construct_json_body(prompt, params)
if self.streaming:
generation = GenerationChunk(text="")
for chunk in self._stream(
prompt, stop=stop, run_manager=run_manager, **kwargs
):
generation += chunk
generations.append([generation])
else:
res = self.completion_with_retry(
data=body,
run_manager=run_manager,
**kwargs,
)
generations.append(self._process_response(res.json()))
return LLMResult(generations=generations)
async def _agenerate(
self,
prompts: List[str],
stop: Optional[List[str]] = None,
run_manager: Optional[AsyncCallbackManagerForLLMRun] = None,
**kwargs: Any,
) -> LLMResult:
"""Call out to OCI Data Science Model Deployment endpoint async with k unique prompts.
Args:
prompts: The prompts to pass into the service.
stop: Optional list of stop words to use when generating.
Returns:
The full LLM output.
Example:
.. code-block:: python
response = await llm.ainvoke("Tell me a joke.")
response = await llm.agenerate(["Tell me a joke."])
""" # noqa: E501
generations: List[List[Generation]] = []
params = self._invocation_params(stop, **kwargs)
for prompt in prompts:
body = self._construct_json_body(prompt, params)
if self.streaming:
generation = GenerationChunk(text="")
async for chunk in self._astream(
prompt, stop=stop, run_manager=run_manager, **kwargs
):
generation += chunk
generations.append([generation])
else:
res = await self.acompletion_with_retry(
data=body,
run_manager=run_manager,
**kwargs,
)
generations.append(self._process_response(res))
return LLMResult(generations=generations)
def _stream(
self,
prompt: str,
stop: Optional[List[str]] = None,
run_manager: Optional[CallbackManagerForLLMRun] = None,
**kwargs: Any,
) -> Iterator[GenerationChunk]:
"""Stream OCI Data Science Model Deployment endpoint on given prompt.
Args:
prompt (str):
The prompt to pass into the model.
stop (List[str], Optional):
List of stop words to use when generating.
kwargs:
requests_kwargs:
Additional ``**kwargs`` to pass to requests.post
Returns:
An iterator of GenerationChunks.
Example:
.. code-block:: python
response = llm.stream("Tell me a joke.")
"""
requests_kwargs = kwargs.pop("requests_kwargs", {})
self.streaming = True
params = self._invocation_params(stop, **kwargs)
body = self._construct_json_body(prompt, params)
response = self.completion_with_retry(
data=body, run_manager=run_manager, stream=True, **requests_kwargs
)
for line in self._parse_stream(response.iter_lines()):
chunk = self._handle_sse_line(line)
if run_manager:
run_manager.on_llm_new_token(chunk.text, chunk=chunk)
yield chunk
async def _astream(
self,
prompt: str,
stop: Optional[List[str]] = None,
run_manager: Optional[AsyncCallbackManagerForLLMRun] = None,
**kwargs: Any,
) -> AsyncIterator[GenerationChunk]:
"""Stream OCI Data Science Model Deployment endpoint async on given prompt.
Args:
prompt (str):
The prompt to pass into the model.
stop (List[str], Optional):
List of stop words to use when generating.
kwargs:
requests_kwargs:
Additional ``**kwargs`` to pass to requests.post
Returns:
An iterator of GenerationChunks.
Example:
.. code-block:: python
async for chunk in llm.astream(("Tell me a joke."):
print(chunk, end="", flush=True)
"""
requests_kwargs = kwargs.pop("requests_kwargs", {})
self.streaming = True
params = self._invocation_params(stop, **kwargs)
body = self._construct_json_body(prompt, params)
async for line in await self.acompletion_with_retry(
data=body, run_manager=run_manager, stream=True, **requests_kwargs
):
chunk = self._handle_sse_line(line)
if run_manager:
await run_manager.on_llm_new_token(chunk.text, chunk=chunk)
yield chunk
def _construct_json_body(self, prompt: str, params: dict) -> dict:
"""Constructs the request body as a dictionary (JSON)."""
return {
"prompt": prompt,
**params,
}
def _invocation_params(
self, stop: Optional[List[str]] = None, **kwargs: Any
) -> dict:
"""Combines the invocation parameters with default parameters."""
params = self._default_params
_model_kwargs = self.model_kwargs or {}
params["stop"] = stop or params.get("stop", [])
return {**params, **_model_kwargs, **kwargs}
def _process_stream_response(self, response_json: dict) -> GenerationChunk:
"""Formats streaming response for OpenAI spec into GenerationChunk."""
try:
choice = response_json["choices"][0]
if not isinstance(choice, dict):
raise TypeError("Endpoint response is not well formed.")
except (KeyError, IndexError, TypeError) as e:
raise ValueError("Error while formatting response payload.") from e
return GenerationChunk(text=choice.get("text", ""))
def _process_response(self, response_json: dict) -> List[Generation]:
"""Formats response in OpenAI spec.
Args:
response_json (dict): The JSON response from the chat model endpoint.
Returns:
ChatResult: An object containing the list of `ChatGeneration` objects
and additional LLM output information.
Raises:
ValueError: If the response JSON is not well-formed or does not
contain the expected structure.
"""
generations = []
try:
choices = response_json["choices"]
if not isinstance(choices, list):
raise TypeError("Endpoint response is not well formed.")
except (KeyError, TypeError) as e:
raise ValueError("Error while formatting response payload.") from e
for choice in choices:
gen = Generation(
text=choice.get("text"),
generation_info=self._generate_info(choice),
)
generations.append(gen)
return generations
def _generate_info(self, choice: dict) -> Any:
"""Extracts generation info from the response."""
gen_info = {}
finish_reason = choice.get("finish_reason", None)
logprobs = choice.get("logprobs", None)
index = choice.get("index", None)
if finish_reason:
gen_info.update({"finish_reason": finish_reason})
if logprobs is not None:
gen_info.update({"logprobs": logprobs})
if index is not None:
gen_info.update({"index": index})
return gen_info or None
def _handle_sse_line(self, line: str) -> GenerationChunk:
try:
obj = json.loads(line)
return self._process_stream_response(obj)
except Exception:
return GenerationChunk(text="")
[docs]
class OCIModelDeploymentTGI(OCIModelDeploymentLLM):
"""OCI Data Science Model Deployment TGI Endpoint.
To use, you must provide the model HTTP endpoint from your deployed
model, e.g. https://modeldeployment.<region>.oci.customer-oci.com/<md_ocid>/predict.
To authenticate, `oracle-ads` has been used to automatically load
credentials: https://accelerated-data-science.readthedocs.io/en/latest/user_guide/cli/authentication.html
Make sure to have the required policies to access the OCI Data
Science Model Deployment endpoint. See:
https://docs.oracle.com/en-us/iaas/data-science/using/model-dep-policies-auth.htm#model_dep_policies_auth__predict-endpoint
Example:
.. code-block:: python
from langchain_community.llms import OCIModelDeploymentTGI
llm = OCIModelDeploymentTGI(
endpoint="https://modeldeployment.<region>.oci.customer-oci.com/<md_ocid>/predict",
api="/v1/completions",
streaming=True,
temperature=0.2,
seed=42,
# other model parameters ...
)
"""
api: Literal["/generate", "/v1/completions"] = "/v1/completions"
"""Api spec."""
frequency_penalty: float = 0.0
"""Penalizes repeated tokens according to frequency. Between 0 and 1."""
seed: Optional[int] = None
"""Random sampling seed"""
repetition_penalty: Optional[float] = None
"""The parameter for repetition penalty. 1.0 means no penalty."""
suffix: Optional[str] = None
"""The text to append to the prompt. """
do_sample: bool = True
"""If set to True, this parameter enables decoding strategies such as
multi-nominal sampling, beam-search multi-nominal sampling, Top-K
sampling and Top-p sampling.
"""
watermark: bool = True
"""Watermarking with `A Watermark for Large Language Models <https://arxiv.org/abs/2301.10226>`_.
Defaults to True."""
return_full_text: bool = False
"""Whether to prepend the prompt to the generated text. Defaults to False."""
@property
def _llm_type(self) -> str:
"""Return type of llm."""
return "oci_model_deployment_tgi_endpoint"
@property
def _default_params(self) -> Dict[str, Any]:
"""Get the default parameters for invoking OCI model deployment TGI endpoint."""
return (
{
"model": self.model, # can be any
"frequency_penalty": self.frequency_penalty,
"max_tokens": self.max_tokens,
"repetition_penalty": self.repetition_penalty,
"temperature": self.temperature,
"top_p": self.p,
"seed": self.seed,
"stream": self.streaming,
"suffix": self.suffix,
"stop": self.stop,
}
if self.api == "/v1/completions"
else {
"best_of": self.best_of,
"max_new_tokens": self.max_tokens,
"temperature": self.temperature,
"top_k": (
self.k if self.k > 0 else None
), # `top_k` must be strictly positive'
"top_p": self.p,
"do_sample": self.do_sample,
"return_full_text": self.return_full_text,
"watermark": self.watermark,
"stop": self.stop,
}
)
@property
def _identifying_params(self) -> Dict[str, Any]:
"""Get the identifying parameters."""
_model_kwargs = self.model_kwargs or {}
return {
**{
"endpoint": self.endpoint,
"api": self.api,
"model_kwargs": _model_kwargs,
},
**self._default_params,
}
def _construct_json_body(self, prompt: str, params: dict) -> dict:
"""Construct request payload."""
if self.api == "/v1/completions":
return super()._construct_json_body(prompt, params)
return {
"inputs": prompt,
"parameters": params,
}
def _process_response(self, response_json: dict) -> List[Generation]:
"""Formats response."""
if self.api == "/v1/completions":
return super()._process_response(response_json)
try:
text = response_json["generated_text"]
except KeyError as e:
raise ValueError(
f"Error while formatting response payload.response_json={response_json}"
) from e
return [Generation(text=text)]
[docs]
class OCIModelDeploymentVLLM(OCIModelDeploymentLLM):
"""VLLM deployed on OCI Data Science Model Deployment
To use, you must provide the model HTTP endpoint from your deployed
model, e.g. https://modeldeployment.<region>.oci.customer-oci.com/<md_ocid>/predict.
To authenticate, `oracle-ads` has been used to automatically load
credentials: https://accelerated-data-science.readthedocs.io/en/latest/user_guide/cli/authentication.html
Make sure to have the required policies to access the OCI Data
Science Model Deployment endpoint. See:
https://docs.oracle.com/en-us/iaas/data-science/using/model-dep-policies-auth.htm#model_dep_policies_auth__predict-endpoint
Example:
.. code-block:: python
from langchain_community.llms import OCIModelDeploymentVLLM
llm = OCIModelDeploymentVLLM(
endpoint="https://modeldeployment.<region>.oci.customer-oci.com/<md_ocid>/predict",
model="odsc-llm",
streaming=False,
temperature=0.2,
max_tokens=512,
n=3,
best_of=3,
# other model parameters
)
"""
n: int = 1
"""Number of output sequences to return for the given prompt."""
k: int = -1
"""Number of most likely tokens to consider at each step."""
frequency_penalty: float = 0.0
"""Penalizes repeated tokens according to frequency. Between 0 and 1."""
presence_penalty: float = 0.0
"""Penalizes repeated tokens. Between 0 and 1."""
use_beam_search: bool = False
"""Whether to use beam search instead of sampling."""
ignore_eos: bool = False
"""Whether to ignore the EOS token and continue generating tokens after
the EOS token is generated."""
logprobs: Optional[int] = None
"""Number of log probabilities to return per output token."""
@property
def _llm_type(self) -> str:
"""Return type of llm."""
return "oci_model_deployment_vllm_endpoint"
@property
def _default_params(self) -> Dict[str, Any]:
"""Get the default parameters for calling vllm."""
return {
"best_of": self.best_of,
"frequency_penalty": self.frequency_penalty,
"ignore_eos": self.ignore_eos,
"logprobs": self.logprobs,
"max_tokens": self.max_tokens,
"model": self.model,
"n": self.n,
"presence_penalty": self.presence_penalty,
"stop": self.stop,
"stream": self.streaming,
"temperature": self.temperature,
"top_k": self.k,
"top_p": self.p,
"use_beam_search": self.use_beam_search,
}