ray.rllib.policy.torch_policy_v2.TorchPolicyV2.损失#

TorchPolicyV2.loss(model: ModelV2, dist_class: Type[TorchDistributionWrapper], train_batch: SampleBatch) numpy.array | jnp.ndarray | tf.Tensor | torch.Tensor | List[numpy.array | jnp.ndarray | tf.Tensor | torch.Tensor][源代码]#

构建损失函数。

参数:
  • model – 用于计算损失的模型。

  • dist_class – 动作分布类。

  • train_batch – 训练数据。

返回:

给定输入批次的损失张量。