! [ -e /content ] && pip install -Uqq fastai # 在Colab上升级fastai
进度与日志记录
from __future__ import annotations
from fastai.basics import *
from nbdev.showdoc import *
回调和辅助函数用于跟踪训练进度或记录结果
from fastai.test_utils import *
进度回调 -
@docs
class ProgressCallback(Callback):
"A `Callback` to handle the display of progress bars"
= 60,('mbar','pbar')
order,_stateattrs
def before_fit(self):
assert hasattr(self.learn, 'recorder')
if self.create_mbar: self.mbar = master_bar(list(range(self.n_epoch)))
if self.learn.logger != noop:
self.old_logger,self.learn.logger = self.logger,self._write_stats
self._write_stats(self.recorder.metric_names)
else: self.old_logger = noop
def before_epoch(self):
if getattr(self, 'mbar', False): self.mbar.update(self.epoch)
def before_train(self): self._launch_pbar()
def before_validate(self): self._launch_pbar()
def after_train(self): self.pbar.on_iter_end()
def after_validate(self): self.pbar.on_iter_end()
def after_batch(self):
self.pbar.update(self.iter+1)
if hasattr(self, 'smooth_loss'): self.pbar.comment = f'{self.smooth_loss.item():.4f}'
def _launch_pbar(self):
self.pbar = progress_bar(self.dl, parent=getattr(self, 'mbar', None), leave=False)
self.pbar.update(0)
def after_fit(self):
if getattr(self, 'mbar', False):
self.mbar.on_iter_end()
delattr(self, 'mbar')
if hasattr(self, 'old_logger'): self.learn.logger = self.old_logger
def _write_stats(self, log):
if getattr(self, 'mbar', False): self.mbar.write([f'{l:.6f}' if isinstance(l, float) else str(l) for l in log], table=True)
= dict(before_fit="Setup the master bar over the epochs",
_docs ="Update the master bar",
before_epoch="Launch a progress bar over the training dataloader",
before_train="Launch a progress bar over the validation dataloader",
before_validate="Close the progress bar over the training dataloader",
after_train="Close the progress bar over the validation dataloader",
after_validate="Update the current progress bar",
after_batch="Close the master bar")
after_fit
if not hasattr(defaults, 'callbacks'): defaults.callbacks = [TrainEvalCallback, Recorder, ProgressCallback]
elif ProgressCallback not in defaults.callbacks: defaults.callbacks.append(ProgressCallback)
= synth_learner()
learn 5) learn.fit(
epoch | train_loss | valid_loss | time |
---|---|---|---|
0 | 14.523648 | 10.988108 | 00:00 |
1 | 12.395808 | 7.306935 | 00:00 |
2 | 10.121231 | 4.370981 | 00:00 |
3 | 8.065226 | 2.487984 | 00:00 |
4 | 6.374166 | 1.368232 | 00:00 |
@patch
@contextmanager
def no_bar(self:Learner):
"Context manager that deactivates the use of progress bars"
= hasattr(self, 'progress')
has_progress if has_progress: self.remove_cb(self.progress)
try: yield self
finally:
if has_progress: self.add_cb(ProgressCallback())
= synth_learner()
learn with learn.no_bar(): learn.fit(5)
[0, 15.748106002807617, 12.352150917053223, '00:00']
[1, 13.818815231323242, 8.879858016967773, '00:00']
[2, 11.650713920593262, 5.857329845428467, '00:00']
[3, 9.595088005065918, 3.7397098541259766, '00:00']
[4, 7.814438343048096, 2.327916145324707, '00:00']
#检查验证在无训练情况下是否有效
def tst_metric(out, targ): return F.mse_loss(out, targ)
= synth_learner(n_trn=5, metrics=tst_metric)
learn = learn.validate() preds,targs
#检查get_preds在未经过任何训练的情况下是否正常工作
= synth_learner(n_trn=5, metrics=tst_metric)
learn = learn.validate() preds,targs
show_doc(ProgressCallback.before_fit)
ProgressCallback.before_fit
ProgressCallback.before_fit ()
Setup the master bar over the epochs
show_doc(ProgressCallback.before_epoch)
show_doc(ProgressCallback.before_train)
ProgressCallback.before_train
ProgressCallback.before_train ()
Launch a progress bar over the training dataloader
show_doc(ProgressCallback.before_validate)
ProgressCallback.before_validate
ProgressCallback.before_validate ()
Launch a progress bar over the validation dataloader
show_doc(ProgressCallback.after_batch)
show_doc(ProgressCallback.after_train)
ProgressCallback.after_train
ProgressCallback.after_train ()
Close the progress bar over the training dataloader
show_doc(ProgressCallback.after_validate)
ProgressCallback.after_validate
ProgressCallback.after_validate ()
Close the progress bar over the validation dataloader
show_doc(ProgressCallback.after_fit)
ShowGraphCallback -
class ShowGraphCallback(Callback):
"Update a graph of training and validation loss"
=65,False
order,run_valid
def before_fit(self):
self.run = not hasattr(self.learn, 'lr_finder') and not hasattr(self, "gather_preds")
if not(self.run): return
self.nb_batches = []
assert hasattr(self.learn, 'progress')
def after_train(self): self.nb_batches.append(self.train_iter)
def after_epoch(self):
"Plot validation loss in the pbar graph"
if not self.nb_batches: return
= self.learn.recorder
rec = range_of(rec.losses)
iters = [v[1] for v in rec.values]
val_losses = (0, (self.n_epoch - len(self.nb_batches)) * self.nb_batches[0] + len(rec.losses))
x_bounds = (0, max((max(Tensor(rec.losses)), max(Tensor(val_losses)))))
y_bounds self.progress.mbar.update_graph([(iters, rec.losses), (self.nb_batches, val_losses)], x_bounds, y_bounds)
::: {#cell-24 .cell 0=‘缓’ 1=‘慢’}
= synth_learner(cbs=ShowGraphCallback())
learn 5) learn.fit(
epoch | train_loss | valid_loss | time |
---|---|---|---|
0 | 17.683565 | 10.431150 | 00:00 |
1 | 15.232769 | 7.056944 | 00:00 |
2 | 12.470916 | 4.382421 | 00:00 |
3 | 10.000675 | 2.574951 | 00:00 |
4 | 7.943449 | 1.464153 | 00:00 |
:::
0.1]])) learn.predict(torch.tensor([[
(tensor([1.8955]), tensor([1.8955]), tensor([1.8955]))
CSVLogger -
class CSVLogger(Callback):
"Log the results displayed in `learn.path/fname`"
=60
orderdef __init__(self, fname='history.csv', append=False):
self.fname,self.append = Path(fname),append
def read_log(self):
"Convenience method to quickly access the log."
return pd.read_csv(self.path/self.fname)
def before_fit(self):
"Prepare file with metric names."
if hasattr(self, "gather_preds"): return
self.path.parent.mkdir(parents=True, exist_ok=True)
self.file = (self.path/self.fname).open('a' if self.append else 'w')
self.file.write(','.join(self.recorder.metric_names) + '\n')
self.old_logger,self.learn.logger = self.logger,self._write_line
def _write_line(self, log):
"Write a line with `log` and call the old logger."
self.file.write(','.join([str(t) for t in log]) + '\n')
self.file.flush()
self.file.fileno())
os.fsync(self.old_logger(log)
def after_fit(self):
"Close the file and clean up."
if hasattr(self, "gather_preds"): return
self.file.close()
self.learn.logger = self.old_logger
如果设置为append
,结果将附加到现有文件中,否则将覆盖它。
= synth_learner(cbs=CSVLogger())
learn 5) learn.fit(
epoch | train_loss | valid_loss | time |
---|---|---|---|
0 | 15.606769 | 14.485189 | 00:00 |
1 | 13.840394 | 10.834929 | 00:00 |
2 | 11.842106 | 7.582738 | 00:00 |
3 | 9.937692 | 5.158300 | 00:00 |
4 | 8.244681 | 3.432087 | 00:00 |
show_doc(CSVLogger.read_log)
= learn.csv_logger.read_log()
df
test_eq(df.columns.values, learn.recorder.metric_names)for i,v in enumerate(learn.recorder.values):
3], [i] + v)
test_close(df.iloc[i][:/learn.csv_logger.fname) os.remove(learn.path
show_doc(CSVLogger.before_fit)
show_doc(CSVLogger.after_fit)
导出 -
from nbdev import nbdev_export
nbdev_export()