topobench.loss.loss module#

Loss module for the topobench package.

class AbstractLoss#

Bases: ABC

Abstract class for the loss class.

__init__()#
abstract forward(model_out, batch)#

Forward pass.

Parameters:
model_outdict

Dictionary containing the model output.

batchtorch_geometric.data.Data

Batch object containing the batched domain data.

class DatasetLoss(dataset_loss)#

Bases: AbstractLoss

Defines the default model loss for the given task.

Parameters:
dataset_lossdict

Dictionary containing the dataset loss information.

__init__(dataset_loss)#
forward(model_out, batch)#

Forward pass of the loss function.

Parameters:
model_outdict

Dictionary containing the model output.

batchtorch_geometric.data.Data

Batch object containing the batched domain data.

Returns:
dict

Dictionary containing the model output with the loss.

forward_criterion(logits, target)#

Forward pass of the loss function.

Parameters:
logitstorch.Tensor

Model predictions.

targettorch.Tensor

Ground truth labels.

Returns:
torch.Tensor

Loss value.

class TBLoss(dataset_loss, modules_losses={})#

Bases: AbstractLoss

Defines the default model loss for the given task.

Parameters:
dataset_lossdict

Dictionary containing the dataset loss information.

modules_lossesAbstractLoss, optional

Custom modules’ losses to be used.

__init__(dataset_loss, modules_losses={})#
forward(model_out, batch)#

Forward pass of the loss function.

Parameters:
model_outdict

Dictionary containing the model output.

batchtorch_geometric.data.Data

Batch object containing the batched domain data.

Returns:
dict

Dictionary containing the model output with the loss.