topobench.loss.dataset package#

This module implements the loss functions for the topobench package.

class DatasetLoss(dataset_loss)#

Bases: AbstractLoss

Defines the default model loss for the given task.

Parameters:
dataset_lossdict

Dictionary containing the dataset loss information.

__init__(dataset_loss)#
forward(model_out, batch)#

Forward pass of the loss function.

Parameters:
model_outdict

Dictionary containing the model output.

batchtorch_geometric.data.Data

Batch object containing the batched domain data.

Returns:
dict

Dictionary containing the model output with the loss.

forward_criterion(logits, target)#

Forward pass of the loss function.

Parameters:
logitstorch.Tensor

Model predictions.

targettorch.Tensor

Ground truth labels.

Returns:
torch.Tensor

Loss value.

Submodules#