ray.rllib.models.distributions.Distribution.from_logits#

classmethod Distribution.from_logits(logits: numpy.array | jnp.ndarray | tf.Tensor | torch.Tensor, **kwargs) Distribution[源代码]#

从logits创建一个分布。

调用者不需要了解分布类的知识就可以创建它并从中采样。传递的批量对数向量可能会被分割,并作为关键字参数传递给分布类的构造函数。

参数:
  • logits – 用于创建分布的logits。

  • **kwargs – 向前兼容占位符。

返回:

创建的发行版。

import numpy as np
from ray.rllib.models.distributions import Distribution

class Uniform(Distribution):
    def __init__(self, lower, upper):
        self.lower = lower
        self.upper = upper

    def sample(self):
        return self.lower + (self.upper - self.lower) * np.random.rand()

    def logp(self, x):
        ...

    def kl(self, other):
        ...

    def entropy(self):
        ...

    @staticmethod
    def required_input_dim(space):
        ...

    def rsample(self):
        ...

    @classmethod
    def from_logits(cls, logits, **kwargs):
        return Uniform(logits[:, 0], logits[:, 1])

logits = np.array([[0.0, 1.0], [2.0, 3.0]])
my_dist = Uniform.from_logits(logits)
sample = my_dist.sample()