topobench.loss.model package#
Submodules#
topobench.loss.model.DGMLoss module#
Differentiable Graph Module loss function.
- class topobench.loss.model.DGMLoss.DGMLoss(loss_weight=0.5)#
Bases:
AbstractLoss
DGM loss function.
Original implementation lcosmo/DGM_pytorch
- Parameters:
- loss_weightfloat, optional
Loss weight (default: 0.5).
topobench.loss.model.GraphMLPLoss module#
Graph MLP loss function.
- class topobench.loss.model.GraphMLPLoss.GraphMLPLoss(r_adj_power=2, tau=1.0, loss_weight=0.5)#
Bases:
AbstractLoss
Graph MLP loss function.
- Parameters:
- r_adj_powerint, optional
Power of the adjacency matrix (default: 2).
- taufloat, optional
Temperature parameter (default: 1).
- loss_weightfloat, optional
Loss weight (default: 0.5).
- forward(model_out: dict, batch: Data) Tensor #
Forward pass of the loss function.
- Parameters:
- model_outdict
Dictionary containing the model output.
- batchtorch_geometric.data.Data
Batch object containing the batched domain data.
- Returns:
- dict
Dictionary containing the model output with the loss.
- get_power_adj(edge_index)#
Get the power of the adjacency matrix.
- Parameters:
- edge_indextorch.Tensor
Edge index tensor.
- Returns:
- torch.Tensor
Power of the adjacency matrix.
- graph_mlp_contrast_loss(x_dis, adj_label)#
Graph MLP contrastive loss.
- Parameters:
- x_distorch.Tensor
Distance matrix.
- adj_labeltorch.Tensor
Adjacency matrix.
- Returns:
- torch.Tensor
Contrastive loss.
Module contents#
This module implements the loss functions for the topobench package.