ray.rllib.evaluation.sampler.SyncSampler.tf_输入操作#

SyncSampler.tf_input_ops(queue_size: int = 1) Dict[str, numpy.array | jnp.ndarray | tf.Tensor | torch.Tensor]#

返回用于从此读取器读取输入的 TensorFlow 队列操作。

这些操作的主要用途是集成到自定义模型损失中。例如,您可以使用 tf_input_ops() 从外部经验文件中读取,以向您的模型添加模仿学习损失。

此方法创建一个队列运行器线程,该线程将反复调用此读取器的 next() 方法以向 TensorFlow 队列提供数据。

参数:

queue_size – TF 队列中允许的最大元素数。

from ray.rllib.models.modelv2 import ModelV2
from ray.rllib.offline.json_reader import JsonReader
imitation_loss = ...
class MyModel(ModelV2):
    def custom_loss(self, policy_loss, loss_inputs):
        reader = JsonReader(...)
        input_ops = reader.tf_input_ops()
        logits, _ = self._build_layers_v2(
            {"obs": input_ops["obs"]},
            self.num_outputs, self.options)
        il_loss = imitation_loss(logits, input_ops["action"])
        return policy_loss + il_loss

你可以在 examples/custom_loss.py 中找到这个的可运行版本。

返回:

每个读取的 SampleBatch 列对应一个张量的字典。