topobench.model package#
TB model module.
- class TBModel(backbone, readout, loss, backbone_wrapper=None, feature_encoder=None, evaluator=None, optimizer=None, **kwargs)#
Bases:
LightningModuleA LightningModule to define a network.
- Parameters:
- backbonetorch.nn.Module
The backbone model to train.
- readouttorch.nn.Module
The readout class.
- losstorch.nn.Module
The loss class.
- backbone_wrappertorch.nn.Module, optional
The backbone wrapper class (default: None).
- feature_encodertorch.nn.Module, optional
The feature encoder (default: None).
- evaluatorAny, optional
The evaluator class (default: None).
- optimizerAny, optional
The optimizer class (default: None).
- **kwargsAny
Additional keyword arguments.
- __init__(backbone, readout, loss, backbone_wrapper=None, feature_encoder=None, evaluator=None, optimizer=None, **kwargs)#
- configure_optimizers()#
Configure optimizers and learning-rate schedulers.
Choose what optimizers and learning-rate schedulers to use in your optimization. Normally you’d need one. But in the case of GANs or similar you might have multiple.
Examples
https://lightning.ai/docs/pytorch/latest/common/lightning_module.html#configure-optimizers
- Returns:
- dict:
A dict containing the configured optimizers and learning-rate schedulers to be used for training.
- forward(batch)#
Perform a forward pass through the model.
- Parameters:
- batchtorch_geometric.data.Data
Batch object containing the batched data.
- Returns:
- dict
Dictionary containing the model output, which includes the logits and other relevant information.
- log_metrics(mode=None)#
Log metrics.
- Parameters:
- modestr, optional
The mode of the model, either “train”, “val”, or “test” (default: None).
- model_step(batch)#
Perform a single model step on a batch of data.
- Parameters:
- batchtorch_geometric.data.Data
Batch object containing the batched data.
- Returns:
- dict
Dictionary containing the model output and the loss.
- on_test_epoch_end()#
Lightning hook that is called when a test epoch ends.
This hook is used to log the test metrics.
- on_test_epoch_start()#
Lightning hook that is called when a test epoch begins.
This hook is used to reset the test metrics.
- on_train_epoch_end()#
Lightning hook that is called when a train epoch ends.
This hook is used to log the train metrics.
- on_train_epoch_start()#
Lightning hook that is called when a train epoch begins.
This hook is used to reset the train metrics.
- on_val_epoch_start()#
Lightning hook that is called when a validation epoch begins.
This hook is used to reset the validation metrics.
- on_validation_epoch_end()#
Lightning hook that is called when a validation epoch ends.
This hook is used to log the validation metrics.
- on_validation_epoch_start()#
Hook called when a validation epoch begins.
According pytorch lightning documentation this hook is called at the beginning of the validation epoch.
https://lightning.ai/docs/pytorch/stable/common/lightning_module.html#hooks
Note that the validation step is within the train epoch. Hence here we have to log the train metrics before we reset the evaluator to start the validation loop.
- process_outputs(model_out, batch)#
Handle model outputs.
- Parameters:
- model_outdict
Dictionary containing the model output.
- batchtorch_geometric.data.Data
Batch object containing the batched data.
- Returns:
- dict
Dictionary containing the updated model output.
- setup(stage)#
Hook to call torch.compile.
Lightning hook that is called at the beginning of fit (train + validate), validate, test, or predict.
This is a good hook when you need to build models dynamically or adjust something about them. This hook is called on every process when using DDP.
- Parameters:
- stagestr
Either “fit”, “validate”, “test”, or “predict”.
- test_step(batch, batch_idx)#
Perform a single test step on a batch of data.
- Parameters:
- batchtorch_geometric.data.Data
Batch object containing the batched data.
- batch_idxint
The index of the current batch.
- training_step(batch, batch_idx)#
Perform a single training step on a batch of data.
- Parameters:
- batchtorch_geometric.data.Data
Batch object containing the batched data.
- batch_idxint
The index of the current batch.
- Returns:
- torch.Tensor
A tensor of losses between model predictions and targets.
- validation_step(batch, batch_idx)#
Perform a single validation step on a batch of data.
- Parameters:
- batchtorch_geometric.data.Data
Batch object containing the batched data.
- batch_idxint
The index of the current batch.
Submodules#
- topobench.model.model module
AnyDataData.__init__()Data.connected_components()Data.debug()Data.edge_subgraph()Data.from_dict()Data.get_all_edge_attrs()Data.get_all_tensor_attrs()Data.is_edge_attr()Data.is_node_attr()Data.stores_as()Data.subgraph()Data.to_dict()Data.to_heterogeneous()Data.to_namedtuple()Data.update()Data.validate()Data.batchData.edge_attrData.edge_indexData.edge_storesData.edge_weightData.faceData.node_storesData.num_edge_featuresData.num_edge_typesData.num_facesData.num_featuresData.num_node_featuresData.num_node_typesData.num_nodesData.posData.storesData.timeData.xData.y
LightningModuleLightningModule.__init__()LightningModule.all_gather()LightningModule.backward()LightningModule.clip_gradients()LightningModule.configure_callbacks()LightningModule.configure_gradient_clipping()LightningModule.configure_optimizers()LightningModule.forward()LightningModule.freeze()LightningModule.load_from_checkpoint()LightningModule.log()LightningModule.log_dict()LightningModule.lr_scheduler_step()LightningModule.lr_schedulers()LightningModule.manual_backward()LightningModule.optimizer_step()LightningModule.optimizer_zero_grad()LightningModule.optimizers()LightningModule.predict_step()LightningModule.print()LightningModule.test_step()LightningModule.to_onnx()LightningModule.to_torchscript()LightningModule.toggle_optimizer()LightningModule.training_step()LightningModule.unfreeze()LightningModule.untoggle_optimizer()LightningModule.validation_step()LightningModule.CHECKPOINT_HYPER_PARAMS_KEYLightningModule.CHECKPOINT_HYPER_PARAMS_NAMELightningModule.CHECKPOINT_HYPER_PARAMS_TYPELightningModule.automatic_optimizationLightningModule.current_epochLightningModule.device_meshLightningModule.example_input_arrayLightningModule.fabricLightningModule.global_rankLightningModule.global_stepLightningModule.local_rankLightningModule.loggerLightningModule.loggersLightningModule.on_gpuLightningModule.strict_loadingLightningModule.trainer
MeanMetricTBModelTBModel.__init__()TBModel.configure_optimizers()TBModel.forward()TBModel.log_metrics()TBModel.model_step()TBModel.on_test_epoch_end()TBModel.on_test_epoch_start()TBModel.on_train_epoch_end()TBModel.on_train_epoch_start()TBModel.on_val_epoch_start()TBModel.on_validation_epoch_end()TBModel.on_validation_epoch_start()TBModel.process_outputs()TBModel.setup()TBModel.test_step()TBModel.training_step()TBModel.validation_step()