进度与日志记录

! [ -e /content ] && pip install -Uqq fastai  # 在Colab上升级fastai
from __future__ import annotations
from fastai.basics import *
from nbdev.showdoc import *

回调和辅助函数用于跟踪训练进度或记录结果

from fastai.test_utils import *

进度回调 -

@docs
class ProgressCallback(Callback):
    "A `Callback` to handle the display of progress bars"
    order,_stateattrs = 60,('mbar','pbar')

    def before_fit(self):
        assert hasattr(self.learn, 'recorder')
        if self.create_mbar: self.mbar = master_bar(list(range(self.n_epoch)))
        if self.learn.logger != noop:
            self.old_logger,self.learn.logger = self.logger,self._write_stats
            self._write_stats(self.recorder.metric_names)
        else: self.old_logger = noop

    def before_epoch(self):
        if getattr(self, 'mbar', False): self.mbar.update(self.epoch)

    def before_train(self):    self._launch_pbar()
    def before_validate(self): self._launch_pbar()
    def after_train(self):     self.pbar.on_iter_end()
    def after_validate(self):  self.pbar.on_iter_end()
    def after_batch(self):
        self.pbar.update(self.iter+1)
        if hasattr(self, 'smooth_loss'): self.pbar.comment = f'{self.smooth_loss.item():.4f}'

    def _launch_pbar(self):
        self.pbar = progress_bar(self.dl, parent=getattr(self, 'mbar', None), leave=False)
        self.pbar.update(0)

    def after_fit(self):
        if getattr(self, 'mbar', False):
            self.mbar.on_iter_end()
            delattr(self, 'mbar')
        if hasattr(self, 'old_logger'): self.learn.logger = self.old_logger

    def _write_stats(self, log):
        if getattr(self, 'mbar', False): self.mbar.write([f'{l:.6f}' if isinstance(l, float) else str(l) for l in log], table=True)

    _docs = dict(before_fit="Setup the master bar over the epochs",
                 before_epoch="Update the master bar",
                 before_train="Launch a progress bar over the training dataloader",
                 before_validate="Launch a progress bar over the validation dataloader",
                 after_train="Close the progress bar over the training dataloader",
                 after_validate="Close the progress bar over the validation dataloader",
                 after_batch="Update the current progress bar",
                 after_fit="Close the master bar")

if not hasattr(defaults, 'callbacks'): defaults.callbacks = [TrainEvalCallback, Recorder, ProgressCallback]
elif ProgressCallback not in defaults.callbacks: defaults.callbacks.append(ProgressCallback)
learn = synth_learner()
learn.fit(5)
epoch train_loss valid_loss time
0 14.523648 10.988108 00:00
1 12.395808 7.306935 00:00
2 10.121231 4.370981 00:00
3 8.065226 2.487984 00:00
4 6.374166 1.368232 00:00
@patch
@contextmanager
def no_bar(self:Learner):
    "Context manager that deactivates the use of progress bars"
    has_progress = hasattr(self, 'progress')
    if has_progress: self.remove_cb(self.progress)
    try: yield self
    finally:
        if has_progress: self.add_cb(ProgressCallback())
learn = synth_learner()
with learn.no_bar(): learn.fit(5)
[0, 15.748106002807617, 12.352150917053223, '00:00']
[1, 13.818815231323242, 8.879858016967773, '00:00']
[2, 11.650713920593262, 5.857329845428467, '00:00']
[3, 9.595088005065918, 3.7397098541259766, '00:00']
[4, 7.814438343048096, 2.327916145324707, '00:00']
#检查验证在无训练情况下是否有效
def tst_metric(out, targ): return F.mse_loss(out, targ)
learn = synth_learner(n_trn=5, metrics=tst_metric)
preds,targs = learn.validate()
#检查get_preds在未经过任何训练的情况下是否正常工作
learn = synth_learner(n_trn=5, metrics=tst_metric)
preds,targs = learn.validate()
show_doc(ProgressCallback.before_fit)

source

ProgressCallback.before_fit

 ProgressCallback.before_fit ()

Setup the master bar over the epochs

show_doc(ProgressCallback.before_epoch)

source

ProgressCallback.before_epoch

 ProgressCallback.before_epoch ()

Update the master bar

show_doc(ProgressCallback.before_train)

source

ProgressCallback.before_train

 ProgressCallback.before_train ()

Launch a progress bar over the training dataloader

show_doc(ProgressCallback.before_validate)

source

ProgressCallback.before_validate

 ProgressCallback.before_validate ()

