ray.rllib.policy.Policy.loss#

Policy.loss(model: ModelV2, dist_class: ActionDistribution, train_batch: SampleBatch) numpy.array | jnp.ndarray | tf.Tensor | torch.Tensor | List[numpy.array | jnp.ndarray | tf.Tensor | torch.Tensor][源代码]#

此策略的损失函数。

在子类中重写此方法以实现自定义损失计算。

参数:
  • model – 用于计算损失的模型。

  • dist_class – 动作分布类,用于从模型的输出中采样动作。

  • train_batch – 要计算损失的输入批次。

返回:

单个损失张量或损失张量列表。