topobench.callbacks.best_epoch_metrics module#

Callback to track all metrics at the epoch when the monitored metric is best.

class BestEpochMetricsCallback(monitor, mode='min')#

Bases: Callback

Tracks all metrics at the epoch when the monitored metric is best.

This callback captures both training and validation metrics from the same epoch where the monitored metric (e.g., val/loss) achieves its best value. Unlike tracking the best value for each metric independently, this ensures all metrics are from the same checkpoint/epoch.

The metrics are logged with the prefix ‘best_epoch/’ to distinguish them from the running metrics and independent best metrics.

Parameters:
monitorstr

The metric to monitor (e.g., “val/loss”).

modestr, optional

Whether to minimize (“min”) or maximize (“max”) the monitored metric (default: “min”).

Examples

If validation loss is the monitored metric and reaches its minimum at epoch 42, this callback will log: - best_epoch/train/loss - best_epoch/train/accuracy - best_epoch/val/loss - best_epoch/val/accuracy - best_epoch/val/f1 etc., all from epoch 42.

__init__(monitor, mode='min')#
on_train_end(trainer, pl_module)#

Log the best model checkpoint path and metadata at the end of training.

Parameters:
trainerTrainer

The PyTorch Lightning trainer.

pl_moduleLightningModule

The PyTorch Lightning module being trained.

on_train_epoch_end(trainer, pl_module)#

Capture training metrics at the end of training phase.

Parameters:
trainerTrainer

The PyTorch Lightning trainer.

pl_moduleLightningModule

The PyTorch Lightning module being trained.

on_train_start(trainer, pl_module)#

Find and store reference to ModelCheckpoint callback for checkpoint path.

Parameters:
trainerTrainer

The PyTorch Lightning trainer.

pl_moduleLightningModule

The PyTorch Lightning module being trained.

on_validation_epoch_end(trainer, pl_module)#

Check if this is the best epoch and capture all metrics if so.

Parameters:
trainerTrainer

The PyTorch Lightning trainer.

pl_moduleLightningModule

The PyTorch Lightning module being trained.

class Callback#

Bases: object

Abstract base class used to build new callbacks.

Subclass this class and override any of the relevant hooks

load_state_dict(state_dict)#

Called when loading a checkpoint, implement to reload callback state given callback’s state_dict.

Parameters:

state_dict (Dict[str, Any]) – the callback state returned by state_dict.

on_after_backward(trainer, pl_module)#

Called after loss.backward() and before optimizers are stepped.

on_before_backward(trainer, pl_module, loss)#

Called before loss.backward().

on_before_optimizer_step(trainer, pl_module, optimizer)#

Called before optimizer.step().

on_before_zero_grad(trainer, pl_module, optimizer)#

Called before optimizer.zero_grad().

on_exception(trainer, pl_module, exception)#

Called when any trainer execution is interrupted by an exception.

on_fit_end(trainer, pl_module)#

Called when fit ends.

on_fit_start(trainer, pl_module)#

Called when fit begins.

on_load_checkpoint(trainer, pl_module, checkpoint)#

Called when loading a model checkpoint, use to reload state.

Parameters:
  • trainer (Trainer) – the current Trainer instance.

  • pl_module (LightningModule) – the current LightningModule instance.

  • checkpoint (Dict[str, Any]) – the full checkpoint dictionary that got loaded by the Trainer.

on_predict_batch_end(trainer, pl_module, outputs, batch, batch_idx, dataloader_idx=0)#

Called when the predict batch ends.

on_predict_batch_start(trainer, pl_module, batch, batch_idx, dataloader_idx=0)#

Called when the predict batch begins.

on_predict_end(trainer, pl_module)#

Called when predict ends.

on_predict_epoch_end(trainer, pl_module)#

Called when the predict epoch ends.

on_predict_epoch_start(trainer, pl_module)#

Called when the predict epoch begins.

on_predict_start(trainer, pl_module)#

