topobench.loss.model.GraphMLPLoss module#
Graph MLP loss function.
- class AbstractLoss#
Bases:
ABCAbstract class for the loss class.
- __init__()#
- abstract forward(model_out, batch)#
Forward pass.
- Parameters:
- model_outdict
Dictionary containing the model output.
- batchtorch_geometric.data.Data
Batch object containing the batched domain data.
- class GraphMLPLoss(r_adj_power=2, tau=1.0, loss_weight=0.5)#
Bases:
AbstractLossGraph MLP loss function.
- Parameters:
- r_adj_powerint, optional
Power of the adjacency matrix (default: 2).
- taufloat, optional
Temperature parameter (default: 1).
- loss_weightfloat, optional
Loss weight (default: 0.5).
- __init__(r_adj_power=2, tau=1.0, loss_weight=0.5)#
- forward(model_out, batch)#
Forward pass of the loss function.
- Parameters:
- model_outdict
Dictionary containing the model output.
- batchtorch_geometric.data.Data
Batch object containing the batched domain data.
- Returns:
- dict
Dictionary containing the model output with the loss.
- get_power_adj(edge_index)#
Get the power of the adjacency matrix.
- Parameters:
- edge_indextorch.Tensor
Edge index tensor.
- Returns:
- torch.Tensor
Power of the adjacency matrix.
- graph_mlp_contrast_loss(x_dis, adj_label)#
Graph MLP contrastive loss.
- Parameters:
- x_distorch.Tensor
Distance matrix.
- adj_labeltorch.Tensor
Adjacency matrix.
- Returns:
- torch.Tensor
Contrastive loss.