Loss
This module implements custom Python classes to compute losses in TopoBenchmarkX.
Abstract class for the loss class.
-
class topobenchmark.loss.base.AbstractLoss[source]
Abstract class for the loss class.
-
abstract forward(model_out: dict, batch: Data)[source]
Forward pass.
- Parameters:
- model_outdict
Dictionary containing the model output.
- batchtorch_geometric.data.Data
Batch object containing the batched domain data.
Loss module for the topobenchmark package.
-
class topobenchmark.loss.loss.TBLoss(dataset_loss, modules_losses={})[source]
Defines the default model loss for the given task.
- Parameters:
- dataset_lossdict
Dictionary containing the dataset loss information.
- modules_lossesAbstractLoss, optional
Custom modules’ losses to be used.
-
forward(model_out: dict, batch: Data)[source]
Forward pass of the loss function.
- Parameters:
- model_outdict
Dictionary containing the model output.
- batchtorch_geometric.data.Data
Batch object containing the batched domain data.
- Returns:
- dict
Dictionary containing the model output with the loss.