Source code for langchain_community.llms.databricks

import os
import re
import warnings
from abc import ABC, abstractmethod
from typing import Any, Callable, Dict, List, Mapping, Optional

import requests
from langchain_core.callbacks import CallbackManagerForLLMRun
from langchain_core.language_models import LLM
from langchain_core.pydantic_v1 import (
    BaseModel,
    Extra,
    Field,
    PrivateAttr,
    root_validator,
    validator,
)

__all__ = ["Databricks"]


class _DatabricksClientBase(BaseModel, ABC):
    """一个基本的JSON API客户端,用于与Databricks通信。"""

    api_url: str
    api_token: str

    def request(self, method: str, url: str, request: Any) -> Any:
        headers = {"Authorization": f"Bearer {self.api_token}"}
        response = requests.request(
            method=method, url=url, headers=headers, json=request
        )
        # TODO: error handling and automatic retries
        if not response.ok:
            raise ValueError(f"HTTP {response.status_code} error: {response.text}")
        return response.json()

    def _get(self, url: str) -> Any:
        return self.request("GET", url, None)

    def _post(self, url: str, request: Any) -> Any:
        return self.request("POST", url, request)

    @abstractmethod
    def post(
        self, request: Any, transform_output_fn: Optional[Callable[..., str]] = None
    ) -> Any:
        ...

    @property
    def llm(self) -> bool:
        return False


def _transform_completions(response: Dict[str, Any]) -> str:
    return response["choices"][0]["text"]


def _transform_llama2_chat(response: Dict[str, Any]) -> str:
    return response["candidates"][0]["text"]


def _transform_chat(response: Dict[str, Any]) -> str:
    return response["choices"][0]["message"]["content"]


class _DatabricksServingEndpointClient(_DatabricksClientBase):
    """一个与Databricks服务端点通信的API客户端。"""

    host: str
    endpoint_name: str
    databricks_uri: str
    client: Any = None
    external_or_foundation: bool = False
    task: Optional[str] = None

    def __init__(self, **data: Any):
        super().__init__(**data)

        try:
            from mlflow.deployments import get_deploy_client

            self.client = get_deploy_client(self.databricks_uri)
        except ImportError as e:
            raise ImportError(
                "Failed to create the client. "
                "Please install mlflow with `pip install mlflow`."
            ) from e

        endpoint = self.client.get_endpoint(self.endpoint_name)
        self.external_or_foundation = endpoint.get("endpoint_type", "").lower() in (
            "external_model",
            "foundation_model_api",
        )
        if self.task is None:
            self.task = endpoint.get("task")

    @property
    def llm(self) -> bool:
        return self.task in ("llm/v1/chat", "llm/v1/completions", "llama2/chat")

    @root_validator(pre=True)
    def set_api_url(cls, values: Dict[str, Any]) -> Dict[str, Any]:
        if "api_url" not in values:
            host = values["host"]
            endpoint_name = values["endpoint_name"]
            api_url = f"https://{host}/serving-endpoints/{endpoint_name}/invocations"
            values["api_url"] = api_url
        return values

    def post(
        self, request: Any, transform_output_fn: Optional[Callable[..., str]] = None
    ) -> Any:
        if self.external_or_foundation:
            resp = self.client.predict(endpoint=self.endpoint_name, inputs=request)
            if transform_output_fn:
                return transform_output_fn(resp)

            if self.task == "llm/v1/chat":
                return _transform_chat(resp)
            elif self.task == "llm/v1/completions":
                return _transform_completions(resp)

            return resp
        else:
            # See https://docs.databricks.com/machine-learning/model-serving/score-model-serving-endpoints.html
            wrapped_request = {"dataframe_records": [request]}
            response = self.client.predict(
                endpoint=self.endpoint_name, inputs=wrapped_request
            )
            preds = response["predictions"]
            # For a single-record query, the result is not a list.
            pred = preds[0] if isinstance(preds, list) else preds
            if self.task == "llama2/chat":
                return _transform_llama2_chat(pred)
            return transform_output_fn(pred) if transform_output_fn else pred


