ray.rllib.env.env_runner_group.EnvRunnerGroup.同步权重#

EnvRunnerGroup.sync_weights(policies: List[str] | None = None, from_worker_or_learner_group: EnvRunner | LearnerGroup | None = None, to_worker_indices: List[int] | None = None, global_vars: Dict[str, numpy.array | jnp.ndarray | tf.Tensor | torch.Tensor] | None = None, timeout_seconds: float | None = 0.0, inference_only: bool | None = False) None[源代码]#

将模型权重从给定的权重源同步到所有远程工作节点。

权重来源可以是(本地的)rollout worker 或 learner_group。它只需要实现一个 get_weights 方法。

参数:
  • policies – 可选的 PolicyID 列表,用于同步权重。如果为 None(默认),则同步所有策略的权重。

  • from_worker_or_learner_group – 可选的(本地)EnvRunner 实例或 LearnerGroup 实例以进行同步。如果为 None(默认),则从该 EnvRunnerGroup 的本地工作者同步。

  • to_worker_indices – 要同步权重的可选工作索引列表。如果为 None(默认),则同步到所有远程工作节点。

  • global_vars – 一个可选的全局变量字典,用于设置此工作器。如果为 None,则不更新 global_vars。

  • timeout_seconds – 等待同步权重调用完成的超时时间(以秒为单位)。默认值为 0.0(即发即弃,不等待任何同步调用完成)。根据算法的 training_step 逻辑,将其设置为 0.0 可能会显著提高算法性能。

  • inference_only – 与保持推理模块的工作者同步权重。这对于使用仅推理模块的新堆栈中的算法是必需的。在这种情况下,只有部分参数被同步到工作者。默认是 False。

开发者API: 此API可能会在Ray的次要版本之间发生变化。