ray.rllib.offline.d4rl_reader.D4RLReader.tf_input_ops#
- D4RLReader.tf_input_ops(queue_size: int = 1) Dict[str, numpy.array | jnp.ndarray | tf.Tensor | torch.Tensor] #
返回用于从此读取器读取输入的 TensorFlow 队列操作。
这些操作的主要用途是集成到自定义模型损失中。例如,您可以使用 tf_input_ops() 从外部经验文件中读取,以向您的模型添加模仿学习损失。
此方法创建一个队列运行器线程,该线程将反复调用此读取器的 next() 方法以向 TensorFlow 队列提供数据。
- 参数:
queue_size – TF 队列中允许的最大元素数。
from ray.rllib.models.modelv2 import ModelV2 from ray.rllib.offline.json_reader import JsonReader imitation_loss = ... class MyModel(ModelV2): def custom_loss(self, policy_loss, loss_inputs): reader = JsonReader(...) input_ops = reader.tf_input_ops() logits, _ = self._build_layers_v2( {"obs": input_ops["obs"]}, self.num_outputs, self.options) il_loss = imitation_loss(logits, input_ops["action"]) return policy_loss + il_loss
你可以在 examples/custom_loss.py 中找到这个的可运行版本。
- 返回:
每个读取的 SampleBatch 列对应一个张量的字典。