torch.distributions.multinomial 的源代码
import torch
from torch import inf
from torch.distributions import Categorical, constraints
from torch.distributions.binomial import Binomial
from torch.distributions.distribution import Distribution
from torch.distributions.utils import broadcast_all
__all__ = ["Multinomial"]
[docs]class Multinomial(Distribution):
r"""
创建一个由 :attr:`total_count` 参数化的多项分布,
并且可以通过 :attr:`probs` 或 :attr:`logits` 来定义(但不能同时使用两者)。
:attr:`probs` 的最内层维度索引类别。所有其他维度索引批次。
注意:如果只调用 :meth:`log_prob`,则不需要指定 :attr:`total_count`(见下面的示例)
.. 注意:: `probs` 参数必须是非负的、有限的并且具有非零和,
它将被归一化,使其沿最后一个维度总和为 1。:attr:`probs`
将返回这个归一化值。
`logits` 参数将被解释为未归一化的对数概率
因此可以是任何实数。它同样将被归一化,使得
生成的概率沿最后一个维度总和为 1。:attr:`logits`
将返回这个归一化值。
- :meth:`sample` 需要所有参数和样本共享一个 `total_count`。
- :meth:`log_prob` 允许每个参数和样本有不同的 `total_count`。
示例::
>>> # xdoctest: +SKIP("FIXME: found invalid values")
>>> m = Multinomial(100, torch.tensor([ 1., 1., 1., 1.]))
>>> x = m.sample() # 0, 1, 2, 3 的概率相等
tensor([ 21., 24., 30., 25.])
>>> Multinomial(probs=torch.tensor([1., 1., 1., 1.])).log_prob(x)
tensor([-4.1338])
参数:
total_count (int): 试验次数
probs (Tensor): 事件概率
logits (Tensor): 事件对数概率(未归一化)
"""
arg_constraints = {"probs": constraints.simplex, "logits": constraints.real_vector}
total_count: int
@property
def mean(self):
return self.probs * self.total_count
@property
def variance(self):
return self.total_count * self.probs * (1 - self.probs)
def __init__(self, total_count=1, probs=None, logits=None, validate_args=None):
if not isinstance(total_count, int):
raise NotImplementedError("inhomogeneous total_count is not supported")
self.total_count = total_count
self._categorical = Categorical(probs=probs, logits=logits)
self._binomial = Binomial(total_count=total_count, probs=self.probs)
batch_shape = self._categorical.batch_shape
event_shape = self._categorical.param_shape[-1:]
super().__init__(batch_shape, event_shape, validate_args=validate_args)
[docs] def expand(self, batch_shape, _instance=None):
new = self._get_checked_instance(Multinomial, _instance)
batch_shape = torch.Size(batch_shape)
new.total_count = self.total_count
new._categorical = self._categorical.expand(batch_shape)
super(Multinomial, new).__init__(
batch_shape, self.event_shape, validate_args=False
)
new._validate_args = self._validate_args
return new
def _new(self, *args, **kwargs):
return self._categorical._new(*args, **kwargs)
@constraints.dependent_property(is_discrete=True, event_dim=1)
def support(self):
return constraints.multinomial(self.total_count)
@property
def logits(self):
return self._categorical.logits
@property
def probs(self):
return self._categorical.probs
@property
def param_shape(self):