PyTorch 模型更新
pytorch_model_update.py 示例演示了如何使用OutputModel 类训练模型并记录它。
该示例执行以下操作:
- 在
examples
项目中创建一个名为Model update pytorch
的任务。 - 在CIFAR10数据集上训练神经网络以进行图像分类。
- 使用OutputModel对象记录模型、其标签枚举和配置字典。
Disabling automatic framework logging
此示例禁用了PyTorch输出的默认自动捕获功能,以演示如何手动控制从PyTorch记录的内容。有关更多信息,请参阅此FAQ。
初始化
为任务实例化了一个OutputModel对象。
from clearml import Task, OutputModel
task = Task.init(
project_name="examples",
task_name="Model update pytorch",
auto_connect_frameworks={"pytorch": False}
)
output_model = OutputModel(task=task)
标签枚举
标签枚举字典通过Task.connect_label_enumeration
方法记录,该方法将更新任务的结果模型信息。当前运行的任务通过Task.current_task
类方法访问。
# store the label enumeration of the training model
classes = ("plane", "car", "bird", "cat", "deer", "dog", "frog", "horse", "ship", "truck",)
enumeration = {k: v for v, k in enumerate(classes, 1)}
Task.current_task().connect_label_enumeration(enumeration)
Directly Setting Model Enumeration
你可以直接使用OutputModel.update_labels
方法来设置模型的标签枚举。
模型配置
使用OutputModel.update_design
方法向模型添加配置字典。
model_config_dict = {
"list_of_ints": [1, 2, 3, 4],
"dict": {
"sub_value": "string",
"sub_integer": 11
},
"value": 13.37
}
model.update_design(config_dict=model_config_dict)
更新模型
要更新模型,请使用OutputModel.update_weights()
。
这将模型上传到设置的存储目的地(参见设置上传目的地),
并将该位置注册为任务的输出模型。
# CONDITION depicts a custom condition for when to save the model. The model is saved and then updated in ClearML
CONDITION = True
if CONDITION:
torch.save(net.state_dict(), PATH)
model.update_weights(weights_filename=PATH)
WebApp
模型出现在任务的ARTIFACTS标签中。
点击模型名称将带您进入模型的页面,在那里您可以查看模型的详细信息并访问模型。
模型的NETWORK选项卡显示其配置。
模型的LABELS标签显示其标签枚举。