topobench.loss.loss module#
Loss module for the topobench package.
- class AbstractLoss#
Bases:
ABCAbstract class for the loss class.
- __init__()#
- abstract forward(model_out, batch)#
Forward pass.
- Parameters:
- model_outdict
Dictionary containing the model output.
- batchtorch_geometric.data.Data
Batch object containing the batched domain data.
- class DatasetLoss(dataset_loss)#
Bases:
AbstractLossDefines the default model loss for the given task.
- Parameters:
- dataset_lossdict
Dictionary containing the dataset loss information.
- __init__(dataset_loss)#
- forward(model_out, batch)#
Forward pass of the loss function.
- Parameters:
- model_outdict
Dictionary containing the model output.
- batchtorch_geometric.data.Data
Batch object containing the batched domain data.
- Returns:
- dict
Dictionary containing the model output with the loss.
- forward_criterion(logits, target)#
Forward pass of the loss function.
- Parameters:
- logitstorch.Tensor
Model predictions.
- targettorch.Tensor
Ground truth labels.
- Returns:
- torch.Tensor
Loss value.
- class TBLoss(dataset_loss, modules_losses={})#
Bases:
AbstractLossDefines the default model loss for the given task.
- Parameters:
- dataset_lossdict
Dictionary containing the dataset loss information.
- modules_lossesAbstractLoss, optional
Custom modules’ losses to be used.
- __init__(dataset_loss, modules_losses={})#
- forward(model_out, batch)#
Forward pass of the loss function.
- Parameters:
- model_outdict
Dictionary containing the model output.
- batchtorch_geometric.data.Data
Batch object containing the batched domain data.
- Returns:
- dict
Dictionary containing the model output with the loss.