topobench.utils.instantiators module#

Instantiators for callbacks and loggers.

class Callback#

Bases: object

Abstract base class used to build new callbacks.

Subclass this class and override any of the relevant hooks

load_state_dict(state_dict)#

Called when loading a checkpoint, implement to reload callback state given callback’s state_dict.

Parameters:

state_dict (Dict[str, Any]) – the callback state returned by state_dict.

on_after_backward(trainer, pl_module)#

Called after loss.backward() and before optimizers are stepped.

on_before_backward(trainer, pl_module, loss)#

Called before loss.backward().

on_before_optimizer_step(trainer, pl_module, optimizer)#

Called before optimizer.step().

on_before_zero_grad(trainer, pl_module, optimizer)#

Called before optimizer.zero_grad().

on_exception(trainer, pl_module, exception)#

Called when any trainer execution is interrupted by an exception.

on_fit_end(trainer, pl_module)#

Called when fit ends.

on_fit_start(trainer, pl_module)#

Called when fit begins.

on_load_checkpoint(trainer, pl_module, checkpoint)#

Called when loading a model checkpoint, use to reload state.

Parameters:
  • trainer (Trainer) – the current Trainer instance.

  • pl_module (LightningModule) – the current LightningModule instance.

  • checkpoint (Dict[str, Any]) – the full checkpoint dictionary that got loaded by the Trainer.

on_predict_batch_end(trainer, pl_module, outputs, batch, batch_idx, dataloader_idx=0)#

Called when the predict batch ends.

on_predict_batch_start(trainer, pl_module, batch, batch_idx, dataloader_idx=0)#

Called when the predict batch begins.

on_predict_end(trainer, pl_module)#

Called when predict ends.

on_predict_epoch_end(trainer, pl_module)#

Called when the predict epoch ends.

on_predict_epoch_start(trainer, pl_module)#

Called when the predict epoch begins.

on_predict_start(trainer, pl_module)#

Called when the predict begins.

on_sanity_check_end(trainer, pl_module)#

Called when the validation sanity check ends.

on_sanity_check_start(trainer, pl_module)#

Called when the validation sanity check starts.

on_save_checkpoint(trainer, pl_module, checkpoint)#

Called when saving a checkpoint to give you a chance to store anything else you might want to save.

Parameters:
  • trainer (Trainer) – the current Trainer instance.

  • pl_module (LightningModule) – the current LightningModule instance.

  • checkpoint (Dict[str, Any]) – the checkpoint dictionary that will be saved.

on_test_batch_end(trainer, pl_module, outputs, batch, batch_idx, dataloader_idx=0)#

Called when the test batch ends.

on_test_batch_start(trainer, pl_module, batch, batch_idx, dataloader_idx=0)#

Called when the test batch begins.

on_test_end(trainer, pl_module)#

Called when the test ends.

on_test_epoch_end(trainer, pl_module)#

Called when the test epoch ends.

on_test_epoch_start(trainer, pl_module)#

Called when the test epoch begins.

on_test_start(trainer, pl_module)#

Called when the test begins.

on_train_batch_end(trainer, pl_module, outputs, batch, batch_idx)#

Called when the train batch ends.

Note

The value outputs["loss"] here will be the normalized value w.r.t accumulate_grad_batches of the loss returned from training_step.

on_train_batch_start(trainer, pl_module, batch, batch_idx)#

Called when the train batch begins.

on_train_end(trainer, pl_module)#

Called when the train ends.

on_train_epoch_end(trainer, pl_module)#

Called when the train epoch ends.

To access all batch outputs at the end of the epoch, you can cache step outputs as an attribute of the lightning.pytorch.core.LightningModule and access them in this hook:

class MyLightningModule(L.LightningModule):
    def __init__(self):
        super().__init__()
        self.training_step_outputs = []

    def training_step(self):
        loss = ...
        self.training_step_outputs.append(loss)
        return loss

class MyCallback(L.Callback):
    def on_train_epoch_end(self, trainer, pl_module):
        # do something with all training_step outputs, for example:
        epoch_mean = torch.stack(pl_module.training_step_outputs).mean()
        pl_module.log("training_epoch_mean", epoch_mean)
        # free up the memory
        pl_module.training_step_outputs.clear()
on_train_epoch_start(trainer, pl_module)#

Called when the train epoch begins.

on_train_start(trainer, pl_module)#

Called when the train begins.

on_validation_batch_end(trainer, pl_module, outputs, batch, batch_idx, dataloader_idx=0)#

Called when the validation batch ends.

on_validation_batch_start(trainer, pl_module, batch, batch_idx, dataloader_idx=0)#

Called when the validation batch begins.

on_validation_end(trainer, pl_module)#

Called when the validation loop ends.

on_validation_epoch_end(trainer, pl_module)#

Called when the val epoch ends.

on_validation_epoch_start(trainer, pl_module)#

Called when the val epoch begins.

on_validation_start(trainer, pl_module)#

Called when the validation loop begins.

setup(trainer, pl_module, stage)#

Called when fit, validate, test, predict, or tune begins.

state_dict()#

Called when saving a checkpoint, implement to generate callback’s state_dict.

Returns:

A dictionary containing callback state.

Return type:

Dict[str, Any]

teardown(trainer, pl_module, stage)#

Called when fit, validate, test, predict, or tune ends.

property state_key: str#

Identifier for the state of the callback.

Used to store and retrieve a callback’s state from the checkpoint dictionary by checkpoint["callbacks"][state_key]. Implementations of a callback need to provide a unique state key if 1) the callback has state and 2) it is desired to maintain the state of multiple instances of that callback.

class DictConfig(content, key=None, parent=None, ref_type=typing.Any, key_type=typing.Any, element_type=typing.Any, is_optional=True, flags=None)#

Bases: BaseContainer, MutableMapping[Any, Any]

__init__(content, key=None, parent=None, ref_type=typing.Any, key_type=typing.Any, element_type=typing.Any, is_optional=True, flags=None)#
copy()#
get(key, default_value=None)#

Return the value for key if key is in the dictionary, else default_value (defaulting to None).

items() a set-like object providing a view on D's items#
items_ex(resolve=True, keys=None)#
keys() a set-like object providing a view on D's keys#
pop(k[, d]) v, remove specified key and return the corresponding value.#

If key is not found, d is returned if given, otherwise KeyError is raised.

setdefault(k[, d]) D.get(k,d), also set D[k]=d if k not in D#
class Logger#

Bases: Logger, ABC

Base class for experiment loggers.

after_save_checkpoint(checkpoint_callback)#

Called after model checkpoint callback saves a new checkpoint.

Parameters:

checkpoint_callback (ModelCheckpoint) – the model checkpoint callback instance

property save_dir: str | None#

Return the root directory where experiment logs get saved, or None if the logger does not save data locally.

instantiate_callbacks(callbacks_cfg)#

Instantiate callbacks from config.

Parameters:
callbacks_cfgDictConfig

A DictConfig object containing callback configurations.

Returns:
list[Callback]

A list of instantiated callbacks.

instantiate_loggers(logger_cfg)#

Instantiate loggers from config.

Parameters:
logger_cfgDictConfig

A DictConfig object containing logger configurations.

Returns:
list[Logger]

A list of instantiated loggers.