topobench.loss.dataset package#

Submodules#

topobench.loss.dataset.DatasetLoss module#

Loss module for the topobench package.

class topobench.loss.dataset.DatasetLoss.DatasetLoss(dataset_loss)[source]#

Bases: AbstractLoss

Defines the default model loss for the given task.

Parameters:
dataset_lossdict

Dictionary containing the dataset loss information.

forward(model_out: dict, batch: Data)[source]#

Forward pass of the loss function.

Parameters:
model_outdict

Dictionary containing the model output.

batchtorch_geometric.data.Data

Batch object containing the batched domain data.

Returns:
dict

Dictionary containing the model output with the loss.

forward_criterion(logits, target)[source]#

Forward pass of the loss function.

Parameters:
logitstorch.Tensor

Model predictions.

targettorch.Tensor

Ground truth labels.

Returns:
torch.Tensor

Loss value.

Module contents#

This module implements the loss functions for the topobench package.