高斯NLL损失¶
- class torch.nn.GaussianNLLLoss(*, full=False, eps=1e-06, reduction='mean')[源代码]¶
高斯负对数似然损失。
目标被视为来自高斯分布的样本,其期望值和方差由神经网络预测。对于一个
target张量,假设其具有高斯分布,期望值张量为input,正方差张量为var,损失为:其中
eps用于稳定性。默认情况下,除非full为True,否则损失函数的常数项将被省略。如果var与input的大小不同(由于同方差假设),它必须具有最后一个维度为1,或者具有少一个维度(所有其他大小相同)以正确广播。- Parameters
- Shape:
输入: 或 其中 表示任意数量的额外维度
目标: 或 ,与输入形状相同,或与输入形状相同但其中一个维度等于1(以允许广播)
变量: 或 , 与输入形状相同,或与输入形状相同但其中一个维度等于1,或与输入形状相同但维度减少一个(以允许广播)
输出:如果
reduction是'mean'(默认)或'sum',则为标量。如果reduction是'none',则为 ,与输入形状相同
- Examples::
>>> loss = nn.GaussianNLLLoss() >>> input = torch.randn(5, 2, requires_grad=True) >>> target = torch.randn(5, 2) >>> var = torch.ones(5, 2, requires_grad=True) # 异方差 >>> output = loss(input, target, var) >>> output.backward()
>>> loss = nn.GaussianNLLLoss() >>> input = torch.randn(5, 2, requires_grad=True) >>> target = torch.randn(5, 2) >>> var = torch.ones(5, 1, requires_grad=True) # 同方差 >>> output = loss(input, target, var) >>> output.backward()
注意
对于自动求导,
var的钳制操作会被忽略,因此梯度不会受到影响。- Reference:
Nix, D. A. 和 Weigend, A. S., “估计目标概率分布的均值和方差”, 1994年IEEE国际神经网络会议(ICNN’94)论文集, 美国佛罗里达州奥兰多, 1994年, 第55-60页, 卷1, doi: 10.1109/ICNN.1994.374138.