topobench.loss.model.DGMLoss module#
Differentiable Graph Module loss function.
- class AbstractLoss#
Bases:
ABCAbstract class for the loss class.
- __init__()#
- abstract forward(model_out, batch)#
Forward pass.
- Parameters:
- model_outdict
Dictionary containing the model output.
- batchtorch_geometric.data.Data
Batch object containing the batched domain data.
- class DGMLoss(loss_weight=0.5)#
Bases:
AbstractLossDGM loss function.
Original implementation lcosmo/DGM_pytorch
- Parameters:
- loss_weightfloat, optional
Loss weight (default: 0.5).
- __init__(loss_weight=0.5)#
- forward(model_out, batch)#
Forward pass of the loss function.
- Parameters:
- model_outdict
Dictionary containing the model output.
- batchtorch_geometric.data.Data
Batch object containing the batched domain data.
- Returns:
- dict
Dictionary containing the model output with the loss.