topobench.loss.model.DGMLoss module#

Differentiable Graph Module loss function.

class AbstractLoss#

Bases: ABC

Abstract class for the loss class.

__init__()#
abstract forward(model_out, batch)#

Forward pass.

Parameters:
model_outdict

Dictionary containing the model output.

batchtorch_geometric.data.Data

Batch object containing the batched domain data.

class DGMLoss(loss_weight=0.5)#

Bases: AbstractLoss

DGM loss function.

Original implementation lcosmo/DGM_pytorch

Parameters:
loss_weightfloat, optional

Loss weight (default: 0.5).

__init__(loss_weight=0.5)#
forward(model_out, batch)#

Forward pass of the loss function.

Parameters:
model_outdict

Dictionary containing the model output.

batchtorch_geometric.data.Data

Batch object containing the batched domain data.

Returns:
dict

Dictionary containing the model output with the loss.