import os
from typing import Any, Dict, List, Mapping, Optional, Union
from langchain_core.callbacks import CallbackManagerForLLMRun
from langchain_core.language_models.llms import LLM
from langchain_core.pydantic_v1 import Field, SecretStr
[docs]class Predibase(LLM):
"""使用Langchain与您的Predibase模型。
要使用,您应该已安装``predibase`` python包,并拥有您的Predibase API密钥。
`model`参数是Predibase的“无服务器”base_model ID(请参阅https://docs.predibase.com/user-guide/inference/models以获取目录)。
可选的`adapter_id`参数是经过微调的LLM适配器的Predibase ID或HuggingFace ID,其基本模型是`model`参数;经过微调的适配器必须与其基本模型兼容;否则,将引发错误。如果经过微调的适配器托管在Predibase上,则必须指定适配器存储库中的`adapter_version`。
一个可选的`predibase_sdk_version`参数默认为最新的SDK版本。"""
model: str
predibase_api_key: SecretStr
predibase_sdk_version: Optional[str] = None
adapter_id: Optional[str] = None
adapter_version: Optional[int] = None
model_kwargs: Dict[str, Any] = Field(default_factory=dict)
default_options_for_generation: dict = Field(
{
"max_new_tokens": 256,
"temperature": 0.1,
},
const=True,
)
@property
def _llm_type(self) -> str:
return "predibase"
def _call(
self,
prompt: str,
stop: Optional[List[str]] = None,
run_manager: Optional[CallbackManagerForLLMRun] = None,
**kwargs: Any,
) -> str:
options: Dict[str, Union[str, float]] = (
self.model_kwargs or self.default_options_for_generation
)
if self._is_deprecated_sdk_version():
try:
from predibase import PredibaseClient
from predibase.pql import get_session
from predibase.pql.api import (
ServerResponseError,
Session,
)
from predibase.resource.llm.interface import (
HuggingFaceLLM,
LLMDeployment,
)
from predibase.resource.llm.response import GeneratedResponse
from predibase.resource.model import Model
session: Session = get_session(
token=self.predibase_api_key.get_secret_value(),
gateway="https://api.app.predibase.com/v1",
serving_endpoint="serving.app.predibase.com",
)
pc: PredibaseClient = PredibaseClient(session=session)
except ImportError as e:
raise ImportError(
"Could not import Predibase Python package. "
"Please install it with `pip install predibase`."
) from e
except ValueError as e:
raise ValueError("Your API key is not correct. Please try again") from e
base_llm_deployment: LLMDeployment = pc.LLM(
uri=f"pb://deployments/{self.model}"
)
result: GeneratedResponse
if self.adapter_id:
"""
Attempt to retrieve the fine-tuned adapter from a Predibase
repository. If absent, then load the fine-tuned adapter
from a HuggingFace repository.
"""
adapter_model: Union[Model, HuggingFaceLLM]
try:
adapter_model = pc.get_model(
name=self.adapter_id,
version=self.adapter_version,
model_id=None,
)
except ServerResponseError:
# Predibase does not recognize the adapter ID (query HuggingFace).
adapter_model = pc.LLM(uri=f"hf://{self.adapter_id}")
result = base_llm_deployment.with_adapter(model=adapter_model).generate(
prompt=prompt,
options=options,
)
else:
result = base_llm_deployment.generate(
prompt=prompt,
options=options,
)
return result.response
from predibase import Predibase
os.environ["PREDIBASE_GATEWAY"] = "https://api.app.predibase.com"
predibase: Predibase = Predibase(
api_token=self.predibase_api_key.get_secret_value()
)
import requests
from lorax.client import Client as LoraxClient
from lorax.errors import GenerationError
from lorax.types import Response
lorax_client: LoraxClient = predibase.deployments.client(
deployment_ref=self.model
)
response: Response
if self.adapter_id:
"""
Attempt to retrieve the fine-tuned adapter from a Predibase repository.
If absent, then load the fine-tuned adapter from a HuggingFace repository.
"""
if self.adapter_version:
# Since the adapter version is provided, query the Predibase repository.
pb_adapter_id: str = f"{self.adapter_id}/{self.adapter_version}"
try:
response = lorax_client.generate(
prompt=prompt,
adapter_id=pb_adapter_id,
**options,
)
except GenerationError as ge:
raise ValueError(
f"""An adapter with the ID "{pb_adapter_id}" cannot be \
found in the Predibase repository of fine-tuned adapters."""
) from ge
else:
# The adapter version is omitted,
# hence look for the adapter ID in the HuggingFace repository.
try:
response = lorax_client.generate(
prompt=prompt,
adapter_id=self.adapter_id,
adapter_source="hub",
**options,
)
except GenerationError as ge:
raise ValueError(
f"""Either an adapter with the ID "{self.adapter_id}" \
cannot be found in a HuggingFace repository, or it is incompatible with the \
base model (please make sure that the adapter configuration is consistent).
"""
) from ge
else:
try:
response = lorax_client.generate(
prompt=prompt,
**options,
)
except requests.JSONDecodeError as jde:
raise ValueError(
f"""An LLM with the deployment ID "{self.model}" cannot be found \
at Predibase (please refer to \
"https://docs.predibase.com/user-guide/inference/models" for the list of \
supported models).
"""
) from jde
response_text = response.generated_text
return response_text
@property
def _identifying_params(self) -> Mapping[str, Any]:
"""获取识别参数。"""
return {
**{"model_kwargs": self.model_kwargs},
}
def _is_deprecated_sdk_version(self) -> bool:
try:
import semantic_version
from predibase.version import __version__ as current_version
from semantic_version.base import Version
sdk_semver_deprecated: Version = semantic_version.Version(
version_string="2024.4.8"
)
actual_current_version: str = self.predibase_sdk_version or current_version
sdk_semver_current: Version = semantic_version.Version(
version_string=actual_current_version
)
return not (
(sdk_semver_current > sdk_semver_deprecated)
or ("+dev" in actual_current_version)
)
except ImportError as e:
raise ImportError(
"Could not import Predibase Python package. "
"Please install it with `pip install semantic_version predibase`."
) from e