AmpOptimWrapper¶
- class mmengine.optim.AmpOptimWrapper(loss_scale='dynamic', dtype=None, use_fsdp=False, **kwargs)[源代码]¶
A subclass of
OptimWrapperthat supports automatic mixed precision training based on torch.cuda.amp.AmpOptimWrapperprovides a unified interface withOptimWrapper, soAmpOptimWrappercan be used in the same way asOptimWrapper.警告
AmpOptimWrapperrequires PyTorch >= 1.6.- 参数:
loss_scale (float or str or dict) –
The initial configuration of torch.cuda.amp.GradScaler. See more specific arguments introduction at PyTorch AMP # noqa: E501 Defaults to
dynamic.”dynamic”: Initialize GradScale without any arguments.
float: Initialize GradScaler with
init_scale.dict: Initialize GradScaler with more detail configuration.
dtype (str or torch.dtype, optional) – The data type to autocast in amp. If a
stris given, it will be converted totorch.dtype. Validstrformat are ‘float16’, ‘bfloat16’, ‘float32’ and ‘float64’. If set toNone, the default data type will be used. Defaults to None. New in version 0.6.1.use_fsdp (bool) – Using
ShardedGradScalerwhen it is True. It should be enabled when usingFullyShardedDataParallel. Defaults to False. New in version 0.8.0.**kwargs – Keyword arguments passed to OptimWrapper.
警告
dtypeargument is only available with PyTorch version >= 1.10.0. If you use PyTorch of an older version, it will be ignored.备注
If you use
IterBasedRunnerand enable gradient accumulation, the original max_iters should be multiplied byaccumulative_counts.- backward(loss, **kwargs)[源代码]¶
Perform gradient back propagation with
loss_scaler.- 参数:
loss (torch.Tensor) – The loss of current iteration.
kwargs – Keyword arguments passed to
torch.Tensor.backward()
- load_state_dict(state_dict)[源代码]¶
Load and parse the state dictionary of
optimizerandloss_scaler.If state_dict contains “loss_scaler.”, the
loss_scalerwill load the corresponding keys. Otherwise, only theoptimizerwill load the state dictionary.- 参数:
state_dict (dict) – The state dict of
optimizerandloss_scaler
- optim_context(model)[源代码]¶
Enables the context for mixed precision training, and enables the context for disabling gradient synchronization during gradient accumulation context.
- 参数:
model (nn.Module) – The training model.
- state_dict()[源代码]¶
Get the state dictionary of
optimizerandloss_scaler.Based on the state dictionary of the optimizer, the returned state dictionary will add a key named “loss_scaler”.
- 返回:
The merged state dict of
loss_scalerandoptimizer.- 返回类型:
- step(**kwargs)[源代码]¶
Update parameters with
loss_scaler.- 参数:
kwargs – Keyword arguments passed to
torch.optim.Optimizer.step().