import xgboost as xgb
try:
from xgboost.spark import SparkXGBRegressor # 类型:忽略
except ModuleNotFoundError:
import os
if os.getenv('IN_TEST', '0') == '1':
= object
SparkXGBRegressor else:
raise
SparkXGBForecast
Spark XGBoost 预测器
xgboost.spark.SparkXGBRegressor
的封装器,增加了一个 extract_local_model
方法,用于获取经过训练的模型的本地版本并将其广播到工作节点。
class SparkXGBForecast(SparkXGBRegressor):
def _pre_fit(self, target_col):
self.setParams(label_col=target_col)
return self
def extract_local_model(self, trained_model):
= trained_model.get_booster().save_raw('ubj')
model_str = xgb.XGBRegressor()
local_model
local_model.load_model(model_str)return local_model
Give us a ⭐ on Github