topobench.loss.dataset package#
Submodules#
topobench.loss.dataset.DatasetLoss module#
Loss module for the topobench package.
- class topobench.loss.dataset.DatasetLoss.DatasetLoss(dataset_loss)#
Bases:
AbstractLoss
Defines the default model loss for the given task.
- Parameters:
- dataset_lossdict
Dictionary containing the dataset loss information.
- forward(model_out: dict, batch: Data)#
Forward pass of the loss function.
- Parameters:
- model_outdict
Dictionary containing the model output.
- batchtorch_geometric.data.Data
Batch object containing the batched domain data.
- Returns:
- dict
Dictionary containing the model output with the loss.
- forward_criterion(logits, target)#
Forward pass of the loss function.
- Parameters:
- logitstorch.Tensor
Model predictions.
- targettorch.Tensor
Ground truth labels.
- Returns:
- torch.Tensor
Loss value.
Module contents#
This module implements the loss functions for the topobench package.