topobench.nn.wrappers.graph package#
Submodules#
topobench.nn.wrappers.graph.gnn_wrapper module#
Wrapper for the GNN models.
- class topobench.nn.wrappers.graph.gnn_wrapper.GNNWrapper(backbone, **kwargs)[source]#
Bases:
AbstractWrapper
Wrapper for the GNN models.
This wrapper defines the forward pass of the model. The GNN models return the embeddings of the cells of rank 0.
topobench.nn.wrappers.graph.graph_mlp_wrapper module#
Wrapper for the GNN models.
- class topobench.nn.wrappers.graph.graph_mlp_wrapper.GraphMLPWrapper(backbone, **kwargs)[source]#
Bases:
AbstractWrapper
Wrapper for the GNN models.
This wrapper defines the forward pass of the model. The GNN models return the embeddings of the cells of rank 0.
Module contents#
Wrappers for graph neural networks with automated exports.
- class topobench.nn.wrappers.graph.GNNWrapper(backbone, **kwargs)#
Bases:
AbstractWrapper
Wrapper for the GNN models.
This wrapper defines the forward pass of the model. The GNN models return the embeddings of the cells of rank 0.
- forward(batch)#
Forward pass for the GNN wrapper.
- Parameters:
- batchtorch_geometric.data.Data
Batch object containing the batched data.
- Returns:
- dict
Dictionary containing the updated model output.
- class topobench.nn.wrappers.graph.GraphMLPWrapper(backbone, **kwargs)#
Bases:
AbstractWrapper
Wrapper for the GNN models.
This wrapper defines the forward pass of the model. The GNN models return the embeddings of the cells of rank 0.
- forward(batch)#
Forward pass for the GNN wrapper.
- Parameters:
- batchtorch_geometric.data.Data
Batch object containing the batched data.
- Returns:
- dict
Dictionary containing the updated model output.