Called when the predict begins.

on_sanity_check_end(trainer, pl_module)#

Called when the validation sanity check ends.

on_sanity_check_start(trainer, pl_module)#

Called when the validation sanity check starts.

on_save_checkpoint(trainer, pl_module, checkpoint)#

Called when saving a checkpoint to give you a chance to store anything else you might want to save.

Parameters:
  • trainer (Trainer) – the current Trainer instance.

  • pl_module (LightningModule) – the current LightningModule instance.

  • checkpoint (Dict[str, Any]) – the checkpoint dictionary that will be saved.

on_test_batch_end(trainer, pl_module, outputs, batch, batch_idx, dataloader_idx=0)#

Called when the test batch ends.

on_test_batch_start(trainer, pl_module, batch, batch_idx, dataloader_idx=0)#

Called when the test batch begins.

on_test_end(trainer, pl_module)#

Called when the test ends.

on_test_epoch_end(trainer, pl_module)#

Called when the test epoch ends.

on_test_epoch_start(trainer, pl_module)#

Called when the test epoch begins.

on_test_start(trainer, pl_module)#

Called when the test begins.

on_train_batch_end(trainer, pl_module, outputs, batch, batch_idx)#

Called when the train batch ends.

Note

The value outputs["loss"] here will be the normalized value w.r.t accumulate_grad_batches of the loss returned from training_step.

on_train_batch_start(trainer, pl_module, batch, batch_idx)#

Called when the train batch begins.

on_train_end(trainer, pl_module)#

Called when the train ends.

on_train_epoch_end(trainer, pl_module)#

Called when the train epoch ends.

To access all batch outputs at the end of the epoch, you can cache step outputs as an attribute of the lightning.pytorch.core.LightningModule and access them in this hook:

class MyLightningModule(L.LightningModule):
    def __init__(self):
        super().__init__()
        self.training_step_outputs = []

    def training_step(self):
        loss = ...
        self.training_step_outputs.append(loss)
        return loss

class MyCallback(L.Callback):
    def on_train_epoch_end(self, trainer, pl_module):
        # do something with all training_step outputs, for example:
        epoch_mean = torch.stack(pl_module.training_step_outputs).mean()
        pl_module.log("training_epoch_mean", epoch_mean)
        # free up the memory
        pl_module.training_step_outputs.clear()
on_train_epoch_start(trainer, pl_module)#

Called when the train epoch begins.

on_train_start(trainer, pl_module)#

Called when the train begins.

on_validation_batch_end(trainer, pl_module, outputs, batch, batch_idx, dataloader_idx=0)#

Called when the validation batch ends.

on_validation_batch_start(trainer, pl_module, batch, batch_idx, dataloader_idx=0)#

Called when the validation batch begins.

on_validation_end(trainer, pl_module)#

Called when the validation loop ends.

on_validation_epoch_end(trainer, pl_module)#

Called when the val epoch ends.

on_validation_epoch_start(trainer, pl_module)#

Called when the val epoch begins.

on_validation_start(trainer, pl_module)#

Called when the validation loop begins.

setup(trainer, pl_module, stage)#

Called when fit, validate, test, predict, or tune begins.

state_dict()#

Called when saving a checkpoint, implement to generate callback’s state_dict.

Returns:

A dictionary containing callback state.

Return type:

Dict[str, Any]

teardown(trainer, pl_module, stage)#

Called when fit, validate, test, predict, or tune ends.

property state_key: str#

Identifier for the state of the callback.

Used to store and retrieve a callback’s state from the checkpoint dictionary by checkpoint["callbacks"][state_key]. Implementations of a callback need to provide a unique state key if 1) the callback has state and 2) it is desired to maintain the state of multiple instances of that callback.

class ModelCheckpoint(dirpath=None, filename=None, monitor=None, verbose=False, save_last=None, save_top_k=1, save_weights_only=False, mode='min', auto_insert_metric_name=True, every_n_train_steps=None, train_time_interval=None, every_n_epochs=None, save_on_train_epoch_end=None, enable_version_counter=True)#

