Loss#

This module implements custom Python classes to compute losses in TopoBenchmarkX.

Abstract class for the loss class.

class topobenchmark.loss.base.AbstractLoss[source]#

Abstract class for the loss class.

abstract forward(model_out: dict, batch: Data)[source]#

Forward pass.

Parameters:
model_outdict

Dictionary containing the model output.

batchtorch_geometric.data.Data

Batch object containing the batched domain data.

Loss module for the topobenchmark package.

class topobenchmark.loss.loss.TBLoss(dataset_loss, modules_losses={})[source]#

Defines the default model loss for the given task.

Parameters:
dataset_lossdict

Dictionary containing the dataset loss information.

modules_lossesAbstractLoss, optional

Custom modules’ losses to be used.

forward(model_out: dict, batch: Data)[source]#

Forward pass of the loss function.

Parameters:
model_outdict

Dictionary containing the model output.

batchtorch_geometric.data.Data

Batch object containing the batched domain data.

Returns:
dict

Dictionary containing the model output with the loss.