OptimWrapper¶
- class mmengine.optim.OptimWrapper(optimizer, accumulative_counts=1, clip_grad=None)[源代码]¶
Optimizer wrapper provides a common interface for updating parameters.
Optimizer wrapper provides a unified interface for single precision training and automatic mixed precision training with different hardware. OptimWrapper encapsulates optimizer to provide simplified interfaces for commonly used training techniques such as gradient accumulative and grad clips.
OptimWrapperimplements the basic logic of gradient accumulation and gradient clipping based ontorch.optim.Optimizer. The subclasses only need to override some methods to implement the mixed precision training. See more information inAmpOptimWrapper.- 参数:
optimizer (Optimizer) – Optimizer used to update model parameters.
accumulative_counts (int) – The number of iterations to accumulate gradients. The parameters will be updated per
accumulative_counts.clip_grad (dict, optional) –
If
clip_gradis not None, it will be the arguments oftorch.nn.utils.clip_grad_norm_()ortorch.nn.utils.clip_grad_value_().clip_gradshould be a dict, and the keys could be set as follows:If the key
typeis not set, ortypeis “norm”, the accepted keys are as follows:max_norm (float or int): Max norm of the gradients.
norm_type (float or int): Type of the used p-norm. Can be
'inf'for infinity norm.error_if_nonfinite (bool): If True, an error is thrown if the total norm of the gradients from
parametersisnan,inf, or-inf. Defaults to False (will switch to True in the future)
If the key
typeis set to “value”, the accepted keys are as follows:clip_value (float or int): maximum allowed value of the gradients. The gradients are clipped in the range
(-clip_value, +clip_value).
备注
If
accumulative_countsis larger than 1, performupdate_params()under the context ofoptim_contextcould avoid unnecessary gradient synchronization.备注
If you use
IterBasedRunnerand enable gradient accumulation, the original max_iters should be multiplied byaccumulative_counts.备注
The subclass should ensure that once
update_params()is called,_inner_count += 1is automatically performed.示例
>>> # Config sample of OptimWrapper and enable clipping gradient by >>> # norm. >>> optim_wrapper_cfg = dict( >>> type='OptimWrapper', >>> _accumulative_counts=1, >>> clip_grad=dict(max_norm=0.2)) >>> # Config sample of OptimWrapper and enable clipping gradient by >>> # value. >>> optim_wrapper_cfg = dict( >>> type='OptimWrapper', >>> _accumulative_counts=1, >>> clip_grad=dict(type='value', clip_value=0.2)) >>> # Use OptimWrapper to update model. >>> import torch.nn as nn >>> import torch >>> from torch.optim import SGD >>> from torch.utils.data import DataLoader >>> from mmengine.optim import OptimWrapper >>> >>> model = nn.Linear(1, 1) >>> dataset = torch.randn(10, 1, 1) >>> dataloader = DataLoader(dataset) >>> optimizer = SGD(model.parameters(), lr=0.1) >>> optim_wrapper = OptimWrapper(optimizer) >>> >>> for data in dataloader: >>> loss = model(data) >>> optim_wrapper.update_params(loss) >>> # Enable gradient accumulation >>> optim_wrapper_cfg = dict( >>> type='OptimWrapper', >>> _accumulative_counts=3, >>> clip_grad=dict(max_norm=0.2)) >>> ddp_model = DistributedDataParallel(model) >>> optimizer = SGD(ddp_model.parameters(), lr=0.1) >>> optim_wrapper = OptimWrapper(optimizer) >>> optim_wrapper.initialize_count_status(0, len(dataloader)) >>> # If model is a subclass instance of DistributedDataParallel, >>> # `optim_context` context manager can avoid unnecessary gradient >>> # synchronize. >>> for iter, data in enumerate(dataloader): >>> with optim_wrapper.optim_context(ddp_model): >>> loss = model(data) >>> optim_wrapper.update_params(loss)
- backward(loss, **kwargs)[源代码]¶
Perform gradient back propagation.
Provide unified
backwardinterface compatible with automatic mixed precision training. Subclass can overload this method to implement the required logic. For example,torch.cuda.amprequire some extra operation on GradScaler during backward process.备注
If subclasses inherit from
OptimWrapperoverridebackward,_inner_count +=1must be implemented.- 参数:
loss (torch.Tensor) – The loss of current iteration.
kwargs – Keyword arguments passed to
torch.Tensor.backward().
- 返回类型:
None
- initialize_count_status(model, init_counts, max_counts)[源代码]¶
Initialize gradient accumulation related attributes.
OptimWrappercan be used without callinginitialize_iter_status. However, Consider the case oflen( dataloader) == 10, and theaccumulative_iter == 3. Since 10 is not divisible by 3, the last iteration will not triggeroptimizer.step(), resulting in one less parameter updating.
- property inner_count¶
Get the number of updating parameters of optimizer wrapper.
- optim_context(model)[源代码]¶
A Context for gradient accumulation and automatic mix precision training.
If subclasses need to enable the context for mix precision training, e.g.,
:class:`AmpOptimWrapper, the corresponding context should be enabled in optim_context. SinceOptimWrapperuses default fp32 training,optim_contextwill only enable the context for blocking the unnecessary gradient synchronization during gradient accumulationIf model is an instance with
no_syncmethod (which means blocking the gradient synchronization) andself._accumulative_counts != 1. The model will not automatically synchronize gradients ifcur_iteris divisible byself._accumulative_counts. Otherwise, this method will enable an empty context.- 参数:
model (nn.Module) – The training model.
- scale_loss(loss)[源代码]¶
Get scaled loss according to
_accumulative_counts,_inner_countand max_counts.- 参数:
loss (torch.Tensor) – Original loss calculated by model.
- 返回:
Scaled loss.
- 返回类型:
loss (torch.Tensor)
- should_sync()[源代码]¶
Decide whether the automatic gradient synchronization should be allowed at the current iteration.
It takes effect when gradient accumulation is used to skip synchronization at the iterations where the parameter is not updated.
Since
should_syncis called byoptim_context(), and it is called beforebackward()which meansself._inner_count += 1has not happened yet. Therefore,self._inner_count += 1should be performed manually here.- 返回:
Whether to block the automatic gradient synchronization.
- 返回类型:
- should_update()[源代码]¶
Decide whether the parameters should be updated at the current iteration.
Called by
update_params()and check whether the optimizer wrapper should update parameters at current iteration.- 返回:
Whether to update parameters.
- 返回类型:
- step(**kwargs)[源代码]¶
A wrapper of
Optimizer.step.Provide unified
stepinterface compatible with automatic mixed precision training. Subclass can overload this method to implement the required logic. For example,torch.cuda.amprequire some extra operation onGradScalerduring step process.Clip grad if
clip_grad_kwargsis not None, and then update parameters.- 参数:
kwargs – Keyword arguments passed to
torch.optim.Optimizer.step().- 返回类型:
None
- update_params(loss, step_kwargs=None, zero_kwargs=None)[源代码]¶
Update parameters in
optimizer.- 参数:
loss (torch.Tensor) – A tensor for back propagation.
step_kwargs (dict) – Arguments for optimizer.step. Defaults to None. New in version v0.4.0.
zero_kwargs (dict) – Arguments for optimizer.zero_grad. Defaults to None. New in version v0.4.0.
- 返回类型:
None
- zero_grad(**kwargs)[源代码]¶
A wrapper of
Optimizer.zero_grad.Provide unified
zero_gradinterface compatible with automatic mixed precision training. Subclass can overload this method to implement the required logic.- 参数:
kwargs – Keyword arguments passed to
torch.optim.Optimizer.zero_grad().- 返回类型:
None