Bases: Checkpoint

Save the model periodically by monitoring a quantity. Every metric logged with log() or log_dict() is a candidate for the monitor key. For more information, see checkpointing.

After training finishes, use best_model_path to retrieve the path to the best checkpoint file and best_model_score to retrieve its score.

Parameters:
  • dirpath (str | Path | None) –

    directory to save the model file.

    Example:

    # custom path
    # saves a file like: my/path/epoch=0-step=10.ckpt
    >>> checkpoint_callback = ModelCheckpoint(dirpath='my/path/')
    

    By default, dirpath is None and will be set at runtime to the location specified by Trainer’s :paramref:`~lightning.pytorch.trainer.trainer.Trainer.default_root_dir` argument, and if the Trainer uses a logger, the path will also contain logger name and version.

  • filename (str | None) –

    checkpoint filename. Can contain named formatting options to be auto-filled.

    Example:

    # save any arbitrary metrics like `val_loss`, etc. in name
    # saves a file like: my/path/epoch=2-val_loss=0.02-other_metric=0.03.ckpt
    >>> checkpoint_callback = ModelCheckpoint(
    ...     dirpath='my/path',
    ...     filename='{epoch}-{val_loss:.2f}-{other_metric:.2f}'
    ... )
    

    By default, filename is None and will be set to '{epoch}-{step}', where “epoch” and “step” match the number of finished epoch and optimizer steps respectively.

  • monitor (str | None) – quantity to monitor. By default it is None which saves a checkpoint only for the last epoch.

  • verbose (bool) – verbosity mode. Default: False.

  • save_last (bool | Literal['link'] | None) – When True, saves a last.ckpt copy whenever a checkpoint file gets saved. Can be set to 'link' on a local filesystem to create a symbolic link. This allows accessing the latest checkpoint in a deterministic manner. Default: None.

  • save_top_k (int) – if save_top_k == k, the best k models according to the quantity monitored will be saved. If save_top_k == 0, no models are saved. If save_top_k == -1, all models are saved. Please note that the monitors are checked every every_n_epochs epochs. If save_top_k >= 2 and the callback is called multiple times inside an epoch, and the filename remains unchanged, the name of the saved file will be appended with a version count starting with v1 to avoid collisions unless enable_version_counter is set to False. The version counter is unrelated to the top-k ranking of the checkpoint, and we recommend formatting the filename to include the monitored metric to avoid collisions.

  • mode (str) – one of {min, max}. If save_top_k != 0, the decision to overwrite the current save file is made based on either the maximization or the minimization of the monitored quantity. For 'val_acc', this should be 'max', for 'val_loss' this should be 'min', etc.

  • auto_insert_metric_name (bool) – When True, the checkpoints filenames will contain the metric name. For example, filename='checkpoint_{epoch:02d}-{acc:02.0f} with epoch 1 and acc 1.12 will resolve to checkpoint_epoch=01-acc=01.ckpt. Is useful to set it to False when metric names contain / as this will result in extra folders. For example, filename='epoch={epoch}-step={step}-val_acc={val/acc:.2f}', auto_insert_metric_name=False

  • save_weights_only (bool) – if True, then only the model’s weights will be saved. Otherwise, the optimizer states, lr-scheduler states, etc are added in the checkpoint too.

  • every_n_train_steps (int | None) – Number of training steps between checkpoints. If every_n_train_steps == None or every_n_train_steps == 0, we skip saving during training. To disable, set every_n_train_steps = 0. This value must be None or non-negative. This must be mutually exclusive with train_time_interval and every_n_epochs.

  • train_time_interval (timedelta | None) – Checkpoints are monitored at the specified time interval. For all practical purposes, this cannot be smaller than the amount of time it takes to process a single training batch. This is not guaranteed to execute at the exact time specified, but should be close. This must be mutually exclusive with every_n_train_steps and every_n_epochs.

  • every_n_epochs (int | None) – Number of epochs between checkpoints. This value must be None or non-negative. To disable saving top-k checkpoints, set every_n_epochs = 0. This argument does not impact the saving of save_last=True checkpoints. If all of every_n_epochs, every_n_train_steps and train_time_interval are None, we save a checkpoint at the end of every epoch (equivalent to every_n_epochs = 1). If every_n_epochs == None and either every_n_train_steps != None or train_time_interval != None, saving at the end of each epoch is disabled (equivalent to every_n_epochs = 0). This must be mutually exclusive with every_n_train_steps and train_time_interval. Setting both ModelCheckpoint(..., every_n_epochs=V, save_on_train_epoch_end=False) and Trainer(max_epochs=N, check_val_every_n_epoch=M) will only save checkpoints at epochs 0 < E <= N where both values for every_n_epochs and check_val_every_n_epoch evenly divide E.

  • save_on_train_epoch_end (bool | None) – Whether to run checkpointing at the end of the training epoch. If this is False, then the check runs at the end of the validation.

  • enable_version_counter (bool) – Whether to append a version to the existing file name. If this is False, then the checkpoint files will be overwritten.

