Source code for topobench.loss.dataset.DatasetLoss
"""Loss module for the topobench package."""
import torch
import torch_geometric
from topobench.loss.base import AbstractLoss
[docs]
class DatasetLoss(AbstractLoss):
r"""Defines the default model loss for the given task.
Parameters
----------
dataset_loss : dict
Dictionary containing the dataset loss information.
"""
def __init__(self, dataset_loss):
super().__init__()
self.task = dataset_loss["task"]
self.loss_type = dataset_loss["loss_type"]
# Dataset loss
if self.task == "classification":
assert self.loss_type == "cross_entropy", (
"Invalid loss type for classification task,TB supports only cross_entropy loss for classification task"
)
self.criterion = torch.nn.CrossEntropyLoss()
elif self.task == "multilabel classification":
assert self.loss_type == "BCE", (
"Invalid loss type for classification task,TB supports only BCE for multilabel classification task"
)
self.criterion = torch.nn.BCEWithLogitsLoss(reduction="none")
elif self.task == "regression" and self.loss_type == "mse":
self.criterion = torch.nn.MSELoss()
elif self.task == "regression" and self.loss_type == "mae":
self.criterion = torch.nn.L1Loss()
else:
raise Exception("Loss is not defined")
def __repr__(self) -> str:
return f"{self.__class__.__name__}(task={self.task}, loss_type={self.loss_type})"
[docs]
def forward(self, model_out: dict, batch: torch_geometric.data.Data):
r"""Forward pass of the loss function.
Parameters
----------
model_out : dict
Dictionary containing the model output.
batch : torch_geometric.data.Data
Batch object containing the batched domain data.
Returns
-------
dict
Dictionary containing the model output with the loss.
"""
logits = model_out["logits"]
target = model_out["labels"]
return self.forward_criterion(logits, target)
[docs]
def forward_criterion(self, logits, target):
r"""Forward pass of the loss function.
Parameters
----------
logits : torch.Tensor
Model predictions.
target : torch.Tensor
Ground truth labels.
Returns
-------
torch.Tensor
Loss value.
"""
if self.task == "regression":
target = target.unsqueeze(1)
dataset_loss = self.criterion(logits, target)
elif self.task == "classification":
dataset_loss = self.criterion(logits, target)
elif self.task == "multilabel classification":
mask = ~torch.isnan(target)
# Avoid NaN values in the target
target = torch.where(mask, target, torch.zeros_like(target))
loss = self.criterion(logits, target)
# Mask out the loss for NaN values
loss = loss * mask
# Take out average
dataset_loss = (loss.sum(dim=-1) / mask.sum(dim=-1)).mean()
else:
raise Exception("Loss is not defined")
return dataset_loss