SparkXGBForecast

Spark XGBoost 预测器

xgboost.spark.SparkXGBRegressor 的封装器,增加了一个 extract_local_model 方法,用于获取经过训练的模型的本地版本并将其广播到工作节点。

import xgboost as xgb
try:
    from xgboost.spark import SparkXGBRegressor  # 类型:忽略
except ModuleNotFoundError:
    import os
    
    if os.getenv('IN_TEST', '0') == '1':
        SparkXGBRegressor = object
    else:
        raise
class SparkXGBForecast(SparkXGBRegressor):   
    def _pre_fit(self, target_col):
        self.setParams(label_col=target_col)
        return self

    def extract_local_model(self, trained_model):
        model_str = trained_model.get_booster().save_raw('ubj')
        local_model = xgb.XGBRegressor()
        local_model.load_model(model_str)
        return local_model

Give us a ⭐ on Github