损失函数

::: {#cell-1 .cell 0=‘h’ 1=‘i’ 2=‘d’ 3=‘e’}

! [ -e /content ] && pip install -Uqq fastai  # 在Colab上升级fastai

:::

# 损失函数
### 默认类级别 3

::: {#cell-3 .cell 0=‘e’ 1=‘x’ 2=‘p’ 3=‘o’ 4=‘r’ 5=‘t’}

from __future__ import annotations
from fastai.imports import *
from fastai.torch_imports import *
from fastai.torch_core import *
from fastai.layers import *

:::

::: {#cell-4 .cell 0=‘h’ 1=‘i’ 2=‘d’ 3=‘e’}

from nbdev.showdoc import *

:::

自定义fastai损失函数

::: {#cell-6 .cell 0=‘e’ 1=‘x’ 2=‘p’ 3=‘o’ 4=‘r’ 5=‘t’}

class BaseLoss():
    "Same as `loss_cls`, but flattens input and target."
    activation=decodes=noops
    def __init__(self, 
        loss_cls, # 未初始化的PyTorch兼容损失
        *args,
        axis:int=-1, # 阶级轴线
        flatten:bool=True, # 在计算损失之前,先将 `inp` 和 `targ` 展平。
        floatify:bool=False, # 将 `targ` 转换为 `float`
        is_2d:bool=True, # 无论应用 `flatten` 时保留一个还是两个通道
        **kwargs
    ):
        store_attr("axis,flatten,floatify,is_2d")
        self.func = loss_cls(*args,**kwargs)
        self.func.__annotations__ = typing.get_type_hints(self.func, globalns=globals(), localns=locals()) # 用于防止不可序列化的损失函数(https://github.com/fastai/fastai/issues/3901)
        functools.update_wrapper(self, self.func)

    def __repr__(self) -> str: return f"FlattenedLoss of {self.func}"
    
    @property
    def reduction(self) -> str: return self.func.reduction
    
    @reduction.setter
    def reduction(self, v:str):
        "Sets the reduction style (typically 'mean', 'sum', or 'none')" 
        self.func.reduction = v

    def _contiguous(self, x:Tensor) -> TensorBase:
        "Move `self.axis` to the last dimension and ensure tensor is contigous for `Tensor` otherwise just return"
        return TensorBase(x.transpose(self.axis,-1).contiguous()) if isinstance(x,torch.Tensor) else x

    def __call__(self, 
        inp:Tensor|MutableSequence, # 来自“学习者”的预测
        targ:Tensor|MutableSequence, # 实际与标签
        **kwargs
    ) -> TensorBase: # 在`inp`和`targ`上计算的`loss_cls`
        inp,targ  = map(self._contiguous, (inp,targ))
        if self.floatify and targ.dtype!=torch.float16: targ = targ.float()
        if targ.dtype in [torch.int8, torch.int16, torch.int32]: targ = targ.long()
        if self.flatten: inp = inp.view(-1,inp.shape[-1]) if self.is_2d else inp.view(-1)
        return self.func.__call__(inp, targ.view(-1) if self.flatten else targ, **kwargs)
    
    def to(self, device:torch.device):
        "Move the loss function to a specified `device`"
        if isinstance(self.func, nn.Module): self.func.to(device)

:::

将通用损失函数包装在BaseLoss中为您的损失函数提供了额外的功能:

argskwargs 将在初始化时传递给 loss_cls 以实例化损失函数。axis 被放在最后,以适应像 softmax 这样的损失,因为它通常是在最后一个维度上执行的。如果 floatify=Truetargs 将被转换为浮点数(对于像 BCEWithLogitsLoss 这样只接受浮点目标的损失函数非常有用),而 is_2d 决定我们是在保持第一个维度(批大小)的情况下进行扁平化,还是完全扁平化输入。对于像交叉熵这样的损失函数,我们希望选择第一个;而对于几乎所有其他损失,则选择第二个。

::: {#cell-9 .cell 0=‘e’ 1=‘x’ 2=‘p’ 3=‘o’ 4=‘r’ 5=‘t’}

@delegates()
class CrossEntropyLossFlat(BaseLoss):
    "Same as `nn.CrossEntropyLoss`, but flattens input and target."
    y_int = True # 通过插值
    @use_kwargs_dict(keep=True, weight=None, ignore_index=-100, reduction='mean')
    def __init__(self, 
        *args, 
        axis:int=-1, # 阶级轴线
        **kwargs
    ): 
        super().__init__(nn.CrossEntropyLoss, *args, axis=axis, **kwargs)
    
    def decodes(self, x:Tensor) -> Tensor:    
        "Converts model output to target format"
        return x.argmax(dim=self.axis)
    
    def activation(self, x:Tensor) -> Tensor: 
        "`nn.CrossEntropyLoss`'s fused activation function applied to model output"
        return F.softmax(x, dim=self.axis)

:::

tst = CrossEntropyLossFlat(reduction='none')
output = torch.randn(32, 5, 10)
target = torch.randint(0, 10, (32,5))
#nn.CrossEntropy 在处理这两个张量时会失败,但我们的扁平化版本则不会。
_ = tst(output, target)

test_fail(lambda x: nn.CrossEntropyLoss()(output,target))

#关联激活是softmax
test_eq(tst.activation(output), F.softmax(output, dim=-1))
#这个损失函数有一个解码,即argmax
test_eq(tst.decodes(output), output.argmax(dim=-1))
在分割任务中,我们希望在通道维度上进行softmax操作。
tst = CrossEntropyLossFlat(axis=1)
output = torch.randn(32, 5, 128, 128)
target = torch.randint(0, 5, (32, 128, 128))
_ = tst(output, target)

test_eq(tst.activation(output), F.softmax(output, dim=1))
test_eq(tst.decodes(output), output.argmax(dim=1))

::: {#cell-12 .cell 0=‘h’ 1=‘i’ 2=‘d’ 3=‘e’ 4=’ ’ 5=‘c’ 6=‘u’ 7=‘d’ 8=‘a’}

tst = CrossEntropyLossFlat(weight=torch.ones(10))
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
tst.to(device)
output = torch.randn(32, 10, device=device)
target = torch.randint(0, 10, (32,), device=device)
_ = tst(output, target)

:::

Focal Loss 与交叉熵相同,但在损失计算中,对易分类观察的权重降低。降低权重的强度与 gamma 参数的大小成正比。换句话说,gamma 越大,易分类观察对损失的贡献越小。

::: {#cell-14 .cell 0=‘e’ 1=‘x’ 2=‘p’ 3=‘o’ 4=‘r’ 5=‘t’}

class FocalLoss(Module):
    y_int=True # 插值法
    def __init__(self, 
        gamma:float=2.0, # Focusing parameter. Higher values down-weight easy examples' contribution to loss
        weight:Tensor=None, # Manual rescaling weight given to each class
        reduction:str='mean' # PyTorch reduction to apply to the output
    ): 
        "Applies Focal Loss: https://arxiv.org/pdf/1708.02002.pdf"
        store_attr()
    
    def forward(self, inp:Tensor, targ:Tensor) -> Tensor:
        "Applies focal loss based on https://arxiv.org/pdf/1708.02002.pdf"
        ce_loss = F.cross_entropy(inp, targ, weight=self.weight, reduction="none")
        p_t = torch.exp(-ce_loss)
        loss = (1 - p_t)**self.gamma * ce_loss
        if self.reduction == "mean":
            loss = loss.mean()
        elif self.reduction == "sum":
            loss = loss.sum()
        return loss


class FocalLossFlat(BaseLoss):
    """
    Same as CrossEntropyLossFlat but with focal paramter, `gamma`. Focal loss is introduced by Lin et al. 
    https://arxiv.org/pdf/1708.02002.pdf. Note the class weighting factor in the paper, alpha, can be 
    implemented through pytorch `weight` argument passed through to F.cross_entropy.
    """
    y_int = True # 插值法
    @use_kwargs_dict(keep=True, weight=None, reduction='mean')
    def __init__(self, 
        *args, 
        gamma:float=2.0, # Focusing parameter. Higher values down-weight easy examples' contribution to loss
        axis:int=-1, # 阶级轴线
        **kwargs
    ):
        super().__init__(FocalLoss, *args, gamma=gamma, axis=axis, **kwargs)
        
    def decodes(self, x:Tensor) -> Tensor: 
        "Converts model output to target format"
        return x.argmax(dim=self.axis)
    
    def activation(self, x:Tensor) -> Tensor: 
        "`F.cross_entropy`'s fused activation function applied to model output"
        return F.softmax(x, dim=self.axis)

:::

#将gamma = 0的焦点损失与交叉熵进行比较
fl = FocalLossFlat(gamma=0)
ce = CrossEntropyLossFlat()
output = torch.randn(32, 5, 10)
target = torch.randint(0, 10, (32,5))
test_close(fl(output, target), ce(output, target))
#测试gamma值大于0的焦点损失与交叉熵的不同之处
fl = FocalLossFlat(gamma=2)
test_ne(fl(output, target), ce(output, target))
在分割任务中,我们需要对通道维度进行softmax操作。
fl = FocalLossFlat(gamma=0, axis=1)
ce = CrossEntropyLossFlat(axis=1)
output = torch.randn(32, 5, 128, 128)
target = torch.randint(0, 5, (32, 128, 128))
test_close(fl(output, target), ce(output, target), eps=1e-4)
test_eq(fl.activation(output), F.softmax(output, dim=1))
test_eq(fl.decodes(output), output.argmax(dim=1))

::: {#cell-17 .cell 0=‘e’ 1=‘x’ 2=‘p’ 3=‘o’ 4=‘r’ 5=‘t’}

@delegates()
class BCEWithLogitsLossFlat(BaseLoss):
    "Same as `nn.BCEWithLogitsLoss`, but flattens input and target."
    @use_kwargs_dict(keep=True, weight=None, reduction='mean', pos_weight=None)
    def __init__(self, 
        *args, 
        axis:int=-1, # 阶级轴线
        floatify:bool=True, # 将 `targ` 转换为 `float`
        thresh:float=0.5, # 预测的门槛 
        **kwargs
    ):
        if kwargs.get('pos_weight', None) is not None and kwargs.get('flatten', None) is True:
            raise ValueError("`flatten` must be False when using `pos_weight` to avoid a RuntimeError due to shape mismatch")
        if kwargs.get('pos_weight', None) is not None: kwargs['flatten'] = False
        super().__init__(nn.BCEWithLogitsLoss, *args, axis=axis, floatify=floatify, is_2d=False, **kwargs)
        self.thresh = thresh

    def decodes(self, x:Tensor) -> Tensor:
        "Converts model output to target format"
        return x>self.thresh
    
    def activation(self, x:Tensor) -> Tensor:
        "`nn.BCEWithLogitsLoss`'s fused activation function applied to model output"
        return torch.sigmoid(x)

:::

tst = BCEWithLogitsLossFlat()
output = torch.randn(32, 5, 10)
target = torch.randn(32, 5, 10)
#nn.BCEWithLogitsLoss 在处理这两个张量时会失败,但我们的扁平化版本则不会。
_ = tst(output, target)
test_fail(lambda x: nn.BCEWithLogitsLoss()(output,target))
output = torch.randn(32, 5)
target = torch.randint(0,2,(32, 5))
#nn.BCEWithLogitsLoss 在处理整数目标时会失败,但我们的扁平化版本不会。
_ = tst(output, target)
test_fail(lambda x: nn.BCEWithLogitsLoss()(output,target))

tst = BCEWithLogitsLossFlat(pos_weight=torch.ones(10))
output = torch.randn(32, 5, 10)
target = torch.randn(32, 5, 10)
_ = tst(output, target)
test_fail(lambda x: nn.BCEWithLogitsLoss()(output,target))

#关联激活是S型函数
test_eq(tst.activation(output), torch.sigmoid(output))

::: {#cell-19 .cell 0=‘e’ 1=‘x’ 2=‘p’ 3=‘o’ 4=‘r’ 5=‘t’}

@use_kwargs_dict(weight=None, reduction='mean')
def BCELossFlat(
    *args, 
    axis:int=-1, # 阶级轴线
    floatify:bool=True, # 将 `targ` 转换为 `float`
    **kwargs
):
    "Same as `nn.BCELoss`, but flattens input and target."
    return BaseLoss(nn.BCELoss, *args, axis=axis, floatify=floatify, is_2d=False, **kwargs)

:::

tst = BCELossFlat()
output = torch.sigmoid(torch.randn(32, 5, 10))
target = torch.randint(0,2,(32, 5, 10))
_ = tst(output, target)
test_fail(lambda x: nn.BCELoss()(output,target))

::: {#cell-21 .cell 0=‘e’ 1=‘x’ 2=‘p’ 3=‘o’ 4=‘r’ 5=‘t’}

@use_kwargs_dict(reduction='mean')
def MSELossFlat(
    *args, 
    axis:int=-1, # 阶级轴线
    floatify:bool=True, # 将 `targ` 转换为 `float`
    **kwargs
):
    "Same as `nn.MSELoss`, but flattens input and target."
    return BaseLoss(nn.MSELoss, *args, axis=axis, floatify=floatify, is_2d=False, **kwargs)

:::

tst = MSELossFlat()
output = torch.sigmoid(torch.randn(32, 5, 10))
target = torch.randint(0,2,(32, 5, 10))
_ = tst(output, target)
test_fail(lambda x: nn.MSELoss()(output,target))

::: {#cell-23 .cell 0=‘h’ 1=‘i’ 2=‘d’ 3=‘e’ 4=’ ’ 5=‘C’ 6=‘U’ 7=‘D’ 8=‘A’}

#测试损失在半精度下有效
if torch.cuda.is_available():
    output = torch.sigmoid(torch.randn(32, 5, 10)).half().cuda()
    target = torch.randint(0,2,(32, 5, 10)).half().cuda()
    for tst in [BCELossFlat(), MSELossFlat()]: _ = tst(output, target)

:::

::: {#cell-24 .cell 0=‘e’ 1=‘x’ 2=‘p’ 3=‘o’ 4=‘r’ 5=‘t’}

@use_kwargs_dict(reduction='mean')
def L1LossFlat(
    *args, 
    axis=-1, # 阶级轴线
    floatify=True, # 将 `targ` 转换为 `float`
    **kwargs
):
    "Same as `nn.L1Loss`, but flattens input and target."
    return BaseLoss(nn.L1Loss, *args, axis=axis, floatify=floatify, is_2d=False, **kwargs)

:::

::: {#cell-25 .cell 0=‘e’ 1=‘x’ 2=‘p’ 3=‘o’ 4=‘r’ 5=‘t’}

class LabelSmoothingCrossEntropy(Module):
    y_int = True # 插值
    def __init__(self, 
        eps:float=0.1, # 插值公式的权重
        weight:Tensor=None, # 手动调整分配给传递给 `F.nll_loss` 的每个类别的权重
        reduction:str='mean' # 应用于输出的PyTorch归约操作
    ): 
        store_attr()

    def forward(self, output:Tensor, target:Tensor) -> Tensor:
        "Apply `F.log_softmax` on output then blend the loss/num_classes(`c`) with the `F.nll_loss`"
        c = output.size()[1]
        log_preds = F.log_softmax(output, dim=1)
        if self.reduction=='sum': loss = -log_preds.sum()
        else:
            loss = -log_preds.sum(dim=1) #我们在返回行中除以该大小,因此是求和而不是求平均值
            if self.reduction=='mean':  loss = loss.mean()
        return loss*self.eps/c + (1-self.eps) * F.nll_loss(log_preds, target.long(), weight=self.weight, reduction=self.reduction)

    def activation(self, out:Tensor) -> Tensor: 
        "`F.log_softmax`'s fused activation function applied to model output"
        return F.softmax(out, dim=-1)
    
    def decodes(self, out:Tensor) -> Tensor:
        "Converts model output to target format"
        return out.argmax(dim=-1)

:::

lmce = LabelSmoothingCrossEntropy()
output = torch.randn(32, 5, 10)
target = torch.randint(0, 10, (32,5))
test_close(lmce(output.flatten(0,1), target.flatten()), lmce(output.transpose(-1,-2), target))

在公式之上,我们定义:

::: {#cell-28 .cell 0=‘e’ 1=‘x’ 2=‘p’ 3=‘o’ 4=‘r’ 5=‘t’}

@delegates()
class LabelSmoothingCrossEntropyFlat(BaseLoss):
    "Same as `LabelSmoothingCrossEntropy`, but flattens input and target."
    y_int = True
    @use_kwargs_dict(keep=True, eps=0.1, reduction='mean')
    def __init__(self, 
        *args, 
        axis:int=-1, # 阶级轴线
        **kwargs
    ): 
        super().__init__(LabelSmoothingCrossEntropy, *args, axis=axis, **kwargs)
    def activation(self, out:Tensor) -> Tensor: 
        "`LabelSmoothingCrossEntropy`'s fused activation function applied to model output"
        return F.softmax(out, dim=-1)
    
    def decodes(self, out:Tensor) -> Tensor:
        "Converts model output to target format"
        return out.argmax(dim=-1)

:::

#这两个值应始终相等,因为Flat版本仅是传递数据
lmce = LabelSmoothingCrossEntropy()
lmce_flat = LabelSmoothingCrossEntropyFlat()
output = torch.randn(32, 5, 10)
target = torch.randint(0, 10, (32,5))
test_close(lmce(output.transpose(-1,-2), target), lmce_flat(output,target))

我们提出了一种通用的 Dice 损失用于分割任务。它通常与 CrossEntropyLossFocalLoss 一起在 kaggle 比赛中使用。这与 DiceMulti 指标非常相似,但为了能够进行导数计算,我们用 softmax 替代了 argmax 激活,并将其与一个独热编码的目标掩码进行了比较。此函数还添加了一个 smooth 参数,以帮助在交集与并集的除法中保持数值稳定性。如果您的网络在使用此 DiceLoss 时学习有问题,请尝试将 DiceLoss 构造函数中的 square_in_union 参数设置为 True

::: {#cell-31 .cell 0=‘e’ 1=‘x’ 2=‘p’ 3=‘o’ 4=‘r’ 5=‘t’}

class DiceLoss:
    "Dice loss for segmentation"
    def __init__(self, 
        axis:int=1, # 阶级轴线
        smooth:float=1e-6, # 有助于在IoU除法中保持数值稳定性
        reduction:str="sum", # 应用于输出的PyTorch归约操作
        square_in_union:bool=False # 调整预测以增加梯度斜率
    ):
        store_attr()
        
    def __call__(self, pred:Tensor, targ:Tensor) -> Tensor:
        "One-hot encodes targ, then runs IoU calculation then takes 1-dice value"
        targ = self._one_hot(targ, pred.shape[self.axis])
        pred, targ = TensorBase(pred), TensorBase(targ)
        assert pred.shape == targ.shape, 'input and target dimensions differ, DiceLoss expects non one-hot targs'
        pred = self.activation(pred)
        sum_dims = list(range(2, len(pred.shape)))
        inter = torch.sum(pred*targ, dim=sum_dims)        
        union = (torch.sum(pred**2+targ, dim=sum_dims) if self.square_in_union
            else torch.sum(pred+targ, dim=sum_dims))
        dice_score = (2. * inter + self.smooth)/(union + self.smooth)
        loss = 1- dice_score
        if self.reduction == 'mean':
            loss = loss.mean()
        elif self.reduction == 'sum':
            loss = loss.sum()
        return loss
    @staticmethod
    def _one_hot(
        x:Tensor, # 非独热编码的目标
        classes:int, # 班级数量 
        axis:int=1 # 用于编码的堆叠轴(类别维度)
    ) -> Tensor:
        "Creates one binary mask per class"
        return torch.stack([torch.where(x==c, 1, 0) for c in range(classes)], axis=axis)
    
    def activation(self, x:Tensor) -> Tensor: 
        "Activation function applied to model output"
        return F.softmax(x, dim=self.axis)
    
    def decodes(self, x:Tensor) -> Tensor:
        "Converts model output to target format"
        return x.argmax(dim=self.axis)

:::

dl = DiceLoss()
_x         = tensor( [[[1, 0, 2],
                       [2, 2, 1]]])
_one_hot_x = tensor([[[[0, 1, 0],
                       [0, 0, 0]],
                      [[1, 0, 0],
                       [0, 0, 1]],
                      [[0, 0, 1],
                       [1, 1, 0]]]])
test_eq(dl._one_hot(_x, 3), _one_hot_x)
dl = DiceLoss()
model_output = tensor([[[[2., 1.],
                         [1., 5.]],
                        [[1,  2.],
                         [3., 1.]],
                        [[3., 0],
                         [4., 3.]]]])
target       =  tensor([[[2, 1],
                         [2, 0]]])
dl_out = dl(model_output, target)
test_eq(dl.decodes(model_output), target)
dl = DiceLoss(reduction="mean")
#相同的面具
model_output = tensor([[[.1], [.1], [100.]]])
target = tensor([[2]])
test_close(dl(model_output, target), 0)

#50% 交叉口
model_output = tensor([[[.1, 100.], [.1, .1], [100., .1]]])
target = tensor([[2, 1]])
test_close(dl(model_output, target), .66, eps=0.01)

作为骰子损失的测试案例,我们考虑卫星图像分割。假设我们有三个类别:背景(0),河流(1)和道路(2)。让我们来看一个特定的目标。

target = torch.zeros(100,100)
target[:,5] = 1
target[:,50] = 2
plt.imshow(target);

在这个例子中,几乎所有内容都是背景,我们在图像的左侧有一条细河,在图像的中间有一条细路。如果我们的数据看起来都类似于这个,我们就说存在类别不平衡,意味着某些类别(如河流和道路)出现得相对较少。如果我们的模型只是对所有像素预测“背景”(即值为0),那么对于大多数像素来说,这将是正确的。但是这将是一个糟糕的模型,而Diceloss应该反映这一点。

model_output_all_background = torch.zeros(3, 100,100)
# 在所有地方将类别0的概率赋值为1
# 要得到概率1,我们只需在应用softmax之前获得一个较高的模型输出值。
model_output_all_background[0,:,:] = 100
# 添加一个批次维度
model_output_all_background = torch.unsqueeze(model_output_all_background,0)
target = torch.unsqueeze(target,0)

我们的骰子分数在这里应该约为1/3,因为“背景”类别几乎每个像素都被正确预测了,但另外两个类别从未被正确预测。骰子分数1/3意味着骰子损失为1 - 1/3 = 2/3:

test_close(dl(model_output_all_background, target), 0.67, eps=0.01)

如果模型能够准确预测所有内容,则骰子损失应该为零:

correct_model_output = torch.zeros(3, 100,100)
correct_model_output[0,:,:] = 100
correct_model_output[0,:,5] = 0
correct_model_output[0,:,50] = 0
correct_model_output[1,:,5] = 100
correct_model_output[2,:,50] = 100
correct_model_output = torch.unsqueeze(correct_model_output, 0)
test_close(dl(correct_model_output, target), 0)

::: {#cell-45 .cell 0=‘h’ 1=‘i’ 2=‘d’ 3=‘e’ 4=’ ’ 5=‘c’ 6=‘u’ 7=‘d’ 8=‘a’}

#测试DicceLoss在半精度下的工作情况
if torch.cuda.is_available():
    output = torch.randn(32, 4, 5, 10).half().cuda()
    target = torch.randint(0,2,(32, 5, 10)).half().cuda()
    _ = dl(output, target)

:::

您可以轻松地将此损失与 FocalLoss 结合,定义一个 CombinedLoss,以平衡目标掩码上的全局(Dice)和局部(Focal)特征。

class CombinedLoss:
    "Dice and Focal combined"
    def __init__(self, axis=1, smooth=1., alpha=1.):
        store_attr()
        self.focal_loss = FocalLossFlat(axis=axis)
        self.dice_loss =  DiceLoss(axis, smooth)
        
    def __call__(self, pred, targ):
        return self.focal_loss(pred, targ) + self.alpha * self.dice_loss(pred, targ)
    
    def decodes(self, x):    return x.argmax(dim=self.axis)
    def activation(self, x): return F.softmax(x, dim=self.axis)
cl = CombinedLoss()
output = torch.randn(32, 4, 5, 10)
target = torch.randint(0,2,(32, 5, 10))
_ = cl(output, target)
# Tests to catch future changes to pickle which cause some loss functions to be 'unpicklable'.
# This causes problems with `Learner.export` as the model can't be pickled with these particular loss funcitons.

losses_picklable = [
 (BCELossFlat(), True),
 (BCEWithLogitsLossFlat(), True),
 (CombinedLoss(), True),
 (CrossEntropyLossFlat(), True),
 (DiceLoss(), True),
 (FocalLoss(), True),
 (FocalLossFlat(), True),
 (L1LossFlat(), True),
 (LabelSmoothingCrossEntropyFlat(), True),
 (LabelSmoothingCrossEntropy(), True),
 (MSELossFlat(), True),
]

for loss, picklable in losses_picklable:
    try:
        pickle.dumps(loss, protocol=2)
    except (pickle.PicklingError, TypeError) as e:
        if picklable:
            # Loss was previously picklable but isn't currently
            raise e

导出 -

::: {#cell-51 .cell 0=‘h’ 1=‘i’ 2=‘d’ 3=‘e’}

from nbdev import *
nbdev_export()

:::