topobench.loss package#
This module implements the loss functions for the topobench package.
- class TBLoss(dataset_loss, modules_losses={})#
Bases:
AbstractLossDefines the default model loss for the given task.
- Parameters:
- dataset_lossdict
Dictionary containing the dataset loss information.
- modules_lossesAbstractLoss, optional
Custom modules’ losses to be used.
- __init__(dataset_loss, modules_losses={})#
- forward(model_out, batch)#
Forward pass of the loss function.
- Parameters:
- model_outdict
Dictionary containing the model output.
- batchtorch_geometric.data.Data
Batch object containing the batched domain data.
- Returns:
- dict
Dictionary containing the model output with the loss.