Shortcuts

torch.distributions.multinomial 的源代码

import torch
from torch import inf
from torch.distributions import Categorical, constraints
from torch.distributions.binomial import Binomial
from torch.distributions.distribution import Distribution
from torch.distributions.utils import broadcast_all

__all__ = ["Multinomial"]


[docs]class Multinomial(Distribution): r""" 创建一个由 :attr:`total_count` 参数化的多项分布, 并且可以通过 :attr:`probs` 或 :attr:`logits` 来定义(但不能同时使用两者)。 :attr:`probs` 的最内层维度索引类别。所有其他维度索引批次。 注意:如果只调用 :meth:`log_prob`,则不需要指定 :attr:`total_count`(见下面的示例) .. 注意:: `probs` 参数必须是非负的、有限的并且具有非零和, 它将被归一化,使其沿最后一个维度总和为 1。:attr:`probs` 将返回这个归一化值。 `logits` 参数将被解释为未归一化的对数概率 因此可以是任何实数。它同样将被归一化,使得 生成的概率沿最后一个维度总和为 1。:attr:`logits` 将返回这个归一化值。 - :meth:`sample` 需要所有参数和样本共享一个 `total_count`。 - :meth:`log_prob` 允许每个参数和样本有不同的 `total_count`。 示例:: >>> # xdoctest: +SKIP("FIXME: found invalid values") >>> m = Multinomial(100, torch.tensor([ 1., 1., 1., 1.])) >>> x = m.sample() # 0, 1, 2, 3 的概率相等 tensor([ 21., 24., 30., 25.]) >>> Multinomial(probs=torch.tensor([1., 1., 1., 1.])).log_prob(x) tensor([-4.1338]) 参数: total_count (int): 试验次数 probs (Tensor): 事件概率 logits (Tensor): 事件对数概率(未归一化) """ arg_constraints = {"probs": constraints.simplex, "logits": constraints.real_vector} total_count: int @property def mean(self): return self.probs * self.total_count @property def variance(self): return self.total_count * self.probs * (1 - self.probs) def __init__(self, total_count=1, probs=None, logits=None, validate_args=None): if not isinstance(total_count, int): raise NotImplementedError("inhomogeneous total_count is not supported") self.total_count = total_count self._categorical = Categorical(probs=probs, logits=logits) self._binomial = Binomial(total_count=total_count, probs=self.probs) batch_shape = self._categorical.batch_shape event_shape = self._categorical.param_shape[-1:] super().__init__(batch_shape, event_shape, validate_args=validate_args)
[docs] def expand(self, batch_shape, _instance=None): new = self._get_checked_instance(Multinomial, _instance) batch_shape = torch.Size(batch_shape) new.total_count = self.total_count new._categorical = self._categorical.expand(batch_shape) super(Multinomial, new).__init__( batch_shape, self.event_shape, validate_args=False ) new._validate_args = self._validate_args return new
def _new(self, *args, **kwargs): return self._categorical._new(*args, **kwargs) @constraints.dependent_property(is_discrete=True, event_dim=1) def support(self): return constraints.multinomial(self.total_count) @property def logits(self): return self._categorical.logits @property def probs(self): return self._categorical.probs @property def param_shape(self):
优云智算