损失平衡器
蒸馏任务的基本损失平衡器。
类
损失平衡器的接口。 |
|
基于静态权重的KD损失聚合。 |
- class DistillationLossBalancer
基础:
Module损失平衡器的接口。
- __init__()
构造函数。
- abstract forward(loss)
计算聚合损失。
- Parameters:
loss (Dict[str, Tensor]) – 要聚合的损失字典。 键将是应用于获取损失的损失函数的类名,后缀为_{idx}以确保唯一性。如果提供了学生损失给
mtd.DistillationModel.compute_kd_loss,则它将具有键mtd.loss_balancers.STUDENT_LOSS_KEY。例如,如果criterion参数 传递给mtd.convert的是{("mod1_s", "mod1_t"): torch.nn.MSELoss(), ("mod2_s", "mod2_t"): torch.nn.MSELoss()}并且提供给mtd.DistillationModel.compute_kd_loss的student_loss不是 None,那么这里的损失字典将看起来像{"student_loss": torch.tensor(...), "MSELoss_0": torch.tensor(...), "MSELoss_1": torch.tensor(...)}。- Returns:
平衡学生和kd损失组件后的总损失。
- Return type:
张量
- set_student_loss_reduction_fn(student_loss_reduction_fn)
设置学生损失减少函数值。
在平衡之前减少学生损失的特殊情况下需要。
- Parameters:
student_loss_reduction_fn (Callable[[Any], Tensor]) –
- class StaticLossBalancer
-
基于静态权重的KD损失聚合。
- __init__(kd_loss_weight=0.5)
构造函数。
- Parameters:
kd_loss_weight (float | List[float]) – 用于平衡知识蒸馏损失和原始学生损失的静态权重。 如果它是一个浮点数,它将应用于sum(KD losses)。 如果它是一个列表,键是KD损失键,按照
criterion参数中指定的顺序,每个键对应的权重将应用于相应的损失值。 如果权重之和不等于1.0,应将student_loss传递给mtd.DistillationModel.compute_kd_loss,权重差异将应用于此损失值。- Raises:
如果kd_loss_weight超出范围,则抛出ValueError –
- forward(loss)
计算聚合损失。
- Parameters:
loss (Dict[str, Tensor]) – 需要聚合的损失字典。
- Returns:
平衡学生和kd损失组件后的总损失。
- Return type:
张量