torch.nn.utils.prune.is_pruned¶
- torch.nn.utils.prune.is_pruned(module)[源代码]¶
检查模块是否被剪枝,通过查找剪枝预钩子。
检查
module是否被修剪,方法是查找其模块中继承自BasePruningMethod的forward_pre_hooks。- Parameters
模块 (nn.Module) – 被剪枝或未被剪枝的对象
- Returns
是否对
module进行了剪枝的二进制答案。
示例
>>> from torch.nn.utils import prune >>> m = nn.Linear(5, 7) >>> print(prune.is_pruned(m)) False >>> prune.random_unstructured(m, name='weight', amount=0.2) >>> print(prune.is_pruned(m)) True