topobench.nn.wrappers.graph package#

Wrappers for graph neural networks with automated exports.

class GNNWrapper(backbone, **kwargs)#

Bases: AbstractWrapper

Wrapper for the GNN models.

This wrapper defines the forward pass of the model. The GNN models return the embeddings of the cells of rank 0.

forward(batch)#

Forward pass for the GNN wrapper.

Parameters:
batchtorch_geometric.data.Data

Batch object containing the batched data.

Returns:
dict

Dictionary containing the updated model output.

class GraphMLPWrapper(backbone, **kwargs)#

Bases: AbstractWrapper

Wrapper for the GNN models.

This wrapper defines the forward pass of the model. The GNN models return the embeddings of the cells of rank 0.

forward(batch)#

Forward pass for the GNN wrapper.

Parameters:
batchtorch_geometric.data.Data

Batch object containing the batched data.

Returns:
dict

Dictionary containing the updated model output.

Submodules#