topobench.loss package#

Subpackages#

Submodules#

topobench.loss.base module#

Abstract class for the loss class.

class topobench.loss.base.AbstractLoss[source]#

Bases: ABC

Abstract class for the loss class.

abstract forward(model_out: dict, batch: Data)[source]#

Forward pass.

Parameters:
model_outdict

Dictionary containing the model output.

batchtorch_geometric.data.Data

Batch object containing the batched domain data.

topobench.loss.loss module#

Loss module for the topobench package.

class topobench.loss.loss.TBLoss(dataset_loss, modules_losses={})[source]#

Bases: AbstractLoss

Defines the default model loss for the given task.

Parameters:
dataset_lossdict

Dictionary containing the dataset loss information.

modules_lossesAbstractLoss, optional

Custom modules’ losses to be used.

forward(model_out: dict, batch: Data)[source]#

Forward pass of the loss function.

Parameters:
model_outdict

Dictionary containing the model output.

batchtorch_geometric.data.Data

Batch object containing the batched domain data.

Returns:
dict

Dictionary containing the model output with the loss.

Module contents#

This module implements the loss functions for the topobench package.

class topobench.loss.TBLoss(dataset_loss, modules_losses={})[source]#

Bases: AbstractLoss

Defines the default model loss for the given task.

Parameters:
dataset_lossdict

Dictionary containing the dataset loss information.

modules_lossesAbstractLoss, optional

Custom modules’ losses to be used.

forward(model_out: dict, batch: Data)[source]#

Forward pass of the loss function.

Parameters:
model_outdict

Dictionary containing the model output.

batchtorch_geometric.data.Data

Batch object containing the batched domain data.

Returns:
dict

Dictionary containing the model output with the loss.