ray.rllib.policy.Policy.loss#
- Policy.loss(model: ModelV2, dist_class: ActionDistribution, train_batch: SampleBatch) numpy.array | jnp.ndarray | tf.Tensor | torch.Tensor | List[numpy.array | jnp.ndarray | tf.Tensor | torch.Tensor] [源代码]#
此策略的损失函数。
在子类中重写此方法以实现自定义损失计算。
- 参数:
model – 用于计算损失的模型。
dist_class – 动作分布类,用于从模型的输出中采样动作。
train_batch – 要计算损失的输入批次。
- 返回:
单个损失张量或损失张量列表。