import lightgbm as lgb
try:
from synapse.ml.lightgbm import LightGBMRegressor
except ModuleNotFoundError:
import os
if os.getenv('QUARTO_PREVIEW', '0') == '1' or os.getenv('IN_TEST', '0') == '1':
= object
LightGBMRegressor else:
raise
SparkLGBMForecast
spark LightGBM 预测器
synapse.ml.lightgbm.LightGBMRegressor
的包装器,添加了 extract_local_model
方法,以获取训练模型的本地版本并将其广播到工作节点。
class SparkLGBMForecast(LightGBMRegressor):
def _pre_fit(self, target_col):
return self.setLabelCol(target_col)
def extract_local_model(self, trained_model):
= trained_model.getNativeModel()
model_str = lgb.Booster(model_str=model_str)
local_model return local_model
Give us a ⭐ on Github