剪枝
高级API,用于使用各种算法自动修剪和优化您的模型。
函数
通过在设计空间内搜索最佳架构来修剪给定模型。 |
- prune(model, mode, constraints, dummy_input, config=None)
通过在设计空间内搜索最佳架构来修剪给定模型。
- Parameters:
model (Module) – 一个标准模型,包含可以就地修剪的标准构建块。
mode (_ModeDescriptor | str | List[_ModeDescriptor | str] | List[Tuple[str, Dict[str, Any]]]) –
A (list of) string(s) or Mode(s) or a list of tuples containing the mode and its config indicating the desired mode(s) (and configurations) for the convert process. Modes set up the model for different algorithms for model optimization. The following modes are available:
"fastnas": Themodelwill be converted into a search space and set up to automatically perform operations required for FastNAS pruning & search. The mode’s config is described inFastNASConfig. This mode is recommended to prune Computer Vision models."gradnas": Themodelwill be converted into a search space and set up to automatically perform operations required for gradient-based pruning & search. The mode’s config is described inGradNASConfig. This mode is recommended to prune Hugging Face language models like BERT and GPT-J."mcore_gpt_minitron": Themodelwill be converted into a search space and set up to automatically perform operations required for Minitron-style pruning & search. The mode’s config is described inMCoreGPTMinitronConfig. This mode is required to prune NVIDIA Megatron-Core / NeMo GPT-type models.
If the mode argument is specified as a dictionary, the keys should indicate the mode and the values specify the per-mode configuration. If not provided, then default configuration would be used.
constraints (Dict[str, str | float | Dict | None]) –
A dictionary mapping constraint names to their respective values that the pruned model must satisfy. Currently, the supported constraints are
flops,params, andexport_config. If the key isflopsorparams, the value should be an upper bound number or percentage of original. Forexport_config, the value is a dictionary mapping hyperparameter names to their pruned values. For e.g.,:# Specify a flops upper bound as 4.5 GFLOPs constraints = {"flops": 4.5e6} # Specify a percentage-based constraint # (e.g., search for a model with <= 60% of the original model params) constraints = {"params": "60%"} # Specify export_config with pruned hyperparameters # This is supported and required if the model is converted via ``mcore_gpt_minitron`` mode. constraints = { "export_config": { "ffn_hidden_size": 128, "num_attention_heads": 16, "num_query_groups": 4, } }
dummy_input (Any | Tuple) –
Arguments of
model.forward(). This is used for exporting and calculating inference-based metrics, such as FLOPs. The format ofdummy_inputsfollows the convention of theargsargument in torch.onnx.export. Specifically,dummy_inputcan be:a single argument (
type(dummy_input) != tuple) corresponding tomodel.forward(dummy_input)
a tuple of arguments corresponding to
model.forward(*dummy_input)
a tuple of arguments such that
type(dummy_input[-1]) == dictcorresponding tomodel.forward(*dummy_input[:-1], **dummy_input[-1])
Warning
In this case the model’s
forward()method cannot contain keyword-only arguments (e.g.forward(..., *, kw_only_args)) or variable keyword arguments (e.g.forward(..., **kwargs)) since these cannot be sorted into positional arguments.
Note
In order to pass a dict as last non-keyword argument, you need to use a tuple as
dummy_inputand add an empty dict as the last element, e.g.,dummy_input = (x, {"y": y, "z": z}, {})
The empty dict at the end will then be interpreted as the keyword args.
See torch.onnx.export for more info.
Note that if you provide a
{arg_name}with batch sizeb, the results will be computed based on batch sizeb.config (Dict[str, Any] | None) –
Additional optional arguments to configure the search. Currently, we support:
checkpoint: Path to save/restore checkpoint with dictionary containing intermediate search state. If provided, the intermediate search state will be automatically restored before search (if exists) and stored/saved during search.verbose: Whether to print detailed search space profiling and search stats during search.forward_loop: ACallablethat takes a model as input and runs a forward loop on it. It is recommended to choose the data loader used inside the forward loop carefully to reduce the runtime. Cannot be provided at the same time asdata_loaderandcollect_func.data_loader: An iterator yielding batches of data for calibrating the normalization layers in the model or compute gradient scores. It is recommended to use the same data loader as for training but with significantly fewer iterations. Cannot be provided at the same time asforward_loop.collect_func: ACallablethat takes a batch of data from the data loader as input and returns the input tomodel.forward()as described inrun_forward_loop. Cannot be provided at the same time asforward_loop.max_iter_data_loader: Maximum number of iterations to run the data loader.score_func: A callable taking the model as input and returning a single accuracy/score metric (float). This metric will be maximized during search.Note
The
score_funcis required only forfastnasmode. It will be evaluated on models in eval mode (model.eval()).loss_func: ACallablewhich takes the model output (i.e output ofmodel.forward()) and the batch of data as its inputs and returns a scalar loss. This is a required argument if the model is converted viagradnasmode.It should be possible to run a backward pass on the loss value returned by this method.
collect_funcwill be used to gather the inputs tomodel.forward()from a batch of data yielded by``data_loader``.loss_funcshould support the following usage:for i, batch in enumerate(data_loader): if i >= max_iter_data_loader: break # Assuming collect_func returns a tuple of arguments output = model(*collect_func(batch)) loss = loss_func(output, batch) loss.backward()
Note
Additional configuration options may be added by individual algorithms. Please refer to the documentation of the individual algorithms for more information.
- Return type:
元组[模块, 字典[字符串, 任意类型]]
- Returns: A tuple (subnet, state_dict) where
subnet 是搜索到的子网 (nn.Module),可以用于后续任务,如微调,state_dict 包含搜索过程的历史和详细统计信息。
注意
给定的模型被就地修改(导出)以匹配搜索算法找到的最佳子网。因此,返回的子网是对与输入模型相同的模型实例的引用。