ray.air.integrations.mlflow.setup_mlflow#
- ray.air.integrations.mlflow.setup_mlflow(config: Dict | None = None, tracking_uri: str | None = None, registry_uri: str | None = None, experiment_id: str | None = None, experiment_name: str | None = None, tracking_token: str | None = None, artifact_location: str | None = None, run_name: str | None = None, create_experiment_if_not_exists: bool = False, tags: Dict | None = None, rank_zero_only: bool = True) ModuleType | _NoopModule [源代码]#
设置一个 MLflow 会话。
此函数可用于在(分布式)训练或调优运行中初始化一个 MLflow 会话。该会话将在可训练对象上创建。
默认情况下,MLflow 实验 ID 是 Ray 试验 ID,而 MLflow 实验名称是 Ray 试验名称。这些设置可以通过传递相应的关键字参数来覆盖。
config
字典会自动记录为运行参数(不包括 mlflow 设置)。在使用 Ray Train 进行分布式训练时,只有零级工作进程会初始化 mlflow。所有其他工作进程将返回一个 noop 客户端,以避免在分布式运行中重复记录。可以通过传递
rank_zero_only=False
来禁用此功能,这样每个训练工作进程都会初始化 mlflow。此函数将返回
mlflow
模块或非零级工作者的noop模块如果 rank_zero_only=True
。通过使用mlflow = setup_mlflow(config)
,您可以确保只有零级工作者调用mlflow API。- 参数:
config – 要记录到 mlflow 中的配置字典作为参数。
tracking_uri – MLflow 跟踪的跟踪 URI。如果在多节点设置中使用 Tune,请确保使用远程服务器进行跟踪。
registry_uri – MLflow 模型注册表的注册 URI。
experiment_id – 一个已创建的 MLflow 实验的 ID。所有来自
tune.Tuner()
中所有试验的日志都将报告到此实验。如果未提供此参数或具有此 ID 的实验不存在,则必须提供一个experiment_name
。此参数优先于experiment_name
。experiment_name – 一个已经存在的 MLflow 实验的名称。所有来自
tune.Tuner()
中所有试验的日志都将报告给这个实验。如果没有提供这个名称,你必须提供一个有效的experiment_id
。tracking_token – 用于在登录远程跟踪服务器时进行HTTP身份验证的令牌。例如,当你想登录到Databricks服务器时,这非常有用。此值将用于在所有远程训练过程中设置MLFLOW_TRACKING_TOKEN环境变量。
artifact_location – 存储运行工件的位置。如果未提供,MLFlow 会选择一个适当的默认值。如果实验已存在,则忽略此项。
run_name – 将被创建的新 MLflow 运行的名称。如果未设置,将默认为
experiment_name
。create_experiment_if_not_exists – 如果提供的名称不存在,是否创建一个实验。默认为 False。
tags – 新运行的标签设置。
rank_zero_only – 如果为 True,则仅在分布式训练中为 rank 0 的 worker 返回一个初始化的会话。如果为 False,则为所有 worker 初始化一个会话。默认为 True。
示例
默认情况下,您只需调用
setup_mlflow
并继续像平常一样使用 MLflow:from ray.air.integrations.mlflow import setup_mlflow def training_loop(config): mlflow = setup_mlflow(config) # ... mlflow.log_metric(key="loss", val=0.123, step=0)
在分布式数据并行训练中,你可以利用
setup_mlflow
的返回值。这将确保它仅在分布式训练运行的第一个工作节点上调用。from ray.air.integrations.mlflow import setup_mlflow def training_loop(config): mlflow = setup_mlflow(config) # ... mlflow.log_metric(key="loss", val=0.123, step=0)
如果你使用的是像 Pytorch Lightning、XGBoost 等训练框架,你也可以使用 MlFlow 的自动记录功能。更多信息可以在这里找到 (https://mlflow.org/docs/latest/tracking.html#automatic-logging)。
from ray.air.integrations.mlflow import setup_mlflow def train_fn(config): mlflow = setup_mlflow(config) mlflow.autolog() xgboost_results = xgb.train(config, ...)
PublicAPI (alpha): 此API处于alpha阶段,可能在稳定之前发生变化。