备注
前往结尾 以下载完整示例代码。
在CPU上使用Dask训练生存模型的示例
import os
import dask.array as da
import dask.dataframe as dd
from dask.distributed import Client, LocalCluster
from xgboost import dask as dxgb
from xgboost.dask import DaskDMatrix
def main(client: Client) -> da.Array:
# Load an example survival data from CSV into a Dask data frame.
# The Veterans' Administration Lung Cancer Trial
# The Statistical Analysis of Failure Time Data by Kalbfleisch J. and Prentice R (1980)
CURRENT_DIR = os.path.dirname(__file__)
df = dd.read_csv(
os.path.join(CURRENT_DIR, os.pardir, "data", "veterans_lung_cancer.csv")
)
# DaskDMatrix acts like normal DMatrix, works as a proxy for local
# DMatrix scatter around workers.
# For AFT survival, you'd need to extract the lower and upper bounds for the label
# and pass them as arguments to DaskDMatrix.
y_lower_bound = df["Survival_label_lower_bound"]
y_upper_bound = df["Survival_label_upper_bound"]
X = df.drop(["Survival_label_lower_bound", "Survival_label_upper_bound"], axis=1)
dtrain = DaskDMatrix(
client, X, label_lower_bound=y_lower_bound, label_upper_bound=y_upper_bound
)
# Use train method from xgboost.dask instead of xgboost. This
# distributed version of train returns a dictionary containing the
# resulting booster and evaluation history obtained from
# evaluation metrics.
params = {
"verbosity": 1,
"objective": "survival:aft",
"eval_metric": "aft-nloglik",
"learning_rate": 0.05,
"aft_loss_distribution_scale": 1.20,
"aft_loss_distribution": "normal",
"max_depth": 6,
"lambda": 0.01,
"alpha": 0.02,
}
output = dxgb.train(
client, params, dtrain, num_boost_round=100, evals=[(dtrain, "train")]
)
bst = output["booster"]
history = output["history"]
# you can pass output directly into `predict` too.
prediction = dxgb.predict(client, bst, dtrain)
print("Evaluation history: ", history)
# Uncomment the following line to save the model to the disk
# bst.save_model('survival_model.json')
return prediction
if __name__ == "__main__":
# or use other clusters for scaling
with LocalCluster(n_workers=7, threads_per_worker=4) as cluster:
with Client(cluster) as client:
main(client)