ray.tune.with_parameters#
- ray.tune.with_parameters(trainable: Type[Trainable] | Callable, **kwargs)[源代码]#
用于可训练对象的包装器,以传递任意大的数据对象。
这个包装函数将存储所有传递的参数在 Ray 对象存储中,并在调用函数时检索它们。因此,它可以用于传递任意数据,甚至是数据集,到 Tune 可训练对象。
这也可以作为
functools.partial
的替代方案,用于向可训练对象传递默认参数。当与函数API一起使用时,可训练函数会以关键字参数的形式调用传递的参数。当与类API一起使用时,
Trainable.setup()
方法会以相应的kwargs调用。如果数据已经存在于对象存储中(是 ObjectRef 的实例),则不需要使用
tune.with_parameters()
。你可以通过config
将对象引用传递给训练函数,或者使用 Python 的部分应用。- 参数:
trainable – 可训练的包装。
**kwargs – 存储在对象存储中的参数。
函数 API 示例:
from ray import train, tune def train_fn(config, data=None): for sample in data: loss = update_model(sample) train.report(loss=loss) data = HugeDataset(download=True) tuner = Tuner( tune.with_parameters(train_fn, data=data), # ... ) tuner.fit()
类 API 示例:
from ray import tune class MyTrainable(tune.Trainable): def setup(self, config, data=None): self.data = data self.iter = iter(self.data) self.next_sample = next(self.iter) def step(self): loss = update_model(self.next_sample) try: self.next_sample = next(self.iter) except StopIteration: return {"loss": loss, done: True} return {"loss": loss} data = HugeDataset(download=True) tuner = Tuner( tune.with_parameters(MyTrainable, data=data), # ... )
PublicAPI (测试版): 此API目前处于测试阶段,在成为稳定版本之前可能会发生变化。