Launch a progress bar over the validation dataloader

show_doc(ProgressCallback.after_batch)

source

ProgressCallback.after_batch

 ProgressCallback.after_batch ()

Update the current progress bar

show_doc(ProgressCallback.after_train)

source

ProgressCallback.after_train

 ProgressCallback.after_train ()

Close the progress bar over the training dataloader

show_doc(ProgressCallback.after_validate)

source

ProgressCallback.after_validate

 ProgressCallback.after_validate ()

Close the progress bar over the validation dataloader

show_doc(ProgressCallback.after_fit)

source

ProgressCallback.after_fit

 ProgressCallback.after_fit ()

Close the master bar

ShowGraphCallback -

class ShowGraphCallback(Callback):
    "Update a graph of training and validation loss"
    order,run_valid=65,False

    def before_fit(self):
        self.run = not hasattr(self.learn, 'lr_finder') and not hasattr(self, "gather_preds")
        if not(self.run): return
        self.nb_batches = []
        assert hasattr(self.learn, 'progress')

    def after_train(self): self.nb_batches.append(self.train_iter)

    def after_epoch(self):
        "Plot validation loss in the pbar graph"
        if not self.nb_batches: return
        rec = self.learn.recorder
        iters = range_of(rec.losses)
        val_losses = [v[1] for v in rec.values]
        x_bounds = (0, (self.n_epoch - len(self.nb_batches)) * self.nb_batches[0] + len(rec.losses))
        y_bounds = (0, max((max(Tensor(rec.losses)), max(Tensor(val_losses)))))
        self.progress.mbar.update_graph([(iters, rec.losses), (self.nb_batches, val_losses)], x_bounds, y_bounds)

::: {#cell-24 .cell 0=‘缓’ 1=‘慢’}

learn = synth_learner(cbs=ShowGraphCallback())
learn.fit(5)
epoch train_loss valid_loss time
0 17.683565 10.431150 00:00
1 15.232769 7.056944 00:00
2 12.470916 4.382421 00:00
3 10.000675 2.574951 00:00
4 7.943449 1.464153 00:00

:::

learn.predict(torch.tensor([[0.1]]))
(tensor([1.8955]), tensor([1.8955]), tensor([1.8955]))

CSVLogger -

class CSVLogger(Callback):
    "Log the results displayed in `learn.path/fname`"
    order=60
    def __init__(self, fname='history.csv', append=False):
        self.fname,self.append = Path(fname),append

    def read_log(self):
        "Convenience method to quickly access the log."
        return pd.read_csv(self.path/self.fname)

    def before_fit(self):
        "Prepare file with metric names."
        if hasattr(self, "gather_preds"): return
        self.path.parent.mkdir(parents=True, exist_ok=True)
        self.file = (self.path/self.fname).open('a' if self.append else 'w')
        self.file.write(','.join(self.recorder.metric_names) + '\n')
        self.old_logger,self.learn.logger = self.logger,self._write_line

    def _write_line(self, log):
        "Write a line with `log` and call the old logger."
        self.file.write(','.join([str(t) for t in log]) + '\n')
        self.file.flush()
        os.fsync(self.file.fileno())
        self.old_logger(log)

    def after_fit(self):
        "Close the file and clean up."
        if hasattr(self, "gather_preds"): return
        self.file.close()
        self.learn.logger = self.old_logger

如果设置为append,结果将附加到现有文件中,否则将覆盖它。

learn = synth_learner(cbs=CSVLogger())
learn.fit(5)
epoch train_loss valid_loss time
0 15.606769 14.485189 00:00
1 13.840394 10.834929 00:00
2 11.842106 7.582738 00:00
3 9.937692 5.158300 00:00
4 8.244681 3.432087 00:00
show_doc(CSVLogger.read_log)

source

CSVLogger.read_log

 CSVLogger.read_log ()

Convenience method to quickly access the log.

df = learn.csv_logger.read_log()
test_eq(df.columns.values, learn.recorder.metric_names)
for i,v in enumerate(learn.recorder.values):
    test_close(df.iloc[i][:3], [i] + v)
os.remove(learn.path/learn.csv_logger.fname)
show_doc(CSVLogger.before_fit)

source

CSVLogger.before_fit

 CSVLogger.before_fit ()

Prepare file with metric names.

show_doc(CSVLogger.after_fit)

source

CSVLogger.after_fit

 CSVLogger.after_fit ()

Close the file and clean up.

导出 -

from nbdev import nbdev_export
nbdev_export()