SparkLGBMForecast

spark LightGBM 预测器

synapse.ml.lightgbm.LightGBMRegressor 的包装器,添加了 extract_local_model 方法,以获取训练模型的本地版本并将其广播到工作节点。

import lightgbm as lgb
try:
    from synapse.ml.lightgbm import LightGBMRegressor
except ModuleNotFoundError:
    import os
    
    if os.getenv('QUARTO_PREVIEW', '0') == '1' or os.getenv('IN_TEST', '0') == '1':
        LightGBMRegressor = object
    else:
        raise
class SparkLGBMForecast(LightGBMRegressor):
    def _pre_fit(self, target_col):
        return self.setLabelCol(target_col)
        
    def extract_local_model(self, trained_model):
        model_str = trained_model.getNativeModel()
        local_model = lgb.Booster(model_str=model_str)
        return local_model

Give us a ⭐ on Github