Skip to content

Predibase

PredibaseLLM #

Bases: CustomLLM

Predibase LLM。

要使用,您应该已安装 predibase python包,并拥有您的Predibase API密钥。

model_name 参数是Predibase的“无服务器” base_model ID (参见 https://docs.predibase.com/user-guide/inference/models 获取目录)。

可选的 adapter_id 参数是经过微调的LLM适配器的Predibase ID或HuggingFace ID, 其基础模型是 model 参数;经过微调的适配器必须与其基础模型兼容;否则会引发错误。 如果经过微调的适配器托管在Predibase上,则必须指定 adapter_version

示例

pip install llama-index-llms-predibase

import os

os.environ["PREDIBASE_API_TOKEN"] = "{PREDIBASE_API_TOKEN}"

from llama_index.llms.predibase import PredibaseLLM

llm = PredibaseLLM(
    model_name="mistral-7b",
    predibase_sdk_version=None,  # 可选参数(如果省略,默认为最新的Predibase SDK版本)
    adapter_id="my-adapter-id",  # 可选参数
    adapter_version=3,  # 可选参数(仅适用于Predibase)
    temperature=0.3,
    max_new_tokens=512,
)
response = llm.complete("Hello World!")
print(str(response))
Source code in llama_index/llms/predibase/base.py
 23
 24
 25
 26
 27
 28
 29
 30
 31
 32
 33
 34
 35
 36
 37
 38
 39
 40
 41
 42
 43
 44
 45
 46
 47
 48
 49
 50
 51
 52
 53
 54
 55
 56
 57
 58
 59
 60
 61
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
class PredibaseLLM(CustomLLM):
    """Predibase LLM。

    要使用,您应该已安装 ``predibase`` python包,并拥有您的Predibase API密钥。

    `model_name` 参数是Predibase的“无服务器” base_model ID
    (参见 https://docs.predibase.com/user-guide/inference/models 获取目录)。

    可选的 `adapter_id` 参数是经过微调的LLM适配器的Predibase ID或HuggingFace ID,
    其基础模型是 `model` 参数;经过微调的适配器必须与其基础模型兼容;否则会引发错误。
    如果经过微调的适配器托管在Predibase上,则必须指定 `adapter_version`。

    示例:
        `pip install llama-index-llms-predibase`

        ```python
        import os

        os.environ["PREDIBASE_API_TOKEN"] = "{PREDIBASE_API_TOKEN}"

        from llama_index.llms.predibase import PredibaseLLM

        llm = PredibaseLLM(
            model_name="mistral-7b",
            predibase_sdk_version=None,  # 可选参数(如果省略,默认为最新的Predibase SDK版本)
            adapter_id="my-adapter-id",  # 可选参数
            adapter_version=3,  # 可选参数(仅适用于Predibase)
            temperature=0.3,
            max_new_tokens=512,
        )
        response = llm.complete("Hello World!")
        print(str(response))
        ```"""

    model_name: str = Field(description="The Predibase base model to use.")
    predibase_api_key: str = Field(description="The Predibase API key to use.")
    predibase_sdk_version: str = Field(
        default=None,
        description="The optional version (string) of the Predibase SDK (defaults to the latest if not specified).",
    )
    adapter_id: str = Field(
        default=None,
        description="The optional Predibase ID or HuggingFace ID of a fine-tuned adapter to use.",
    )
    adapter_version: str = Field(
        default=None,
        description="The optional version number of fine-tuned adapter use (applies to Predibase only).",
    )
    max_new_tokens: int = Field(
        default=DEFAULT_NUM_OUTPUTS,
        description="The number of tokens to generate.",
        gt=0,
    )
    temperature: float = Field(
        default=DEFAULT_TEMPERATURE,
        description="The temperature to use for sampling.",
        gte=0.0,
        lte=1.0,
    )
    context_window: int = Field(
        default=DEFAULT_CONTEXT_WINDOW,
        description="The number of context tokens available to the LLM.",
        gt=0,
    )

    _client: Any = PrivateAttr()

    def __init__(
        self,
        model_name: str,
        predibase_api_key: Optional[str] = None,
        predibase_sdk_version: Optional[str] = None,
        adapter_id: Optional[str] = None,
        adapter_version: Optional[int] = None,
        max_new_tokens: int = DEFAULT_NUM_OUTPUTS,
        temperature: float = DEFAULT_TEMPERATURE,
        context_window: int = DEFAULT_CONTEXT_WINDOW,
        callback_manager: Optional[CallbackManager] = None,
        system_prompt: Optional[str] = None,
        messages_to_prompt: Optional[Callable[[Sequence[ChatMessage]], str]] = None,
        completion_to_prompt: Optional[Callable[[str], str]] = None,
        pydantic_program_mode: PydanticProgramMode = PydanticProgramMode.DEFAULT,
        output_parser: Optional[BaseOutputParser] = None,
    ) -> None:
        predibase_api_key = (
            predibase_api_key
            if predibase_api_key
            else os.environ.get("PREDIBASE_API_TOKEN")
        )
        assert predibase_api_key is not None

        super().__init__(
            model_name=model_name,
            predibase_api_key=predibase_api_key,
            predibase_sdk_version=predibase_sdk_version,
            adapter_id=adapter_id,
            adapter_version=adapter_version,
            max_new_tokens=max_new_tokens,
            temperature=temperature,
            context_window=context_window,
            callback_manager=callback_manager,
            system_prompt=system_prompt,
            messages_to_prompt=messages_to_prompt,
            completion_to_prompt=completion_to_prompt,
            pydantic_program_mode=pydantic_program_mode,
            output_parser=output_parser,
        )

        self._client: Union["PredibaseClient", "Predibase"] = self.initialize_client()

    def initialize_client(
        self,
    ) -> Union["PredibaseClient", "Predibase"]:
        try:
            if self._is_deprecated_sdk_version():
                from predibase import PredibaseClient
                from predibase.pql import get_session
                from predibase.pql.api import Session

                session: Session = get_session(
                    token=self.predibase_api_key,
                    gateway="https://api.app.predibase.com/v1",
                    serving_endpoint="serving.app.predibase.com",
                )
                return PredibaseClient(session=session)

            from predibase import Predibase

            os.environ["PREDIBASE_GATEWAY"] = "https://api.app.predibase.com"
            return Predibase(api_token=self.predibase_api_key)
        except ValueError as e:
            raise ValueError("Your API key is not correct. Please try again") from e

    @classmethod
    def class_name(cls) -> str:
        return "PredibaseLLM"

    @property
    def metadata(self) -> LLMMetadata:
        """获取LLM元数据。"""
        return LLMMetadata(
            context_window=self.context_window,
            num_output=self.max_new_tokens,
            model_name=self.model_name,
        )

    @llm_completion_callback()
    def complete(
        self, prompt: str, formatted: bool = False, **kwargs: Any
    ) -> "CompletionResponse":
        options: Dict[str, Union[str, float]] = copy.deepcopy(kwargs)
        options.update(
            {
                "max_new_tokens": self.max_new_tokens,
                "temperature": self.temperature,
            }
        )

        response_text: str

        if self._is_deprecated_sdk_version():
            from predibase.pql.api import ServerResponseError
            from predibase.resource.llm.interface import (
                HuggingFaceLLM,
                LLMDeployment,
            )
            from predibase.resource.llm.response import GeneratedResponse
            from predibase.resource.model import Model

            base_llm_deployment: LLMDeployment = self._client.LLM(
                uri=f"pb://deployments/{self.model_name}"
            )

            result: GeneratedResponse
            if self.adapter_id:
                """
                Attempt to retrieve the fine-tuned adapter from a Predibase repository.
                If absent, then load the fine-tuned adapter from a HuggingFace repository.
                """
                adapter_model: Union[Model, HuggingFaceLLM]
                try:
                    adapter_model = self._client.get_model(
                        name=self.adapter_id,
                        version=self.adapter_version,
                        model_id=None,
                    )
                except ServerResponseError:
                    # Predibase does not recognize the adapter ID (query HuggingFace).
                    adapter_model = self._client.LLM(uri=f"hf://{self.adapter_id}")
                result = base_llm_deployment.with_adapter(model=adapter_model).generate(
                    prompt=prompt,
                    options=options,
                )
            else:
                result = base_llm_deployment.generate(
                    prompt=prompt,
                    options=options,
                )
            response_text = result.response
        else:
            import requests
            from lorax.client import Client as LoraxClient
            from lorax.errors import GenerationError
            from lorax.types import Response

            lorax_client: LoraxClient = self._client.deployments.client(
                deployment_ref=self.model_name
            )

            response: Response
            if self.adapter_id:
                """
                Attempt to retrieve the fine-tuned adapter from a Predibase repository.
                If absent, then load the fine-tuned adapter from a HuggingFace repository.
                """
                if self.adapter_version:
                    # Since the adapter version is provided, query the Predibase repository.
                    pb_adapter_id: str = f"{self.adapter_id}/{self.adapter_version}"
                    try:
                        response = lorax_client.generate(
                            prompt=prompt,
                            adapter_id=pb_adapter_id,
                            **options,
                        )
                    except GenerationError as ge:
                        raise ValueError(
                            f'An adapter with the ID "{pb_adapter_id}" cannot be found in the Predibase repository of fine-tuned adapters.'
                        ) from ge
                else:
                    # The adapter version is omitted, hence look for the adapter ID in the HuggingFace repository.
                    try:
                        response = lorax_client.generate(
                            prompt=prompt,
                            adapter_id=self.adapter_id,
                            adapter_source="hub",
                            **options,
                        )
                    except GenerationError as ge:
                        raise ValueError(
                            f"""Either an adapter with the ID "{self.adapter_id}" cannot be found in a HuggingFace repository, \
or it is incompatible with the base model (please make sure that the adapter configuration is consistent).
"""
                        ) from ge
            else:
                try:
                    response = lorax_client.generate(
                        prompt=prompt,
                        **options,
                    )
                except requests.JSONDecodeError as jde:
                    raise ValueError(
                        f"""An LLM with the deployment ID "{self.model_name}" cannot be found at Predibase \
(please refer to "https://docs.predibase.com/user-guide/inference/models" for the list of supported models).
"""
                    ) from jde
            response_text = response.generated_text

        return CompletionResponse(text=response_text)

    @llm_completion_callback()
    def stream_complete(
        self, prompt: str, formatted: bool = False, **kwargs: Any
    ) -> "CompletionResponseGen":
        raise NotImplementedError

    def _is_deprecated_sdk_version(self) -> bool:
        try:
            import semantic_version
            from semantic_version.base import Version

            from predibase.version import __version__ as current_version

            sdk_semver_deprecated: Version = semantic_version.Version(
                version_string="2024.4.8"
            )
            actual_current_version: str = self.predibase_sdk_version or current_version
            sdk_semver_current: Version = semantic_version.Version(
                version_string=actual_current_version
            )
            return not (
                (sdk_semver_current > sdk_semver_deprecated)
                or ("+dev" in actual_current_version)
            )
        except ImportError as e:
            raise ImportError(
                "Could not import Predibase Python package. "
                "Please install it with `pip install semantic_version predibase`."
            ) from e

metadata property #

metadata: LLMMetadata

获取LLM元数据。