import os
import re
import warnings
from abc import ABC, abstractmethod
from typing import Any, Callable, Dict, List, Mapping, Optional
import requests
from langchain_core.callbacks import CallbackManagerForLLMRun
from langchain_core.language_models import LLM
from langchain_core.pydantic_v1 import (
BaseModel,
Extra,
Field,
PrivateAttr,
root_validator,
validator,
)
__all__ = ["Databricks"]
class _DatabricksClientBase(BaseModel, ABC):
"""一个基本的JSON API客户端,用于与Databricks通信。"""
api_url: str
api_token: str
def request(self, method: str, url: str, request: Any) -> Any:
headers = {"Authorization": f"Bearer {self.api_token}"}
response = requests.request(
method=method, url=url, headers=headers, json=request
)
# TODO: error handling and automatic retries
if not response.ok:
raise ValueError(f"HTTP {response.status_code} error: {response.text}")
return response.json()
def _get(self, url: str) -> Any:
return self.request("GET", url, None)
def _post(self, url: str, request: Any) -> Any:
return self.request("POST", url, request)
@abstractmethod
def post(
self, request: Any, transform_output_fn: Optional[Callable[..., str]] = None
) -> Any:
...
@property
def llm(self) -> bool:
return False
def _transform_completions(response: Dict[str, Any]) -> str:
return response["choices"][0]["text"]
def _transform_llama2_chat(response: Dict[str, Any]) -> str:
return response["candidates"][0]["text"]
def _transform_chat(response: Dict[str, Any]) -> str:
return response["choices"][0]["message"]["content"]
class _DatabricksServingEndpointClient(_DatabricksClientBase):
"""一个与Databricks服务端点通信的API客户端。"""
host: str
endpoint_name: str
databricks_uri: str
client: Any = None
external_or_foundation: bool = False
task: Optional[str] = None
def __init__(self, **data: Any):
super().__init__(**data)
try:
from mlflow.deployments import get_deploy_client
self.client = get_deploy_client(self.databricks_uri)
except ImportError as e:
raise ImportError(
"Failed to create the client. "
"Please install mlflow with `pip install mlflow`."
) from e
endpoint = self.client.get_endpoint(self.endpoint_name)
self.external_or_foundation = endpoint.get("endpoint_type", "").lower() in (
"external_model",
"foundation_model_api",
)
if self.task is None:
self.task = endpoint.get("task")
@property
def llm(self) -> bool:
return self.task in ("llm/v1/chat", "llm/v1/completions", "llama2/chat")
@root_validator(pre=True)
def set_api_url(cls, values: Dict[str, Any]) -> Dict[str, Any]:
if "api_url" not in values:
host = values["host"]
endpoint_name = values["endpoint_name"]
api_url = f"https://{host}/serving-endpoints/{endpoint_name}/invocations"
values["api_url"] = api_url
return values
def post(
self, request: Any, transform_output_fn: Optional[Callable[..., str]] = None
) -> Any:
if self.external_or_foundation:
resp = self.client.predict(endpoint=self.endpoint_name, inputs=request)
if transform_output_fn:
return transform_output_fn(resp)
if self.task == "llm/v1/chat":
return _transform_chat(resp)
elif self.task == "llm/v1/completions":
return _transform_completions(resp)
return resp
else:
# See https://docs.databricks.com/machine-learning/model-serving/score-model-serving-endpoints.html
wrapped_request = {"dataframe_records": [request]}
response = self.client.predict(
endpoint=self.endpoint_name, inputs=wrapped_request
)
preds = response["predictions"]
# For a single-record query, the result is not a list.
pred = preds[0] if isinstance(preds, list) else preds
if self.task == "llama2/chat":
return _transform_llama2_chat(pred)
return transform_output_fn(pred) if transform_output_fn else pred
class _DatabricksClusterDriverProxyClient(_DatabricksClientBase):
"""一个与Databricks集群驱动程序代理应用程序通信的API客户端。"""
host: str
cluster_id: str
cluster_driver_port: str
@root_validator(pre=True)
def set_api_url(cls, values: Dict[str, Any]) -> Dict[str, Any]:
if "api_url" not in values:
host = values["host"]
cluster_id = values["cluster_id"]
port = values["cluster_driver_port"]
api_url = f"https://{host}/driver-proxy-api/o/0/{cluster_id}/{port}"
values["api_url"] = api_url
return values
def post(
self, request: Any, transform_output_fn: Optional[Callable[..., str]] = None
) -> Any:
resp = self._post(self.api_url, request)
return transform_output_fn(resp) if transform_output_fn else resp
[docs]def get_repl_context() -> Any:
"""如果在Databricks笔记本中运行,则获取笔记本的REPL上下文。
否则返回None。
"""
try:
from dbruntime.databricks_repl_context import get_context
return get_context()
except ImportError:
raise ImportError(
"Cannot access dbruntime, not running inside a Databricks notebook."
)
[docs]def get_default_host() -> str:
"""获取默认的Databricks工作区主机名。
如果无法自动确定主机名,则会引发错误。
"""
host = os.getenv("DATABRICKS_HOST")
if not host:
try:
host = get_repl_context().browserHostName
if not host:
raise ValueError("context doesn't contain browserHostName.")
except Exception as e:
raise ValueError(
"host was not set and cannot be automatically inferred. Set "
f"environment variable 'DATABRICKS_HOST'. Received error: {e}"
)
# TODO: support Databricks CLI profile
host = host.lstrip("https://").lstrip("http://").rstrip("/")
return host
[docs]def get_default_api_token() -> str:
"""获取默认的Databricks个人访问令牌。
如果无法自动确定令牌,则会引发错误。
"""
if api_token := os.getenv("DATABRICKS_TOKEN"):
return api_token
try:
api_token = get_repl_context().apiToken
if not api_token:
raise ValueError("context doesn't contain apiToken.")
except Exception as e:
raise ValueError(
"api_token was not set and cannot be automatically inferred. Set "
f"environment variable 'DATABRICKS_TOKEN'. Received error: {e}"
)
# TODO: support Databricks CLI profile
return api_token
def _is_hex_string(data: str) -> bool:
"""使用正则表达式检查数据是否为有效的十六进制字符串。"""
if not isinstance(data, str):
return False
pattern = r"^[0-9a-fA-F]+$"
return bool(re.match(pattern, data))
def _load_pickled_fn_from_hex_string(
data: str, allow_dangerous_deserialization: Optional[bool]
) -> Callable:
"""从十六进制字符串中加载一个pickle函数。"""
if not allow_dangerous_deserialization:
raise ValueError(
"This code relies on the pickle module. "
"You will need to set allow_dangerous_deserialization=True "
"if you want to opt-in to allow deserialization of data using pickle."
"Data can be compromised by a malicious actor if "
"not handled properly to include "
"a malicious payload that when deserialized with "
"pickle can execute arbitrary code on your machine."
)
try:
import cloudpickle
except Exception as e:
raise ValueError(f"Please install cloudpickle>=2.0.0. Error: {e}")
try:
return cloudpickle.loads(bytes.fromhex(data))
except Exception as e:
raise ValueError(
f"Failed to load the pickled function from a hexadecimal string. Error: {e}"
)
def _pickle_fn_to_hex_string(fn: Callable) -> str:
"""将一个函数封装并返回十六进制字符串。"""
try:
import cloudpickle
except Exception as e:
raise ValueError(f"Please install cloudpickle>=2.0.0. Error: {e}")
try:
return cloudpickle.dumps(fn).hex()
except Exception as e:
raise ValueError(f"Failed to pickle the function: {e}")
[docs]class Databricks(LLM):
"""Databricks提供端点或LLM的集群驱动代理应用程序。
它支持两种端点类型:
* **Serving endpoint** (推荐用于生产和开发)。
我们假设LLM已部署到服务端点。
要将其包装为LLM,您必须对端点具有“Can Query”权限。
相应地设置``endpoint_name``,不要设置``cluster_id``和``cluster_driver_port``。
如果底层模型是由MLflow注册的模型,则期望的模型签名为:
* 输入::
[{"name": "prompt", "type": "string"},
{"name": "stop", "type": "list[string]"}]
* 输出: ``[{"type": "string"}]``
如果底层模型是外部或基础模型,则除非提供了``transform_output_fn``,否则来自端点的响应会自动转换为期望的格式。
* **Cluster driver proxy app** (推荐用于交互式开发)。
可以在Databricks交互式集群上加载LLM,并在驱动节点上启动一个本地HTTP服务器,以使用HTTP POST方法在``/``处提供模型,使用JSON输入/输出。
请使用``[3000, 8000]``之间的端口号,并让服务器监听驱动器IP地址或简单地使用``0.0.0.0``而不是仅限于localhost。
要将其包装为LLM,您必须对集群具有“Can Attach To”权限。
设置``cluster_id``和``cluster_driver_port``,不要设置``endpoint_name``。
期望的服务器模式(使用JSON模式)为:
* 输入::
{"type": "object",
"properties": {
"prompt": {"type": "string"},
"stop": {"type": "array", "items": {"type": "string"}}},
"required": ["prompt"]}`
* 输出: ``{"type": "string"}``
如果端点模型签名不同或您想设置额外参数,可以使用`transform_input_fn`和`transform_output_fn`在查询之前和之后应用必要的转换。"""
host: str = Field(default_factory=get_default_host)
"""Databricks工作区主机名。
如果未提供,默认值由以下确定
* 如果存在``DATABRICKS_HOST``环境变量,则使用该值,或者
* 如果在Databricks笔记本中以"单用户"或"无隔离共享"模式运行,则使用当前Databricks工作区的主机名。"""
api_token: str = Field(default_factory=get_default_api_token)
"""Databricks个人访问令牌。
如果未提供,默认值由以下确定:
* 如果存在``DATABRICKS_TOKEN``环境变量,
* 或者在附加到交互式集群中的Databricks笔记本中运行时,
生成一个临时令牌,模式为"单用户"或"无隔离共享"。"""
endpoint_name: Optional[str] = None
"""模型服务端点的名称。
您必须指定端点名称以连接到模型服务端点。
不能同时设置``endpoint_name``和``cluster_id``。"""
cluster_id: Optional[str] = None
"""连接到集群驱动程序代理应用程序时的集群ID。
如果未提供``endpoint_name``或``cluster_id``,并且代码在附加到“单用户”或“无隔离共享”模式的Databricks笔记本中运行时,
当前集群ID将被用作默认值。
不能同时设置``endpoint_name``和``cluster_id``。"""
cluster_driver_port: Optional[str] = None
"""集群驱动节点上运行的HTTP服务器使用的端口号。
服务器应该监听驱动器IP地址或简单地使用``0.0.0.0``进行连接。
我们建议服务器使用端口号在``[3000, 8000]``之间。"""
model_kwargs: Optional[Dict[str, Any]] = None
"""已弃用。请改用“extra_params”。传递给端点的额外参数。"""
transform_input_fn: Optional[Callable] = None
"""将``{prompt, stop, **kwargs}``转换为端点接受的JSON兼容请求对象的函数。
例如,您可以将提示模板应用于输入提示。"""
transform_output_fn: Optional[Callable[..., str]] = None
"""一个将端点输出转换为生成文本的函数。"""
databricks_uri: str = "databricks"
"""Databricks URI。仅在使用服务端点时使用。"""
temperature: float = 0.0
"""采样温度。"""
n: int = 1
"""生成完成选项的数量。"""
stop: Optional[List[str]] = None
"""停止序列。"""
max_tokens: Optional[int] = None
"""生成的最大令牌数量。"""
extra_params: Dict[str, Any] = Field(default_factory=dict)
"""传递给端点的任何额外参数。"""
task: Optional[str] = None
"""端点的任务。仅在使用服务端点时使用。
如果未提供,则任务将从端点中自动推断。"""
allow_dangerous_deserialization: bool = False
"""是否允许对数据进行危险的反序列化,这涉及使用pickle加载数据。
如果数据已被恶意行为者修改,它可能传递恶意有效负载,导致在目标机器上执行任意代码。"""
_client: _DatabricksClientBase = PrivateAttr()
class Config:
extra = Extra.forbid
underscore_attrs_are_private = True
@property
def _llm_params(self) -> Dict[str, Any]:
params: Dict[str, Any] = {
"temperature": self.temperature,
"n": self.n,
}
if self.stop:
params["stop"] = self.stop
if self.max_tokens is not None:
params["max_tokens"] = self.max_tokens
return params
@validator("cluster_id", always=True)
def set_cluster_id(cls, v: Any, values: Dict[str, Any]) -> Optional[str]:
if v and values["endpoint_name"]:
raise ValueError("Cannot set both endpoint_name and cluster_id.")
elif values["endpoint_name"]:
return None
elif v:
return v
else:
try:
if v := get_repl_context().clusterId:
return v
raise ValueError("Context doesn't contain clusterId.")
except Exception as e:
raise ValueError(
"Neither endpoint_name nor cluster_id was set. "
"And the cluster_id cannot be automatically determined. Received"
f" error: {e}"
)
@validator("cluster_driver_port", always=True)
def set_cluster_driver_port(cls, v: Any, values: Dict[str, Any]) -> Optional[str]:
if v and values["endpoint_name"]:
raise ValueError("Cannot set both endpoint_name and cluster_driver_port.")
elif values["endpoint_name"]:
return None
elif v is None:
raise ValueError(
"Must set cluster_driver_port to connect to a cluster driver."
)
elif int(v) <= 0:
raise ValueError(f"Invalid cluster_driver_port: {v}")
else:
return v
@validator("model_kwargs", always=True)
def set_model_kwargs(cls, v: Optional[Dict[str, Any]]) -> Optional[Dict[str, Any]]:
if v:
assert "prompt" not in v, "model_kwargs must not contain key 'prompt'"
assert "stop" not in v, "model_kwargs must not contain key 'stop'"
return v
def __init__(self, **data: Any):
if "transform_input_fn" in data and _is_hex_string(data["transform_input_fn"]):
data["transform_input_fn"] = _load_pickled_fn_from_hex_string(
data=data["transform_input_fn"],
allow_dangerous_deserialization=data.get(
"allow_dangerous_deserialization"
),
)
if "transform_output_fn" in data and _is_hex_string(
data["transform_output_fn"]
):
data["transform_output_fn"] = _load_pickled_fn_from_hex_string(
data=data["transform_output_fn"],
allow_dangerous_deserialization=data.get(
"allow_dangerous_deserialization"
),
)
super().__init__(**data)
if self.model_kwargs is not None and self.extra_params is not None:
raise ValueError("Cannot set both extra_params and extra_params.")
elif self.model_kwargs is not None:
warnings.warn(
"model_kwargs is deprecated. Please use extra_params instead.",
DeprecationWarning,
)
if self.endpoint_name:
self._client = _DatabricksServingEndpointClient(
host=self.host,
api_token=self.api_token,
endpoint_name=self.endpoint_name,
databricks_uri=self.databricks_uri,
task=self.task,
)
elif self.cluster_id and self.cluster_driver_port:
self._client = _DatabricksClusterDriverProxyClient( # type: ignore[call-arg]
host=self.host,
api_token=self.api_token,
cluster_id=self.cluster_id,
cluster_driver_port=self.cluster_driver_port,
)
else:
raise ValueError(
"Must specify either endpoint_name or cluster_id/cluster_driver_port."
)
@property
def _default_params(self) -> Dict[str, Any]:
"""返回默认参数。"""
return {
"host": self.host,
# "api_token": self.api_token, # Never save the token
"endpoint_name": self.endpoint_name,
"cluster_id": self.cluster_id,
"cluster_driver_port": self.cluster_driver_port,
"databricks_uri": self.databricks_uri,
"model_kwargs": self.model_kwargs,
"temperature": self.temperature,
"n": self.n,
"stop": self.stop,
"max_tokens": self.max_tokens,
"extra_params": self.extra_params,
"task": self.task,
"transform_input_fn": None
if self.transform_input_fn is None
else _pickle_fn_to_hex_string(self.transform_input_fn),
"transform_output_fn": None
if self.transform_output_fn is None
else _pickle_fn_to_hex_string(self.transform_output_fn),
}
@property
def _identifying_params(self) -> Mapping[str, Any]:
return self._default_params
@property
def _llm_type(self) -> str:
"""llm的返回类型。"""
return "databricks"
def _call(
self,
prompt: str,
stop: Optional[List[str]] = None,
run_manager: Optional[CallbackManagerForLLMRun] = None,
**kwargs: Any,
) -> str:
"""使用给定的提示和停止序列查询LLM端点。"""
# TODO: support callbacks
request: Dict[str, Any] = {"prompt": prompt}
if self._client.llm:
request.update(self._llm_params)
request.update(self.model_kwargs or self.extra_params)
request.update(kwargs)
if stop:
request["stop"] = stop
if self.transform_input_fn:
request = self.transform_input_fn(**request)
return self._client.post(request, transform_output_fn=self.transform_output_fn)