ray.rllib.policy.eager_tf_policy_v2.EagerTFPolicyV2.损失#

EagerTFPolicyV2.loss(model: ModelV2 | tf.keras.Model, dist_class: Type[TFActionDistribution], train_batch: SampleBatch) numpy.array | jnp.ndarray | tf.Tensor | torch.Tensor | List[numpy.array | jnp.ndarray | tf.Tensor | torch.Tensor][源代码]#

使用模型、dist_class 和 train_batch 计算此策略的损失。

参数:
  • model – 用于计算损失的模型。

  • dist_class – 动作分布类。

  • train_batch – 训练数据。

返回:

单个损失张量或损失张量列表。