ray.train.torch.prepare_模型#

ray.train.torch.prepare_model(model: torch.nn.Module, move_to_device: bool | torch.device = True, parallel_strategy: str | None = 'ddp', parallel_strategy_kwargs: Dict[str, Any] | None = None) torch.nn.Module[源代码]#

为模型的分布式执行做准备。

这允许你使用相同的代码,无论工作者的数量或使用的设备类型(CPU、GPU)如何。

参数:
  • model (torch.nn.Module) – 一个准备好的 torch 模型。

  • move_to_device – 一个布尔值,指示是否将模型移动到正确的设备,或者是一个实际的设备,用于移动模型。如果设置为 False,则需要手动将模型移动到正确的设备。

  • parallel_strategy ("ddp", "fsdp", or None) – 是否将模型封装在 DistributedDataParallelFullyShardedDataParallel 中,或两者都不使用。

  • parallel_strategy_kwargs (Dict[str, Any]) – 如果``parallel_strategy``分别设置为”ddp”或”fsdp”,则传递给``DistributedDataParallel``或``FullyShardedDataParallel``初始化的参数。