ray.rllib.models.distributions.Distribution.from_logits#
- classmethod Distribution.from_logits(logits: numpy.array | jnp.ndarray | tf.Tensor | torch.Tensor, **kwargs) Distribution [源代码]#
从logits创建一个分布。
调用者不需要了解分布类的知识就可以创建它并从中采样。传递的批量对数向量可能会被分割,并作为关键字参数传递给分布类的构造函数。
- 参数:
logits – 用于创建分布的logits。
**kwargs – 向前兼容占位符。
- 返回:
创建的发行版。
import numpy as np from ray.rllib.models.distributions import Distribution class Uniform(Distribution): def __init__(self, lower, upper): self.lower = lower self.upper = upper def sample(self): return self.lower + (self.upper - self.lower) * np.random.rand() def logp(self, x): ... def kl(self, other): ... def entropy(self): ... @staticmethod def required_input_dim(space): ... def rsample(self): ... @classmethod def from_logits(cls, logits, **kwargs): return Uniform(logits[:, 0], logits[:, 1]) logits = np.array([[0.0, 1.0], [2.0, 3.0]]) my_dist = Uniform.from_logits(logits) sample = my_dist.sample()