Shortcuts

torch.nn.utils.prune.is_pruned

torch.nn.utils.prune.is_pruned(module)[源代码]

检查模块是否被剪枝,通过查找剪枝预钩子。

检查 module 是否被修剪,方法是查找其模块中继承自 BasePruningMethodforward_pre_hooks

Parameters

模块 (nn.Module) – 被剪枝或未被剪枝的对象

Returns

是否对 module 进行了剪枝的二进制答案。

示例

>>> from torch.nn.utils import prune
>>> m = nn.Linear(5, 7)
>>> print(prune.is_pruned(m))
False
>>> prune.random_unstructured(m, name='weight', amount=0.2)
>>> print(prune.is_pruned(m))
True
优云智算