Source code for loss.loss
"""Loss module for the topobench package."""
import torch
import torch_geometric
from topobench.loss.base import AbstractLoss
from topobench.loss.dataset import DatasetLoss
[docs]
class TBLoss(AbstractLoss):
r"""Defines the default model loss for the given task.
Parameters
----------
dataset_loss : dict
Dictionary containing the dataset loss information.
modules_losses : AbstractLoss, optional
Custom modules' losses to be used.
"""
def __init__(self, dataset_loss, modules_losses={}): # noqa: B006
super().__init__()
self.losses = []
# Dataset loss
self.losses.append(DatasetLoss(dataset_loss))
# Model losses
self.losses.extend(
[loss for loss in modules_losses.values() if loss is not None]
)
def __repr__(self) -> str:
return f"{self.__class__.__name__}(task={self.task}, loss_type={self.loss_type})"
[docs]
def forward(self, model_out: dict, batch: torch_geometric.data.Data):
r"""Forward pass of the loss function.
Parameters
----------
model_out : dict
Dictionary containing the model output.
batch : torch_geometric.data.Data
Batch object containing the batched domain data.
Returns
-------
dict
Dictionary containing the model output with the loss.
"""
losses = [loss(model_out, batch) for loss in self.losses]
model_out["loss"] = torch.stack(losses).sum()
return model_out