topobench.nn.wrappers.cell.cwn_wrapper module#
Wrapper for the CWN model.
- class AbstractWrapper(backbone, **kwargs)#
-
Abstract class that provides an interface to handle the network output.
- Parameters:
- backbonetorch.nn.Module
Backbone model.
- **kwargsdict
Additional arguments for the class. It should contain the following keys: - out_channels (int): Number of output channels. - num_cell_dimensions (int): Number of cell dimensions.
- __init__(backbone, **kwargs)#
- abstract forward(batch)#
Forward pass for the model.
- Parameters:
- batchtorch_geometric.data.Data
Batch object containing the batched data.
- residual_connection(model_out, batch)#
Residual connection for the model.
This method sums, for the embeddings of the cells of any rank, the output of the model with the input embeddings and applies layer normalization.
- Parameters:
- model_outdict
Dictionary containing the model output.
- batchtorch_geometric.data.Data
Batch object containing the batched data.
- Returns:
- dict
Dictionary containing the updated model output.
- class CWNWrapper(backbone, **kwargs)#
Bases:
AbstractWrapperWrapper for the CWN model.
This wrapper defines the forward pass of the model. The CWN model returns the embeddings of the cells of rank 0, 1, and 2.
- forward(batch)#
Forward pass for the CWN wrapper.
- Parameters:
- batchtorch_geometric.data.Data
Batch object containing the batched domain data.
- Returns:
- dict
Dictionary containing the updated model output.