训练
训练深度学习模型可能会变得非常复杂。PyTorch Tabular 通过继承 PyTorch Lightning,将整个工作负载转移到底层的 PyTorch Lightning 框架上。它的设计旨在使模型训练变得轻而易举,同时赋予你灵活性,让你能够定制训练过程。
PyTorch Tabular 中的训练器直接或间接地继承了 PyTorch Lightning 训练器的所有功能。
基本用法¶
你最常设置的参数包括:
batch_size
: int: 每个训练批次中的样本数量。默认为64
max_epochs
: int: 要运行的最大周期数。在启用早停的情况下,这是最大值;在没有早停的情况下,这是将要运行的周期数。默认为10
-
devices
: (Optional[int]): 用于训练的设备数量(整数)。-1 表示使用所有可用设备。默认使用所有可用设备(-1) -
accelerator
: Optional[str]: 用于训练的加速器。可以是 'cpu'、'gpu'、'tpu'、'ipu'、'auto' 之一。默认为 'auto'。 load_best
: int: 标志,用于在训练期间加载保存的最佳模型。如果关闭了检查点保存,则此项将被忽略。默认为 True
使用示例¶
PyTorch Tabular 默认使用早停机制,并监控 valid_loss
以停止训练。检查点保存也默认开启,它会监控 valid_loss
并将最佳模型保存在 saved_models
文件夹中。所有这些都可以在下一节中进行配置。
高级用法¶
早停和检查点保存¶
早停默认开启。但你可以通过将 early_stopping
设置为 None
来关闭它。如果你想监控其他指标,只需在 early_stopping
参数中提供该指标名称。控制早停的其他几个参数包括:
early_stopping_min_delta
: float: 损失/指标中被视为改进的最小增量。默认为0.001
early_stopping_mode
: str: 损失/指标应优化的方向。选项为max
和min
。默认为min
early_stopping_patience
: int: 在没有进一步改进损失/指标的情况下等待的周期数。默认为3
min_epochs
: int: 要运行的最小周期数。无论停止标准如何,都会运行这么多周期。默认为1
检查点保存也默认开启,要关闭它,可以将 checkpoints
参数设置为 None
。如果你想监控其他指标,只需在 early_stopping
参数中提供该指标名称。控制检查点保存的其他几个参数包括:
checkpoints_path
: str: 保存模型的路径。默认为saved_models
checkpoints_mode
: str: 损失/指标应优化的方向。选项为max
和min
。默认为min
checkpoints_save_top_k
: int: 要保存的最佳模型数量。如果你想保存多个最佳模型,可以将此参数设置为 >1。默认为1
Note
确保你要跟踪的指标/损失名称与日志中的名称完全匹配。推荐的方法是运行一个模型并评估结果。从结果字典中,你可以选择一个键来在训练期间跟踪。
学习率查找器¶
首先在这篇论文 Cyclical Learning Rates for Training Neural Networks 中提出,随后被 fast.ai 推广,这是一种无需昂贵搜索即可达到最优学习率附近的技术。PyTorch Tabular 允许你使用论文中提出的方法找到最佳学习率,并自动将其用于训练网络。所有这些都可以通过一个简单的标志 auto_lr_find
开启。
我们还可以使用 [pytorch_tabular.TabularModel.find_learning_rate] 作为一个单独的步骤来运行学习率查找器。
控制梯度/优化¶
在训练过程中,有时你可能需要对梯度优化过程进行更严格的控制。例如,如果梯度爆炸,你可能希望在每次更新前裁剪梯度值。gradient_clip_val
允许你这样做。
有时,你可能希望在执行反向传播之前跨多个批次累积梯度(可能是因为较大的批次大小不适合你的 GPU)。PyTorch Tabular 允许你通过 accumulate_grad_batches
来实现这一点。
调试¶
很多时候,你需要调试模型,看看为什么它没有按预期表现。甚至在开发新模型时,你也需要大量调试模型。PyTorch Lightning 为此提供了一些功能,PyTorch Tabular 也采用了这些功能。 为了找出性能瓶颈,我们可以使用:
profiler
: Optional[str]: 在训练过程中分析各个步骤,以帮助识别瓶颈。可选值为:None
simple
advanced
。默认为None
为了检查整个设置是否无误运行,我们可以使用:
fast_dev_run
: Optional[str]: 快速调试验证运行。默认为False
如果模型学习不正常:
-
overfit_batches
: float: 使用训练集的这部分数据。如果不为零,将使用相同的训练集进行验证和测试。如果训练数据加载器设置了 shuffle=True,Lightning 会自动禁用它。适用于快速调试或有意过拟合。默认为0
-
track_grad_norm
: bool: 仅在设置实验跟踪时使用。在日志记录器中跟踪和记录梯度范数。默认值为 -1 表示不跟踪。1 表示 L1 范数,2 表示 L2 范数,依此类推。默认为False
。如果梯度范数迅速降至零,则存在问题。
使用完整的 PyTorch Lightning Trainer¶
要充分发挥 PyTorch Lightning Trainer 的潜力,可以使用 trainer_kwargs
参数。这将允许你传递 PyTorch Lightning Trainer 支持的任何参数。完整的文档可以在这里找到
pytorch_tabular.config.TrainerConfig
dataclass
¶
训练器配置.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
batch_size
|
int
|
每个训练批次中的样本数量 |
64
|
data_aware_init_batch_size
|
int
|
数据感知初始化时每个训练批次中的样本数量, 适用时默认值为2000 |
2000
|
fast_dev_run
|
bool
|
如果设置为 |
False
|
max_epochs
|
int
|
要运行的最大周期数 |
10
|
min_epochs
|
Optional[int]
|
强制训练至少这些周期数.默认值为1 |
1
|
max_time
|
Optional[int]
|
经过此时间后停止训练.默认禁用(None) |
None
|
accelerator
|
Optional[str]
|
用于训练的加速器.可以是以下之一:
'cpu','gpu','tpu','ipu', 'mps', 'auto'.默认为'auto'.
可选值为:[ |
'auto'
|
devices
|
Optional[int]
|
用于训练的设备数量(整数).-1表示使用所有可用设备. 默认情况下,使用所有可用设备(-1) |
-1
|
devices_list
|
Optional[List[int]]
|
用于训练的设备列表(列表).如果指定,
优先于 |
None
|
accumulate_grad_batches
|
int
|
每k个批次或按字典设置累积梯度.训练器 也会在最后一个不可整除的步骤数上调用optimizer.step(). |
1
|
auto_lr_find
|
bool
|
在调用trainer.tune()时运行学习率查找算法, 以找到最佳初始学习率. |
False
|
auto_select_gpus
|
bool
|
如果启用且 |
True
|
check_val_every_n_epoch
|
int
|
每n个训练周期检查一次验证. |
1
|
gradient_clip_val
|
float
|
梯度裁剪值 |
0.0
|
overfit_batches
|
float
|
使用训练集的此部分数据.如果不为零,将使用相同的 训练集进行验证和测试.如果训练数据加载器的shuffle=True,Lightning 将自动禁用它.对于快速调试或故意过拟合很有用. |
0.0
|
deterministic
|
bool
|
如果为真,启用cudnn.deterministic.可能会使系统变慢,但 确保可重复性. |
False
|
profiler
|
Optional[str]
|
在训练期间分析各个步骤并协助识别
瓶颈.可以是None、simple或advanced、pytorch.可选值为:
[ |
None
|
early_stopping
|
Optional[str]
|
需要监控的损失/指标以进行早停.如果 为None,则不会进行早停 |
'valid_loss'
|
early_stopping_min_delta
|
float
|
早停中损失/指标的最小变化量, 符合改进条件 |
0.001
|
early_stopping_mode
|
str
|
损失/指标应优化的方向.可选值为:
[ |
'min'
|
early_stopping_patience
|
int
|
在损失/指标没有进一步改善之前等待的周期数 |
3
|
early_stopping_kwargs
|
Optional[Dict]
|
早停回调的额外关键字参数. 有关更多详细信息,请参阅PyTorch Lightning EarlyStopping回调的文档. |
lambda: {}()
|
checkpoints
|
Optional[str]
|
需要监控的损失/指标以进行检查点保存.如果为None, 则不会进行检查点保存 |
'valid_loss'
|
checkpoints_path
|
str
|
保存模型的路径 |
'saved_models'
|
checkpoints_every_n_epochs
|
int
|
检查点之间的训练步数 |
1
|
checkpoints_name
|
Optional[str]
|
保存模型的名称.如果留空,
首先会查找experiment_config中的 |
None
|
checkpoints_mode
|
str
|
损失/指标应优化的方向 |
'min'
|
checkpoints_save_top_k
|
int
|
保存的最佳模型数量 |
1
|
checkpoints_kwargs
|
Optional[Dict]
|
检查点回调的额外关键字参数. 有关更多详细信息,请参阅PyTorch Lightning ModelCheckpoint回调的文档. |
lambda: {}()
|
load_best
|
bool
|
标志以加载训练期间保存的最佳模型 |
True
|
track_grad_norm
|
int
|
在日志记录器中跟踪和记录梯度范数.默认值-1表示不跟踪. 1表示L1范数,2表示L2范数,依此类推. |
-1
|
progress_bar
|
str
|
进度条类型.可以是以下之一: |
'rich'
|
precision
|
int
|
模型的精度.可以是以下之一: |
32
|
seed
|
int
|
随机数生成器的种子.默认为42 |
42
|
trainer_kwargs
|
Dict[str, Any]
|
传递给PyTorch Lightning Trainer的额外关键字参数.请参阅 https://pytorch-lightning.readthedocs.io/en/latest/api/pytorch_lightning.trainer.html#pytorch_lightning.trainer.Trainer |
dict()
|
Source code in src/pytorch_tabular/config/config.py
233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 263 264 265 266 267 268 269 270 271 272 273 274 275 276 277 278 279 280 281 282 283 284 285 286 287 288 289 290 291 292 293 294 295 296 297 298 299 300 301 302 303 304 305 306 307 308 309 310 311 312 313 314 315 316 317 318 319 320 321 322 323 324 325 326 327 328 329 330 331 332 333 334 335 336 337 338 339 340 341 342 343 344 345 346 347 348 349 350 351 352 353 354 355 356 357 358 359 360 361 362 363 364 365 366 367 368 369 370 371 372 373 374 375 376 377 378 379 380 381 382 383 384 385 386 387 388 389 390 391 392 393 394 395 396 397 398 399 400 401 402 403 404 405 406 407 408 409 410 411 412 413 414 415 416 417 418 419 420 421 422 423 424 425 426 427 428 429 430 431 432 433 434 435 436 437 438 439 440 441 442 443 444 445 446 447 448 449 450 451 452 453 454 455 456 457 458 459 460 461 462 463 464 465 466 467 468 469 470 471 472 473 474 475 476 477 478 479 480 481 482 483 484 485 486 487 488 489 490 491 492 493 494 495 496 497 498 499 500 501 502 503 504 505 506 507 508 509 510 511 512 513 514 515 516 517 518 519 520 521 522 523 524 525 526 527 528 529 530 531 532 533 534 535 536 537 538 539 540 541 542 |
|