ray.tune.with_parameters#

ray.tune.with_parameters(trainable: Type[Trainable] | Callable, **kwargs)[源代码]#

用于可训练对象的包装器,以传递任意大的数据对象。

这个包装函数将存储所有传递的参数在 Ray 对象存储中,并在调用函数时检索它们。因此,它可以用于传递任意数据,甚至是数据集,到 Tune 可训练对象。

这也可以作为 functools.partial 的替代方案,用于向可训练对象传递默认参数。

当与函数API一起使用时,可训练函数会以关键字参数的形式调用传递的参数。当与类API一起使用时,Trainable.setup() 方法会以相应的kwargs调用。

如果数据已经存在于对象存储中(是 ObjectRef 的实例),则不需要使用 tune.with_parameters()。你可以通过 config 将对象引用传递给训练函数,或者使用 Python 的部分应用。

参数:
  • trainable – 可训练的包装。

  • **kwargs – 存储在对象存储中的参数。

函数 API 示例:

from ray import train, tune

def train_fn(config, data=None):
    for sample in data:
        loss = update_model(sample)
        train.report(loss=loss)

data = HugeDataset(download=True)

tuner = Tuner(
    tune.with_parameters(train_fn, data=data),
    # ...
)
tuner.fit()

类 API 示例:

from ray import tune

class MyTrainable(tune.Trainable):
    def setup(self, config, data=None):
        self.data = data
        self.iter = iter(self.data)
        self.next_sample = next(self.iter)

    def step(self):
        loss = update_model(self.next_sample)
        try:
            self.next_sample = next(self.iter)
        except StopIteration:
            return {"loss": loss, done: True}
        return {"loss": loss}

data = HugeDataset(download=True)

tuner = Tuner(
    tune.with_parameters(MyTrainable, data=data),
    # ...
)

PublicAPI (测试版): 此API目前处于测试阶段,在成为稳定版本之前可能会发生变化。