损失平衡器

蒸馏任务的基本损失平衡器。

DistillationLossBalancer

损失平衡器的接口。

StaticLossBalancer

基于静态权重的KD损失聚合。

class DistillationLossBalancer

基础:Module

损失平衡器的接口。

__init__()

构造函数。

abstract forward(loss)

计算聚合损失。

Parameters:

loss (Dict[str, Tensor]) – 要聚合的损失字典。 键将是应用于获取损失的损失函数的类名,后缀为_{idx}以确保唯一性。如果提供了学生损失给 mtd.DistillationModel.compute_kd_loss,则它将具有键 mtd.loss_balancers.STUDENT_LOSS_KEY。例如,如果criterion参数 传递给mtd.convert的是 {("mod1_s", "mod1_t"): torch.nn.MSELoss(), ("mod2_s", "mod2_t"): torch.nn.MSELoss()} 并且提供给mtd.DistillationModel.compute_kd_loss的student_loss不是 None,那么这里的损失字典将看起来像 {"student_loss": torch.tensor(...), "MSELoss_0": torch.tensor(...), "MSELoss_1": torch.tensor(...)}

Returns:

平衡学生和kd损失组件后的总损失。

Return type:

张量

set_student_loss_reduction_fn(student_loss_reduction_fn)

设置学生损失减少函数值。

在平衡之前减少学生损失的特殊情况下需要。

Parameters:

student_loss_reduction_fn (Callable[[Any], Tensor]) –

class StaticLossBalancer

基础:DistillationLossBalancer

基于静态权重的KD损失聚合。

__init__(kd_loss_weight=0.5)

构造函数。

Parameters:

kd_loss_weight (float | List[float]) – 用于平衡知识蒸馏损失和原始学生损失的静态权重。 如果它是一个浮点数,它将应用于sum(KD losses)。 如果它是一个列表,键是KD损失键,按照criterion参数中指定的顺序,每个键对应的权重将应用于相应的损失值。 如果权重之和不等于1.0,应将student_loss传递给mtd.DistillationModel.compute_kd_loss,权重差异将应用于此损失值。

Raises:

如果kd_loss_weight超出范围,则抛出ValueError

forward(loss)

计算聚合损失。

Parameters:

loss (Dict[str, Tensor]) – 需要聚合的损失字典。

Returns:

平衡学生和kd损失组件后的总损失。

Return type:

张量