topobench.loss.model.GraphMLPLoss module#

Graph MLP loss function.

class AbstractLoss#

Bases: ABC

Abstract class for the loss class.

__init__()#
abstract forward(model_out, batch)#

Forward pass.

Parameters:
model_outdict

Dictionary containing the model output.

batchtorch_geometric.data.Data

Batch object containing the batched domain data.

class GraphMLPLoss(r_adj_power=2, tau=1.0, loss_weight=0.5)#

Bases: AbstractLoss

Graph MLP loss function.

Parameters:
r_adj_powerint, optional

Power of the adjacency matrix (default: 2).

taufloat, optional

Temperature parameter (default: 1).

loss_weightfloat, optional

Loss weight (default: 0.5).

__init__(r_adj_power=2, tau=1.0, loss_weight=0.5)#
forward(model_out, batch)#

Forward pass of the loss function.

Parameters:
model_outdict

Dictionary containing the model output.

batchtorch_geometric.data.Data

Batch object containing the batched domain data.

Returns:
dict

Dictionary containing the model output with the loss.

get_power_adj(edge_index)#

Get the power of the adjacency matrix.

Parameters:
edge_indextorch.Tensor

Edge index tensor.

Returns:
torch.Tensor

Power of the adjacency matrix.

graph_mlp_contrast_loss(x_dis, adj_label)#

Graph MLP contrastive loss.

Parameters:
x_distorch.Tensor

Distance matrix.

adj_labeltorch.Tensor

Adjacency matrix.

Returns:
torch.Tensor

Contrastive loss.