topobench.callbacks package#
Callbacks for training, validation, and testing stages.
- class BestEpochMetricsCallback(monitor, mode='min')#
Bases:
CallbackTracks all metrics at the epoch when the monitored metric is best.
This callback captures both training and validation metrics from the same epoch where the monitored metric (e.g., val/loss) achieves its best value. Unlike tracking the best value for each metric independently, this ensures all metrics are from the same checkpoint/epoch.
The metrics are logged with the prefix ‘best_epoch/’ to distinguish them from the running metrics and independent best metrics.
- Parameters:
- monitorstr
The metric to monitor (e.g., “val/loss”).
- modestr, optional
Whether to minimize (“min”) or maximize (“max”) the monitored metric (default: “min”).
Examples
If validation loss is the monitored metric and reaches its minimum at epoch 42, this callback will log: - best_epoch/train/loss - best_epoch/train/accuracy - best_epoch/val/loss - best_epoch/val/accuracy - best_epoch/val/f1 etc., all from epoch 42.
- __init__(monitor, mode='min')#
- on_train_end(trainer, pl_module)#
Log the best model checkpoint path and metadata at the end of training.
- Parameters:
- trainerTrainer
The PyTorch Lightning trainer.
- pl_moduleLightningModule
The PyTorch Lightning module being trained.
- on_train_epoch_end(trainer, pl_module)#
Capture training metrics at the end of training phase.
- Parameters:
- trainerTrainer
The PyTorch Lightning trainer.
- pl_moduleLightningModule
The PyTorch Lightning module being trained.
- on_train_start(trainer, pl_module)#
Find and store reference to ModelCheckpoint callback for checkpoint path.
- Parameters:
- trainerTrainer
The PyTorch Lightning trainer.
- pl_moduleLightningModule
The PyTorch Lightning module being trained.
- on_validation_epoch_end(trainer, pl_module)#
Check if this is the best epoch and capture all metrics if so.
- Parameters:
- trainerTrainer
The PyTorch Lightning trainer.
- pl_moduleLightningModule
The PyTorch Lightning module being trained.
- class PipelineTimer#
Bases:
CallbackMeasures and logs average execution times of training, validation, and testing stages.
- __init__()#
Initialize dictionaries to store accumulated times and counts.
- on_test_batch_end(*args)#
End timing a test batch.
- Parameters:
- *argstuple
Additional arguments passed by the trainer.
- on_test_batch_start(*args)#
Start timing a test batch.
- Parameters:
- *argstuple
Additional arguments passed by the trainer.
- on_test_epoch_end(*args)#
End timing a test epoch.
- Parameters:
- *argstuple
Additional arguments passed by the trainer.
- on_test_epoch_start(*args)#
Start timing a test epoch.
- Parameters:
- *argstuple
Additional arguments passed by the trainer.
- on_train_batch_end(*args)#
End timing a training batch.
- Parameters:
- *argstuple
Additional arguments passed by the trainer.
- on_train_batch_start(*args)#
Start timing a training batch.
- Parameters:
- *argstuple
Additional arguments passed by the trainer.
- on_train_end(trainer, *args)#
Log the average times at the end of training.
- Parameters:
- trainerobject
The PyTorch Lightning trainer instance used for logging.
- *argstuple
Additional arguments passed by the trainer.
- on_train_epoch_end(*args)#
End timing a training epoch.
- Parameters:
- *argstuple
Additional arguments passed by the trainer.
- on_train_epoch_start(*args)#
Start timing a training epoch.
- Parameters:
- *argstuple
Additional arguments passed by the trainer.
- on_validation_batch_end(*args)#
End timing a validation batch.
- Parameters:
- *argstuple
Additional arguments passed by the trainer.
- on_validation_batch_start(*args)#
Start timing a validation batch.
- Parameters:
- *argstuple
Additional arguments passed by the trainer.
- on_validation_epoch_end(*args)#
End timing a validation epoch.
- Parameters:
- *argstuple
Additional arguments passed by the trainer.
- on_validation_epoch_start(*args)#
Start timing a validation epoch.
- Parameters:
- *argstuple
Additional arguments passed by the trainer.
Submodules#
- topobench.callbacks.best_epoch_metrics module
BestEpochMetricsCallbackCallbackCallback.load_state_dict()Callback.on_after_backward()Callback.on_before_backward()Callback.on_before_optimizer_step()Callback.on_before_zero_grad()Callback.on_exception()Callback.on_fit_end()Callback.on_fit_start()Callback.on_load_checkpoint()Callback.on_predict_batch_end()Callback.on_predict_batch_start()Callback.on_predict_end()Callback.on_predict_epoch_end()Callback.on_predict_epoch_start()Callback.on_predict_start()Callback.on_sanity_check_end()Callback.on_sanity_check_start()Callback.on_save_checkpoint()Callback.on_test_batch_end()Callback.on_test_batch_start()Callback.on_test_end()Callback.on_test_epoch_end()Callback.on_test_epoch_start()Callback.on_test_start()Callback.on_train_batch_end()Callback.on_train_batch_start()Callback.on_train_end()Callback.on_train_epoch_end()Callback.on_train_epoch_start()Callback.on_train_start()Callback.on_validation_batch_end()Callback.on_validation_batch_start()Callback.on_validation_end()Callback.on_validation_epoch_end()Callback.on_validation_epoch_start()Callback.on_validation_start()Callback.setup()Callback.state_dict()Callback.teardown()Callback.state_key
ModelCheckpointModelCheckpoint.__init__()ModelCheckpoint.check_monitor_top_k()ModelCheckpoint.file_exists()ModelCheckpoint.format_checkpoint_name()ModelCheckpoint.load_state_dict()ModelCheckpoint.on_train_batch_end()ModelCheckpoint.on_train_epoch_end()ModelCheckpoint.on_train_start()ModelCheckpoint.on_validation_end()ModelCheckpoint.setup()ModelCheckpoint.state_dict()ModelCheckpoint.to_yaml()ModelCheckpoint.CHECKPOINT_EQUALS_CHARModelCheckpoint.CHECKPOINT_JOIN_CHARModelCheckpoint.CHECKPOINT_NAME_LASTModelCheckpoint.FILE_EXTENSIONModelCheckpoint.STARTING_VERSIONModelCheckpoint.every_n_epochsModelCheckpoint.state_key
- topobench.callbacks.timer_callback module
PipelineTimerPipelineTimer.__init__()PipelineTimer.on_test_batch_end()PipelineTimer.on_test_batch_start()PipelineTimer.on_test_epoch_end()PipelineTimer.on_test_epoch_start()PipelineTimer.on_train_batch_end()PipelineTimer.on_train_batch_start()PipelineTimer.on_train_end()PipelineTimer.on_train_epoch_end()PipelineTimer.on_train_epoch_start()PipelineTimer.on_validation_batch_end()PipelineTimer.on_validation_batch_start()PipelineTimer.on_validation_epoch_end()PipelineTimer.on_validation_epoch_start()