ray.rllib.policy.eager_tf_policy_v2.EagerTFPolicyV2.损失#
- EagerTFPolicyV2.loss(model: ModelV2 | tf.keras.Model, dist_class: Type[TFActionDistribution], train_batch: SampleBatch) numpy.array | jnp.ndarray | tf.Tensor | torch.Tensor | List[numpy.array | jnp.ndarray | tf.Tensor | torch.Tensor] [源代码]#
使用模型、dist_class 和 train_batch 计算此策略的损失。
- 参数:
model – 用于计算损失的模型。
dist_class – 动作分布类。
train_batch – 训练数据。
- 返回:
单个损失张量或损失张量列表。