topobench.nn.wrappers.simplicial.san_wrapper module#
Wrapper for the SAN model.
- class AbstractWrapper(backbone, **kwargs)#
-
Abstract class that provides an interface to handle the network output.
- Parameters:
- backbonetorch.nn.Module
Backbone model.
- **kwargsdict
Additional arguments for the class. It should contain the following keys: - out_channels (int): Number of output channels. - num_cell_dimensions (int): Number of cell dimensions.
- __init__(backbone, **kwargs)#
- abstract forward(batch)#
Forward pass for the model.
- Parameters:
- batchtorch_geometric.data.Data
Batch object containing the batched data.
- residual_connection(model_out, batch)#
Residual connection for the model.
This method sums, for the embeddings of the cells of any rank, the output of the model with the input embeddings and applies layer normalization.
- Parameters:
- model_outdict
Dictionary containing the model output.
- batchtorch_geometric.data.Data
Batch object containing the batched data.
- Returns:
- dict
Dictionary containing the updated model output.
- class SANWrapper(backbone, **kwargs)#
Bases:
AbstractWrapperWrapper for the SAN model.
This wrapper defines the forward pass of the model. The SAN model returns the embeddings of the cells of rank 1. The embeddings of the cells of rank 0 are computed as the sum of the embeddings of the cells of rank 1 connected to them.
- forward(batch)#
Forward pass for the SAN wrapper.
- Parameters:
- batchtorch_geometric.data.Data
Batch object containing the batched data.
- Returns:
- dict
Dictionary containing the updated model output.