Reference for ultralytics/engine/trainer.py
Note
This file is available at https://github.com/ultralytics/ultralytics/blob/main/ultralytics/engine/trainer.py. If you spot a problem please help fix it by contributing a Pull Request 🛠️. Thank you 🙏!
ultralytics.engine.trainer.BaseTrainer
A base class for creating trainers.
Attributes:
Name | Type | Description |
---|---|---|
args |
SimpleNamespace
|
Configuration for the trainer. |
validator |
BaseValidator
|
Validator instance. |
model |
Module
|
Model instance. |
callbacks |
defaultdict
|
Dictionary of callbacks. |
save_dir |
Path
|
Directory to save results. |
wdir |
Path
|
Directory to save weights. |
last |
Path
|
Path to the last checkpoint. |
best |
Path
|
Path to the best checkpoint. |
save_period |
int
|
Save checkpoint every x epochs (disabled if < 1). |
batch_size |
int
|
Batch size for training. |
epochs |
int
|
Number of epochs to train for. |
start_epoch |
int
|
Starting epoch for training. |
device |
device
|
Device to use for training. |
amp |
bool
|
Flag to enable AMP (Automatic Mixed Precision). |
scaler |
GradScaler
|
Gradient scaler for AMP. |
data |
str
|
Path to data. |
trainset |
Dataset
|
Training dataset. |
testset |
Dataset
|
Testing dataset. |
ema |
Module
|
EMA (Exponential Moving Average) of the model. |
resume |
bool
|
Resume training from a checkpoint. |
lf |
Module
|
Loss function. |
scheduler |
_LRScheduler
|
Learning rate scheduler. |
best_fitness |
float
|
The best fitness value achieved. |
fitness |
float
|
Current fitness value. |
loss |
float
|
Current loss value. |
tloss |
float
|
Total loss value. |
loss_names |
list
|
List of loss names. |
csv |
Path
|
Path to results CSV file. |
Parameters:
Name | Type | Description | Default |
---|---|---|---|
cfg
|
str
|
Path to a configuration file. Defaults to DEFAULT_CFG. |
DEFAULT_CFG
|
overrides
|
dict
|
Configuration overrides. Defaults to None. |
None
|
Source code in ultralytics/engine/trainer.py
add_callback
build_dataset
build_optimizer
build_optimizer(model, name='auto', lr=0.001, momentum=0.9, decay=1e-05, iterations=100000.0)
Constructs an optimizer for the given model, based on the specified optimizer name, learning rate, momentum, weight decay, and number of iterations.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
model
|
Module
|
The model for which to build an optimizer. |
required |
name
|
str
|
The name of the optimizer to use. If 'auto', the optimizer is selected based on the number of iterations. Default: 'auto'. |
'auto'
|
lr
|
float
|
The learning rate for the optimizer. Default: 0.001. |
0.001
|
momentum
|
float
|
The momentum factor for the optimizer. Default: 0.9. |
0.9
|
decay
|
float
|
The weight decay for the optimizer. Default: 1e-5. |
1e-05
|
iterations
|
float
|
The number of iterations, which determines the optimizer if name is 'auto'. Default: 1e5. |
100000.0
|
Returns:
Type | Description |
---|---|
Optimizer
|
The constructed optimizer. |
Source code in ultralytics/engine/trainer.py
build_targets
check_resume
Check if resume checkpoint exists and update arguments accordingly.
Source code in ultralytics/engine/trainer.py
final_eval
Performs final evaluation and validation for object detection YOLO model.
Source code in ultralytics/engine/trainer.py
get_dataloader
Returns dataloader derived from torch.data.Dataloader.
get_dataset
Get train, val path from data dict if it exists.
Returns None if data format is not recognized.
Source code in ultralytics/engine/trainer.py
get_model
Get model and raise NotImplementedError for loading cfg files.
get_validator
Returns a NotImplementedError when the get_validator function is called.
label_loss_items
Returns a loss dict with labelled training loss items tensor.
Note
This is not needed for classification but necessary for segmentation & detection
Source code in ultralytics/engine/trainer.py
on_plot
optimizer_step
Perform a single step of the training optimizer with gradient clipping and EMA update.
Source code in ultralytics/engine/trainer.py
plot_metrics
plot_training_labels
plot_training_samples
preprocess_batch
progress_string
read_results_csv
resume_training
Resume YOLO training from given epoch and best fitness.
Source code in ultralytics/engine/trainer.py
run_callbacks
save_metrics
Saves training metrics to a CSV file.
Source code in ultralytics/engine/trainer.py
save_model
Save model training checkpoints with additional metadata.
Source code in ultralytics/engine/trainer.py
set_callback
set_model_attributes
setup_model
Load/create/download model for any task.
Source code in ultralytics/engine/trainer.py
train
Allow device='', device=None on Multi-GPU systems to default to device=0.
Source code in ultralytics/engine/trainer.py
validate
Runs validation on test set using self.validator.
The returned dict is expected to contain "fitness" key.