Note

For extra customization, ModelCheckpoint includes the following attributes:

  • CHECKPOINT_JOIN_CHAR = "-"

  • CHECKPOINT_EQUALS_CHAR = "="

  • CHECKPOINT_NAME_LAST = "last"

  • FILE_EXTENSION = ".ckpt"

  • STARTING_VERSION = 1

For example, you can change the default last checkpoint name by doing checkpoint_callback.CHECKPOINT_NAME_LAST = "{epoch}-last"

If you want to checkpoint every N hours, every M train batches, and/or every K val epochs, then you should create multiple ModelCheckpoint callbacks.

If the checkpoint’s dirpath changed from what it was before while resuming the training, only best_model_path will be reloaded and a warning will be issued.

Raises:
  • MisconfigurationException – If save_top_k is smaller than -1, if monitor is None and save_top_k is none of None, -1, and 0, or if mode is none of "min" or "max".

  • ValueError – If trainer.save_checkpoint is None.

Example:

>>> from lightning.pytorch import Trainer
>>> from lightning.pytorch.callbacks import ModelCheckpoint

# saves checkpoints to 'my/path/' at every epoch
>>> checkpoint_callback = ModelCheckpoint(dirpath='my/path/')
>>> trainer = Trainer(callbacks=[checkpoint_callback])

# save epoch and val_loss in name
# saves a file like: my/path/sample-mnist-epoch=02-val_loss=0.32.ckpt
>>> checkpoint_callback = ModelCheckpoint(
...     monitor='val_loss',
...     dirpath='my/path/',
...     filename='sample-mnist-{epoch:02d}-{val_loss:.2f}'
... )

# save epoch and val_loss in name, but specify the formatting yourself (e.g. to avoid problems with Tensorboard
# or Neptune, due to the presence of characters like '=' or '/')
# saves a file like: my/path/sample-mnist-epoch02-val_loss0.32.ckpt
>>> checkpoint_callback = ModelCheckpoint(
...     monitor='val/loss',
...     dirpath='my/path/',
...     filename='sample-mnist-epoch{epoch:02d}-val_loss{val/loss:.2f}',
...     auto_insert_metric_name=False
... )

# retrieve the best checkpoint after training
checkpoint_callback = ModelCheckpoint(dirpath='my/path/')
trainer = Trainer(callbacks=[checkpoint_callback])
model = ...
trainer.fit(model)
checkpoint_callback.best_model_path

Tip

Saving and restoring multiple checkpoint callbacks at the same time is supported under variation in the following arguments:

monitor, mode, every_n_train_steps, every_n_epochs, train_time_interval

Read more: Persisting Callback State

