topobench.callbacks package#

Callbacks for training, validation, and testing stages.

class BestEpochMetricsCallback(monitor, mode='min')#

Bases: Callback

Tracks all metrics at the epoch when the monitored metric is best.

This callback captures both training and validation metrics from the same epoch where the monitored metric (e.g., val/loss) achieves its best value. Unlike tracking the best value for each metric independently, this ensures all metrics are from the same checkpoint/epoch.

The metrics are logged with the prefix ‘best_epoch/’ to distinguish them from the running metrics and independent best metrics.

Parameters:
monitorstr

The metric to monitor (e.g., “val/loss”).

modestr, optional

Whether to minimize (“min”) or maximize (“max”) the monitored metric (default: “min”).

Examples

If validation loss is the monitored metric and reaches its minimum at epoch 42, this callback will log: - best_epoch/train/loss - best_epoch/train/accuracy - best_epoch/val/loss - best_epoch/val/accuracy - best_epoch/val/f1 etc., all from epoch 42.

__init__(monitor, mode='min')#
on_train_end(trainer, pl_module)#

Log the best model checkpoint path and metadata at the end of training.

Parameters:
trainerTrainer

The PyTorch Lightning trainer.

pl_moduleLightningModule

The PyTorch Lightning module being trained.

on_train_epoch_end(trainer, pl_module)#

Capture training metrics at the end of training phase.

Parameters:
trainerTrainer

The PyTorch Lightning trainer.

pl_moduleLightningModule

The PyTorch Lightning module being trained.

on_train_start(trainer, pl_module)#

Find and store reference to ModelCheckpoint callback for checkpoint path.

Parameters:
trainerTrainer

The PyTorch Lightning trainer.

pl_moduleLightningModule

The PyTorch Lightning module being trained.

on_validation_epoch_end(trainer, pl_module)#

Check if this is the best epoch and capture all metrics if so.

Parameters:
trainerTrainer

The PyTorch Lightning trainer.

pl_moduleLightningModule

The PyTorch Lightning module being trained.

class PipelineTimer#

Bases: Callback

Measures and logs average execution times of training, validation, and testing stages.

__init__()#

Initialize dictionaries to store accumulated times and counts.

on_test_batch_end(*args)#

End timing a test batch.

Parameters:
*argstuple

Additional arguments passed by the trainer.

on_test_batch_start(*args)#

Start timing a test batch.

Parameters:
*argstuple

Additional arguments passed by the trainer.

on_test_epoch_end(*args)#

End timing a test epoch.

Parameters:
*argstuple

Additional arguments passed by the trainer.

on_test_epoch_start(*args)#

Start timing a test epoch.

Parameters:
*argstuple

Additional arguments passed by the trainer.

on_train_batch_end(*args)#

End timing a training batch.

Parameters:
*argstuple

Additional arguments passed by the trainer.

on_train_batch_start(*args)#

Start timing a training batch.

Parameters:
*argstuple

Additional arguments passed by the trainer.

on_train_end(trainer, *args)#

Log the average times at the end of training.

Parameters:
trainerobject

The PyTorch Lightning trainer instance used for logging.

*argstuple

Additional arguments passed by the trainer.

on_train_epoch_end(*args)#

End timing a training epoch.

Parameters:
*argstuple

Additional arguments passed by the trainer.

on_train_epoch_start(*args)#

Start timing a training epoch.

Parameters:
*argstuple

Additional arguments passed by the trainer.

on_validation_batch_end(*args)#

End timing a validation batch.

Parameters:
*argstuple

Additional arguments passed by the trainer.

on_validation_batch_start(*args)#

Start timing a validation batch.

Parameters:
*argstuple

Additional arguments passed by the trainer.

on_validation_epoch_end(*args)#

End timing a validation epoch.

Parameters:
*argstuple

Additional arguments passed by the trainer.

on_validation_epoch_start(*args)#

Start timing a validation epoch.

Parameters:
*argstuple

Additional arguments passed by the trainer.

Submodules#