ray.tune.Callback.get_state#

Callback.get_state() Dict | None[源代码]#

获取回调的状态。

此方法应由子类实现,以返回对象当前状态的字典表示。

这是 Tune 自动调用的,用于定期检查点回调状态。在 Tune 实验恢复 时,回调状态将通过 set_state() 恢复。

from typing import Dict, List, Optional

from ray.tune import Callback
from ray.tune.experiment import Trial

class MyCallback(Callback):
    def __init__(self):
        self._trial_ids = set()

    def on_trial_start(
        self, iteration: int, trials: List["Trial"], trial: "Trial", **info
    ):
        self._trial_ids.add(trial.trial_id)

    def get_state(self) -> Optional[Dict]:
        return {"trial_ids": self._trial_ids.copy()}

    def set_state(self, state: Dict) -> Optional[Dict]:
        self._trial_ids = state["trial_ids"]
返回:

回调的状态。如果回调没有任何状态需要保存,则应为 `None`(这是默认值)。

返回类型:

dict