__init__(dirpath=None, filename=None, monitor=None, verbose=False, save_last=None, save_top_k=1, save_weights_only=False, mode='min', auto_insert_metric_name=True, every_n_train_steps=None, train_time_interval=None, every_n_epochs=None, save_on_train_epoch_end=None, enable_version_counter=True)#
check_monitor_top_k(trainer, current=None)#
file_exists(filepath, trainer)#

Checks if a file exists on rank 0 and broadcasts the result to all other ranks, preventing the internal state to diverge between ranks.

format_checkpoint_name(metrics, filename=None, ver=None)#

Generate a filename according to the defined template.

Example:

>>> tmpdir = os.path.dirname(__file__)
>>> ckpt = ModelCheckpoint(dirpath=tmpdir, filename='{epoch}')
>>> os.path.basename(ckpt.format_checkpoint_name(dict(epoch=0)))
'epoch=0.ckpt'
>>> ckpt = ModelCheckpoint(dirpath=tmpdir, filename='{epoch:03d}')
>>> os.path.basename(ckpt.format_checkpoint_name(dict(epoch=5)))
'epoch=005.ckpt'
>>> ckpt = ModelCheckpoint(dirpath=tmpdir, filename='{epoch}-{val_loss:.2f}')
>>> os.path.basename(ckpt.format_checkpoint_name(dict(epoch=2, val_loss=0.123456)))
'epoch=2-val_loss=0.12.ckpt'
>>> os.path.basename(ckpt.format_checkpoint_name(dict(epoch=2, val_loss=0.12), filename='{epoch:d}'))
'epoch=2.ckpt'
>>> ckpt = ModelCheckpoint(dirpath=tmpdir,
... filename='epoch={epoch}-validation_loss={val_loss:.2f}',
... auto_insert_metric_name=False)
>>> os.path.basename(ckpt.format_checkpoint_name(dict(epoch=2, val_loss=0.123456)))
'epoch=2-validation_loss=0.12.ckpt'
>>> ckpt = ModelCheckpoint(dirpath=tmpdir, filename='{missing:d}')
>>> os.path.basename(ckpt.format_checkpoint_name({}))
'missing=0.ckpt'
>>> ckpt = ModelCheckpoint(filename='{step}')
>>> os.path.basename(ckpt.format_checkpoint_name(dict(step=0)))
'step=0.ckpt'
load_state_dict(state_dict)#

Called when loading a checkpoint, implement to reload callback state given callback’s state_dict.

Parameters:

state_dict (Dict[str, Any]) – the callback state returned by state_dict.

on_train_batch_end(trainer, pl_module, outputs, batch, batch_idx)#

Save checkpoint on train batch end if we meet the criteria for every_n_train_steps

on_train_epoch_end(trainer, pl_module)#

Save a checkpoint at the end of the training epoch.

on_train_start(trainer, pl_module)#

Called when the train begins.

on_validation_end(trainer, pl_module)#

Save a checkpoint at the end of the validation stage.

setup(trainer, pl_module, stage)#

Called when fit, validate, test, predict, or tune begins.

state_dict()#

Called when saving a checkpoint, implement to generate callback’s state_dict.

Returns:

A dictionary containing callback state.

Return type:

Dict[str, Any]

to_yaml(filepath=None)#

Saves the best_k_models dict containing the checkpoint paths with the corresponding scores to a YAML file.

CHECKPOINT_EQUALS_CHAR = '='#
CHECKPOINT_JOIN_CHAR = '-'#
CHECKPOINT_NAME_LAST = 'last'#
FILE_EXTENSION = '.ckpt'#
STARTING_VERSION = 1#
property every_n_epochs: int | None#

!! processed by numpydoc !!

property state_key: str#

Identifier for the state of the callback.

Used to store and retrieve a callback’s state from the checkpoint dictionary by checkpoint["callbacks"][state_key]. Implementations of a callback need to provide a unique state key if 1) the callback has state and 2) it is desired to maintain the state of multiple instances of that callback.