gradnas

实现用于搜索的gradnas剪枝算法的模块。

总结:

gradnas 算法在语言模型中对各种剪枝选择进行排序时,比 L1 范数(fastnas)给出了更好的评分。

详情:

此外,我们可以获得即使抽象实现的超参数的分数。 例如,我们可以使用此算法对多头注意力层中的头进行排序。注意力头 没有与之关联的唯一张量参数。

我们正在根据损失相对于修剪掩码的梯度的平方和来对特定超参数的可修剪选择进行排名。 超参数的修剪掩码是一个二进制掩码,指示超参数的哪些选择被修剪 (0表示被修剪,1表示未被修剪)。

在计算损失的向后梯度时,所有张量的掩码都设置为1。 有关使用掩码来测量敏感性的更多信息,请参阅本文:https://arxiv.org/pdf/1905.10650.pdf

GradientBinarySearcher

梯度算法的二分搜索器。

GradientDataManager

用于管理hparam的梯度数据的类。

class GradientBinarySearcher

基础类: BinarySearcher

梯度算法的二分搜索器。

SETUP_GRADIENT_FUNC: Dict[Type[动态模块], Callable[[动态模块], Tuple[GradientDataManager, RemovableHandle]]]

设置基于梯度的分数搜索。

Return type:

property default_search_config: Dict[str, Any]

获取搜索器的默认配置。

static gradnas_score_func(model)

gradnas算法的评分函数。

如果我们从层L中修剪N个神经元,总退化是这N个被修剪神经元的退化值之和。在fast算法中,修剪导致的退化直接从validation_score(修剪后的模型)中估计。算法的其余部分与fast算法完全相同。

Parameters:

模型 (模块) –

Return type:

浮点数

我们只能在梯度二分搜索中优化某些类型的超参数。

sanitize_search_config(config)

清理搜索配置字典。

Parameters:

config (Dict[str, Any] | None) –

Return type:

Dict[str, Any]

class GradientDataManager

基础类:object

用于管理hparam的梯度数据的类。

__init__(shape, model, reduce_func=<function GradientDataManager.<lambda>>)

初始化 GradientDataManager。

process_gradient()

处理掩码的梯度。

property score

基于存储梯度的hparam得分。