class _DatabricksClusterDriverProxyClient(_DatabricksClientBase):
    """一个与Databricks集群驱动程序代理应用程序通信的API客户端。"""

    host: str
    cluster_id: str
    cluster_driver_port: str

    @root_validator(pre=True)
    def set_api_url(cls, values: Dict[str, Any]) -> Dict[str, Any]:
        if "api_url" not in values:
            host = values["host"]
            cluster_id = values["cluster_id"]
            port = values["cluster_driver_port"]
            api_url = f"https://{host}/driver-proxy-api/o/0/{cluster_id}/{port}"
            values["api_url"] = api_url
        return values

    def post(
        self, request: Any, transform_output_fn: Optional[Callable[..., str]] = None
    ) -> Any:
        resp = self._post(self.api_url, request)
        return transform_output_fn(resp) if transform_output_fn else resp


[docs]def get_repl_context() -> Any: """如果在Databricks笔记本中运行,则获取笔记本的REPL上下文。 否则返回None。 """ try: from dbruntime.databricks_repl_context import get_context return get_context() except ImportError: raise ImportError( "Cannot access dbruntime, not running inside a Databricks notebook." )
[docs]def get_default_host() -> str: """获取默认的Databricks工作区主机名。 如果无法自动确定主机名,则会引发错误。 """ host = os.getenv("DATABRICKS_HOST") if not host: try: host = get_repl_context().browserHostName if not host: raise ValueError("context doesn't contain browserHostName.") except Exception as e: raise ValueError( "host was not set and cannot be automatically inferred. Set " f"environment variable 'DATABRICKS_HOST'. Received error: {e}" ) # TODO: support Databricks CLI profile host = host.lstrip("https://").lstrip("http://").rstrip("/") return host
[docs]def get_default_api_token() -> str: """获取默认的Databricks个人访问令牌。 如果无法自动确定令牌,则会引发错误。 """ if api_token := os.getenv("DATABRICKS_TOKEN"): return api_token try: api_token = get_repl_context().apiToken if not api_token: raise ValueError("context doesn't contain apiToken.") except Exception as e: raise ValueError( "api_token was not set and cannot be automatically inferred. Set " f"environment variable 'DATABRICKS_TOKEN'. Received error: {e}" ) # TODO: support Databricks CLI profile return api_token
def _is_hex_string(data: str) -> bool: """使用正则表达式检查数据是否为有效的十六进制字符串。""" if not isinstance(data, str): return False pattern = r"^[0-9a-fA-F]+$" return bool(re.match(pattern, data)) def _load_pickled_fn_from_hex_string( data: str, allow_dangerous_deserialization: Optional[bool] ) -> Callable: """从十六进制字符串中加载一个pickle函数。""" if not allow_dangerous_deserialization: raise ValueError( "This code relies on the pickle module. " "You will need to set allow_dangerous_deserialization=True " "if you want to opt-in to allow deserialization of data using pickle." "Data can be compromised by a malicious actor if " "not handled properly to include " "a malicious payload that when deserialized with " "pickle can execute arbitrary code on your machine." ) try: import cloudpickle except Exception as e: raise ValueError(f"Please install cloudpickle>=2.0.0. Error: {e}") try: return cloudpickle.loads(bytes.fromhex(data)) except Exception as e: raise ValueError( f"Failed to load the pickled function from a hexadecimal string. Error: {e}" ) def _pickle_fn_to_hex_string(fn: Callable) -> str: """将一个函数封装并返回十六进制字符串。""" try: import cloudpickle except Exception as e: raise ValueError(f"Please install cloudpickle>=2.0.0. Error: {e}") try: return cloudpickle.dumps(fn).hex() except Exception as e: raise ValueError(f"Failed to pickle the function: {e}")
[docs]class Databricks(LLM): """Databricks提供端点或LLM的集群驱动代理应用程序。 它支持两种端点类型: * **Serving endpoint** (推荐用于生产和开发)。 我们假设LLM已部署到服务端点。 要将其包装为LLM,您必须对端点具有“Can Query”权限。 相应地设置``endpoint_name``,不要设置``cluster_id``和``cluster_driver_port``。 如果底层模型是由MLflow注册的模型,则期望的模型签名为: * 输入:: [{"name": "prompt", "type": "string"}, {"name": "stop", "type": "list[string]"}] * 输出: ``[{"type": "string"}]`` 如果底层模型是外部或基础模型,则除非提供了``transform_output_fn``,否则来自端点的响应会自动转换为期望的格式。 * **Cluster driver proxy app** (推荐用于交互式开发)。 可以在Databricks交互式集群上加载LLM,并在驱动节点上启动一个本地HTTP服务器,以使用HTTP POST方法在``/``处提供模型,使用JSON输入/输出。 请使用``[3000, 8000]``之间的端口号,并让服务器监听驱动器IP地址或简单地使用``0.0.0.0``而不是仅限于localhost。 要将其包装为LLM,您必须对集群具有“Can Attach To”权限。 设置``cluster_id``和``cluster_driver_port``,不要设置``endpoint_name``。 期望的服务器模式(使用JSON模式)为: * 输入:: {"type": "object", "properties": { "prompt": {"type": "string"}, "stop": {"type": "array", "items": {"type": "string"}}}, "required": ["prompt"]}` * 输出: ``{"type": "string"}`` 如果端点模型签名不同或您想设置额外参数,可以使用`transform_input_fn`和`transform_output_fn`在查询之前和之后应用必要的转换。""" host: str = Field(default_factory=get_default_host) """Databricks工作区主机名。 如果未提供,默认值由以下确定 * 如果存在``DATABRICKS_HOST``环境变量,则使用该值,或者 * 如果在Databricks笔记本中以"单用户"或"无隔离共享"模式运行,则使用当前Databricks工作区的主机名。""" api_token: str = Field(default_factory=get_default_api_token) """Databricks个人访问令牌。 如果未提供,默认值由以下确定: * 如果存在``DATABRICKS_TOKEN``环境变量, * 或者在附加到交互式集群中的Databricks笔记本中运行时, 生成一个临时令牌,模式为"单用户"或"无隔离共享"。""" endpoint_name: Optional[str] = None """模型服务端点的名称。 您必须指定端点名称以连接到模型服务端点。 不能同时设置``endpoint_name``和``cluster_id``。""" cluster_id: Optional[str] = None """连接到集群驱动程序代理应用程序时的集群ID。 如果未提供``endpoint_name``或``cluster_id``,并且代码在附加到“单用户”或“无隔离共享”模式的Databricks笔记本中运行时, 当前集群ID将被用作默认值。 不能同时设置``endpoint_name``和``cluster_id``。""" cluster_driver_port: Optional[str] = None """集群驱动节点上运行的HTTP服务器使用的端口号。 服务器应该监听驱动器IP地址或简单地使用``0.0.0.0``进行连接。 我们建议服务器使用端口号在``[3000, 8000]``之间。""" model_kwargs: Optional[Dict[str, Any]] = None """已弃用。请改用“extra_params”。传递给端点的额外参数。""" transform_input_fn: Optional[Callable] = None """将``{prompt, stop, **kwargs}``转换为端点接受的JSON兼容请求对象的函数。 例如,您可以将提示模板应用于输入提示。""" transform_output_fn: Optional[Callable[..., str]] = None """一个将端点输出转换为生成文本的函数。""" databricks_uri: str = "databricks" """Databricks URI。仅在使用服务端点时使用。""" temperature: float = 0.0 """采样温度。""" n: int = 1 """生成完成选项的数量。""" stop: Optional[List[str]] = None """停止序列。""" max_tokens: Optional[int] = None """生成的最大令牌数量。""" extra_params: Dict[str, Any] = Field(default_factory=dict) """传递给端点的任何额外参数。""" task: Optional[str] = None """端点的任务。仅在使用服务端点时使用。 如果未提供,则任务将从端点中自动推断。""" allow_dangerous_deserialization: bool = False """是否允许对数据进行危险的反序列化,这涉及使用pickle加载数据。 如果数据已被恶意行为者修改,它可能传递恶意有效负载,导致在目标机器上执行任意代码。""" _client: _DatabricksClientBase = PrivateAttr() class Config: extra = Extra.forbid underscore_attrs_are_private = True @property def _llm_params(self) -> Dict[str, Any]: params: Dict[str, Any] = { "temperature": self.temperature, "n": self.n, } if self.stop: params["stop"] = self.stop if self.max_tokens is not None: params["max_tokens"] = self.max_tokens return params @validator("cluster_id", always=True) def set_cluster_id(cls, v: Any, values: Dict[str, Any]) -> Optional[str]: if v and values["endpoint_name"]: raise ValueError("Cannot set both endpoint_name and cluster_id.") elif values["endpoint_name"]: return None elif v: return v else: try: if v := get_repl_context().clusterId: return v raise ValueError("Context doesn't contain clusterId.") except Exception as e: raise ValueError( "Neither endpoint_name nor cluster_id was set. " "And the cluster_id cannot be automatically determined. Received" f" error: {e}" ) @validator("cluster_driver_port", always=True) def set_cluster_driver_port(cls, v: Any, values: Dict[str, Any]) -> Optional[str]: if v and values["endpoint_name"]: raise ValueError("Cannot set both endpoint_name and cluster_driver_port.") elif values["endpoint_name"]: return None elif v is None: raise ValueError( "Must set cluster_driver_port to connect to a cluster driver." ) elif int(v) <= 0: raise ValueError(f"Invalid cluster_driver_port: {v}") else: return v @validator("model_kwargs", always=True) def set_model_kwargs(cls, v: Optional[Dict[str, Any]]) -> Optional[Dict[str, Any]]: if v: assert "prompt" not in v, "model_kwargs must not contain key 'prompt'" assert "stop" not in v, "model_kwargs must not contain key 'stop'" return v def __init__(self, **data: Any): if "transform_input_fn" in data and _is_hex_string(data["transform_input_fn"]): data["transform_input_fn"] = _load_pickled_fn_from_hex_string( data=data["transform_input_fn"], allow_dangerous_deserialization=data.get( "allow_dangerous_deserialization" ), ) if "transform_output_fn" in data and _is_hex_string( data["transform_output_fn"] ): data["transform_output_fn"] = _load_pickled_fn_from_hex_string( data=data["transform_output_fn"], allow_dangerous_deserialization=data.get( "allow_dangerous_deserialization" ), ) super().__init__(**data) if self.model_kwargs is not None and self.extra_params is not None: raise ValueError("Cannot set both extra_params and extra_params.") elif self.model_kwargs is not None: warnings.warn( "model_kwargs is deprecated. Please use extra_params instead.", DeprecationWarning, ) if self.endpoint_name: self._client = _DatabricksServingEndpointClient( host=self.host, api_token=self.api_token, endpoint_name=self.endpoint_name, databricks_uri=self.databricks_uri, task=self.task, ) elif self.cluster_id and self.cluster_driver_port: self._client = _DatabricksClusterDriverProxyClient( # type: ignore[call-arg] host=self.host, api_token=self.api_token, cluster_id=self.cluster_id, cluster_driver_port=self.cluster_driver_port, ) else: raise ValueError( "Must specify either endpoint_name or cluster_id/cluster_driver_port." ) @property def _default_params(self) -> Dict[str, Any]: """返回默认参数。""" return { "host": self.host, # "api_token": self.api_token, # Never save the token "endpoint_name": self.endpoint_name, "cluster_id": self.cluster_id, "cluster_driver_port": self.cluster_driver_port, "databricks_uri": self.databricks_uri, "model_kwargs": self.model_kwargs, "temperature": self.temperature, "n": self.n, "stop": self.stop, "max_tokens": self.max_tokens, "extra_params": self.extra_params, "task": self.task, "transform_input_fn": None if self.transform_input_fn is None else _pickle_fn_to_hex_string(self.transform_input_fn), "transform_output_fn": None if self.transform_output_fn is None else _pickle_fn_to_hex_string(self.transform_output_fn), } @property def _identifying_params(self) -> Mapping[str, Any]: return self._default_params @property def _llm_type(self) -> str: """llm的返回类型。""" return "databricks" def _call( self, prompt: str, stop: Optional[List[str]] = None, run_manager: Optional[CallbackManagerForLLMRun] = None, **kwargs: Any, ) -> str: """使用给定的提示和停止序列查询LLM端点。""" # TODO: support callbacks request: Dict[str, Any] = {"prompt": prompt} if self._client.llm: request.update(self._llm_params) request.update(self.model_kwargs or self.extra_params) request.update(kwargs) if stop: request["stop"] = stop if self.transform_input_fn: request = self.transform_input_fn(**request) return self._client.post(request, transform_output_fn=self.transform_output_fn)