Skip to content

高级自定义

Ultralytics YOLO的命令行和Python接口都只是基础引擎执行器的高级抽象。让我们来看看Trainer引擎。



观看: 掌握Ultralytics YOLO: 高级定制

BaseTrainer

BaseTrainer包含通用的训练样板程序。它可以基于覆盖所需函数或操作进行任何任务的定制,只要遵循正确的格式即可。例如,您可以通过覆盖这些函数来支持您自己的自定义模型和数据加载器:

  • get_model(cfg, weights) - 构建要训练的模型的函数
  • get_dataloader() - 构建数据加载器的函数 更多详情和源代码可以在BaseTrainer参考中找到

DetectionTrainer

以下是如何使用YOLO11 DetectionTrainer并对其进行定制的方法。

from ultralytics.models.yolo.detect import DetectionTrainer

trainer = DetectionTrainer(overrides={...})
trainer.train()
trained_model = trainer.best  # 获取最佳模型

定制DetectionTrainer

让我们定制训练器以训练一个不直接支持的自定义检测模型。您可以通过简单地重载现有的get_model功能来实现这一点:

from ultralytics.models.yolo.detect import DetectionTrainer


class CustomTrainer(DetectionTrainer):
    def get_model(self, cfg, weights):
        """加载给定配置和权重文件的自定义检测模型。"""
        ...


trainer = CustomTrainer(overrides={...})
trainer.train()

您现在意识到需要进一步定制训练器以:

  • 自定义损失函数
  • 添加回调,在每10个epoch后将模型上传到您的Google Drive。以下是如何做到这一点:
from ultralytics.models.yolo.detect import DetectionTrainer
from ultralytics.nn.tasks import DetectionModel


class MyCustomModel(DetectionModel):
    def init_criterion(self):
        """初始化损失函数,并添加每10个epoch后将模型上传到Google Drive的回调。"""
        ...


class CustomTrainer(DetectionTrainer):
    def get_model(self, cfg, weights):
        """返回配置了指定配置和权重的自定义检测模型实例。"""
        return MyCustomModel(...)


# 回调以上传模型权重
def log_model(trainer):
    """记录训练器使用的最后一个模型权重的路径。"""
    last_weight_path = trainer.last
    print(last_weight_path)


trainer = CustomTrainer(overrides={...})
trainer.add_callback("on_train_epoch_end", log_model)  # 添加到现有回调
trainer.train()

要了解更多关于回调触发事件和入口点的信息,请查看我们的回调指南

其他引擎组件

其他组件如验证器预测器也可以类似地进行定制。有关这些的更多信息,请参阅参考部分。

常见问题

如何为特定任务定制Ultralytics YOLO11 DetectionTrainer?

要为特定任务定制Ultralytics YOLO11 DetectionTrainer,您可以覆盖其方法以适应您的自定义模型和数据加载器。首先从DetectionTrainer继承,然后重新定义get_model等方法以实现您的自定义功能。以下是一个示例:

from ultralytics.models.yolo.detect import DetectionTrainer


class CustomTrainer(DetectionTrainer):
    def get_model(self, cfg, weights):
        """加载给定配置和权重文件的自定义检测模型。"""
        ...


trainer = CustomTrainer(overrides={...})
trainer.train()
trained_model = trainer.best  # 获取最佳模型

对于进一步的定制,如更改损失函数或添加回调,您可以参考我们的回调指南

Ultralytics YOLO11中的BaseTrainer的关键组件是什么?

Ultralytics YOLO11中的BaseTrainer作为训练程序的基础,可以通过覆盖其通用方法进行各种任务的定制。关键组件包括:

  • get_model(cfg, weights) 构建要训练的模型。
  • get_dataloader() 构建数据加载器。 有关自定义和源代码的更多详细信息,请参阅 BaseTrainer 参考

如何为 Ultralytics YOLO11 DetectionTrainer 添加回调?

您可以在 Ultralytics YOLO11 DetectionTrainer 中添加回调以监控和修改训练过程。例如,以下是如何添加一个回调以在每次训练 epoch 后记录模型权重的示例:

from ultralytics.models.yolo.detect import DetectionTrainer


# 回调以上传模型权重
def log_model(trainer):
    """记录训练器使用的最后一个模型权重的路径。"""
    last_weight_path = trainer.last
    print(last_weight_path)


trainer = DetectionTrainer(overrides={...})
trainer.add_callback("on_train_epoch_end", log_model)  # 添加到现有回调中
trainer.train()

有关回调事件和入口点的更多详细信息,请参阅我们的 回调指南

为什么我应该使用 Ultralytics YOLO11 进行模型训练?

Ultralytics YOLO11 在强大的引擎执行器上提供了高层次的抽象,非常适合快速开发和定制。主要优势包括:

  • 易用性:命令行和 Python 接口简化了复杂任务。
  • 性能:针对实时 目标检测 和各种视觉 AI 应用进行了优化。
  • 定制化:易于扩展以支持自定义模型、损失函数 和数据加载器。

通过访问 Ultralytics YOLO 了解更多关于 YOLO11 的功能。

我可以将 Ultralytics YOLO11 DetectionTrainer 用于非标准模型吗?

是的,Ultralytics YOLO11 DetectionTrainer 非常灵活,可以定制用于非标准模型。通过继承 DetectionTrainer,您可以重载不同的方法以支持您的特定模型需求。以下是一个简单的示例:

from ultralytics.models.yolo.detect import DetectionTrainer


class CustomDetectionTrainer(DetectionTrainer):
    def get_model(self, cfg, weights):
        """加载自定义检测模型。"""
        ...


trainer = CustomDetectionTrainer(overrides={...})
trainer.train()

有关更全面的说明和示例,请查看 DetectionTrainer 文档。


📅 Created 11 months ago ✏️ Updated 13 days ago

Comments