topobench.loss package#
Subpackages#
Submodules#
topobench.loss.base module#
Abstract class for the loss class.
topobench.loss.loss module#
Loss module for the topobench package.
- class topobench.loss.loss.TBLoss(dataset_loss, modules_losses={})#
Bases:
AbstractLoss
Defines the default model loss for the given task.
- Parameters:
- dataset_lossdict
Dictionary containing the dataset loss information.
- modules_lossesAbstractLoss, optional
Custom modules’ losses to be used.
- forward(model_out: dict, batch: torch_geometric.data.Data)#
Forward pass of the loss function.
- Parameters:
- model_outdict
Dictionary containing the model output.
- batchtorch_geometric.data.Data
Batch object containing the batched domain data.
- Returns:
- dict
Dictionary containing the model output with the loss.
Module contents#
This module implements the loss functions for the topobench package.
- class topobench.loss.TBLoss(dataset_loss, modules_losses={})#
Bases:
AbstractLoss
Defines the default model loss for the given task.
- Parameters:
- dataset_lossdict
Dictionary containing the dataset loss information.
- modules_lossesAbstractLoss, optional
Custom modules’ losses to be used.
- forward(model_out: dict, batch: torch_geometric.data.Data)#
Forward pass of the loss function.
- Parameters:
- model_outdict
Dictionary containing the model output.
- batchtorch_geometric.data.Data
Batch object containing the batched domain data.
- Returns:
- dict
Dictionary containing the model output with the loss.