topobench.nn.wrappers.simplicial.sccn_wrapper module#

Wrapper for the SCCN model.

class AbstractWrapper(backbone, **kwargs)#

Bases: ABC, Module

Abstract class that provides an interface to handle the network output.

Parameters:
backbonetorch.nn.Module

Backbone model.

**kwargsdict

Additional arguments for the class. It should contain the following keys: - out_channels (int): Number of output channels. - num_cell_dimensions (int): Number of cell dimensions.

__init__(backbone, **kwargs)#
abstract forward(batch)#

Forward pass for the model.

Parameters:
batchtorch_geometric.data.Data

Batch object containing the batched data.

residual_connection(model_out, batch)#

Residual connection for the model.

This method sums, for the embeddings of the cells of any rank, the output of the model with the input embeddings and applies layer normalization.

Parameters:
model_outdict

Dictionary containing the model output.

batchtorch_geometric.data.Data

Batch object containing the batched data.

Returns:
dict

Dictionary containing the updated model output.

class SCCNWrapper(backbone, **kwargs)#

Bases: AbstractWrapper

Wrapper for the SCCN model.

This wrapper defines the forward pass of the model. The SCCN model returns the embeddings of the cells of any rank.

forward(batch)#

Forward pass for the SCCN wrapper.

Parameters:
batchtorch_geometric.data.Data

Batch object containing the batched data.

Returns:
dict

Dictionary containing the updated model output.