topobench.nn.wrappers package#
Subpackages#
- topobench.nn.wrappers.cell package
- topobench.nn.wrappers.combinatorial package
- topobench.nn.wrappers.graph package
- topobench.nn.wrappers.hypergraph package
- topobench.nn.wrappers.simplicial package
Submodules#
topobench.nn.wrappers.base module#
Abstract class that provides an interface to handle the network output.
- class topobench.nn.wrappers.base.AbstractWrapper(backbone, **kwargs)[source]#
Bases:
ABC
,Module
Abstract class that provides an interface to handle the network output.
- Parameters:
- backbonetorch.nn.Module
Backbone model.
- **kwargsdict
Additional arguments for the class. It should contain the following keys: - out_channels (int): Number of output channels. - num_cell_dimensions (int): Number of cell dimensions.
- abstract forward(batch)[source]#
Forward pass for the model.
- Parameters:
- batchtorch_geometric.data.Data
Batch object containing the batched data.
- residual_connection(model_out, batch)[source]#
Residual connection for the model.
This method sums, for the embeddings of the cells of any rank, the output of the model with the input embeddings and applies layer normalization.
- Parameters:
- model_outdict
Dictionary containing the model output.
- batchtorch_geometric.data.Data
Batch object containing the batched data.
- Returns:
- dict
Dictionary containing the updated model output.
Module contents#
Wrappers implemented for TopoBench with automated exports.
- class topobench.nn.wrappers.AbstractWrapper(backbone, **kwargs)#
Bases:
ABC
,Module
Abstract class that provides an interface to handle the network output.
- Parameters:
- backbonetorch.nn.Module
Backbone model.
- **kwargsdict
Additional arguments for the class. It should contain the following keys: - out_channels (int): Number of output channels. - num_cell_dimensions (int): Number of cell dimensions.
- abstract forward(batch)#
Forward pass for the model.
- Parameters:
- batchtorch_geometric.data.Data
Batch object containing the batched data.
- residual_connection(model_out, batch)#
Residual connection for the model.
This method sums, for the embeddings of the cells of any rank, the output of the model with the input embeddings and applies layer normalization.
- Parameters:
- model_outdict
Dictionary containing the model output.
- batchtorch_geometric.data.Data
Batch object containing the batched data.
- Returns:
- dict
Dictionary containing the updated model output.
- class topobench.nn.wrappers.CANWrapper(backbone, **kwargs)#
Bases:
AbstractWrapper
Wrapper for the CAN model.
This wrapper defines the forward pass of the model. The CAN model returns the embeddings of the cells of rank 1. The embeddings of the cells of rank 0 are computed as the sum of the embeddings of the cells of rank 1 connected to them.
- forward(batch)#
Forward pass for the CAN wrapper.
- Parameters:
- batchtorch_geometric.data.Data
Batch object containing the batched data.
- Returns:
- dict
Dictionary containing the updated model output.
- class topobench.nn.wrappers.CCCNWrapper(backbone, **kwargs)#
Bases:
AbstractWrapper
Wrapper for the CCCN model.
This wrapper defines the forward pass of the model. The CCCN model returns the embeddings of the cells of rank 1. The embeddings of the cells of rank 0 are computed as the sum of the embeddings of the cells of rank 1 connected to them.
- forward(batch)#
Forward pass for the CCCN wrapper.
- Parameters:
- batchtorch_geometric.data.Data
Batch object containing the batched data.
- Returns:
- dict
Dictionary containing the updated model output.
- class topobench.nn.wrappers.CCXNWrapper(backbone, **kwargs)#
Bases:
AbstractWrapper
Wrapper for the CCXN model.
This wrapper defines the forward pass of the model. The CCXN model returns the embeddings of the cells of rank 0, 1, and 2.
- forward(batch)#
Forward pass for the CCXN wrapper.
- Parameters:
- batchtorch_geometric.data.Data
Batch object containing the batched domain data.
- Returns:
- dict
Dictionary containing the updated model output.
- class topobench.nn.wrappers.CWNWrapper(backbone, **kwargs)#
Bases:
AbstractWrapper
Wrapper for the CWN model.
This wrapper defines the forward pass of the model. The CWN model returns the embeddings of the cells of rank 0, 1, and 2.
- forward(batch)#
Forward pass for the CWN wrapper.
- Parameters:
- batchtorch_geometric.data.Data
Batch object containing the batched domain data.
- Returns:
- dict
Dictionary containing the updated model output.
- class topobench.nn.wrappers.GNNWrapper(backbone, **kwargs)#
Bases:
AbstractWrapper
Wrapper for the GNN models.
This wrapper defines the forward pass of the model. The GNN models return the embeddings of the cells of rank 0.
- forward(batch)#
Forward pass for the GNN wrapper.
- Parameters:
- batchtorch_geometric.data.Data
Batch object containing the batched data.
- Returns:
- dict
Dictionary containing the updated model output.
- class topobench.nn.wrappers.GraphMLPWrapper(backbone, **kwargs)#
Bases:
AbstractWrapper
Wrapper for the GNN models.
This wrapper defines the forward pass of the model. The GNN models return the embeddings of the cells of rank 0.
- forward(batch)#
Forward pass for the GNN wrapper.
- Parameters:
- batchtorch_geometric.data.Data
Batch object containing the batched data.
- Returns:
- dict
Dictionary containing the updated model output.
- class topobench.nn.wrappers.HypergraphWrapper(backbone, **kwargs)#
Bases:
AbstractWrapper
Wrapper for the hypergraph models.
This wrapper defines the forward pass of the model. The hypergraph model return the embeddings of the cells of rank 0, and 1 (the hyperedges).
- forward(batch)#
Forward pass for the hypergraph wrapper.
- Parameters:
- batchtorch_geometric.data.Data
Batch object containing the batched data.
- Returns:
- dict
Dictionary containing the updated model output.
- class topobench.nn.wrappers.SANWrapper(backbone, **kwargs)#
Bases:
AbstractWrapper
Wrapper for the SAN model.
This wrapper defines the forward pass of the model. The SAN model returns the embeddings of the cells of rank 1. The embeddings of the cells of rank 0 are computed as the sum of the embeddings of the cells of rank 1 connected to them.
- forward(batch)#
Forward pass for the SAN wrapper.
- Parameters:
- batchtorch_geometric.data.Data
Batch object containing the batched data.
- Returns:
- dict
Dictionary containing the updated model output.
- class topobench.nn.wrappers.SCCNNWrapper(backbone, **kwargs)#
Bases:
AbstractWrapper
Wrapper for the SCCNN model.
This wrapper defines the forward pass of the model. The SCCNN model returns the embeddings of the cells of rank 0, 1, and 2.
- forward(batch)#
Forward pass for the SCCNN wrapper.
- Parameters:
- batchtorch_geometric.data.Data
Batch object containing the batched data.
- Returns:
- dict
Dictionary containing the updated model output.
- class topobench.nn.wrappers.SCCNWrapper(backbone, **kwargs)#
Bases:
AbstractWrapper
Wrapper for the SCCN model.
This wrapper defines the forward pass of the model. The SCCN model returns the embeddings of the cells of any rank.
- forward(batch)#
Forward pass for the SCCN wrapper.
- Parameters:
- batchtorch_geometric.data.Data
Batch object containing the batched data.
- Returns:
- dict
Dictionary containing the updated model output.
- class topobench.nn.wrappers.SCNWrapper(backbone, **kwargs)#
Bases:
AbstractWrapper
Wrapper for the SCNW model.
This wrapper defines the forward pass of the model. The SCNW model returns the embeddings of the cells of rank 0, 1, and 2.
- forward(batch)#
Forward pass for the SCNW wrapper.
- Parameters:
- batchtorch_geometric.data.Data
Batch object containing the batched data.
- Returns:
- dict
Dictionary containing the updated model output.
- normalize_matrix(matrix)#
Normalize the input matrix.
The normalization is performed using the diagonal matrix of the inverse square root of the sum of the absolute values of the rows.
- Parameters:
- matrixtorch.sparse.FloatTensor
Input matrix to be normalized.
- Returns:
- torch.sparse.FloatTensor
Normalized matrix.
- class topobench.nn.wrappers.TuneWrapper(backbone, **kwargs)#
Bases:
AbstractWrapper
Wrapper for the TopoTune model.
This wrapper defines the forward pass of the TopoTune model.
- forward(batch)#
Forward pass for the Tune wrapper.
- Parameters:
- batchtorch_geometric.data.Data
Batch object containing the batched data.
- Returns:
- dict
Dictionary containing the updated model output.