topobench.nn.wrappers.simplicial package#
Wrappers for simplicial neural networks with automated exports.
- class SANWrapper(backbone, **kwargs)#
Bases:
AbstractWrapperWrapper for the SAN model.
This wrapper defines the forward pass of the model. The SAN model returns the embeddings of the cells of rank 1. The embeddings of the cells of rank 0 are computed as the sum of the embeddings of the cells of rank 1 connected to them.
- forward(batch)#
Forward pass for the SAN wrapper.
- Parameters:
- batchtorch_geometric.data.Data
Batch object containing the batched data.
- Returns:
- dict
Dictionary containing the updated model output.
- class SCCNNWrapper(backbone, **kwargs)#
Bases:
AbstractWrapperWrapper for the SCCNN model.
This wrapper defines the forward pass of the model. The SCCNN model returns the embeddings of the cells of rank 0, 1, and 2.
- forward(batch)#
Forward pass for the SCCNN wrapper.
- Parameters:
- batchtorch_geometric.data.Data
Batch object containing the batched data.
- Returns:
- dict
Dictionary containing the updated model output.
- class SCCNWrapper(backbone, **kwargs)#
Bases:
AbstractWrapperWrapper for the SCCN model.
This wrapper defines the forward pass of the model. The SCCN model returns the embeddings of the cells of any rank.
- forward(batch)#
Forward pass for the SCCN wrapper.
- Parameters:
- batchtorch_geometric.data.Data
Batch object containing the batched data.
- Returns:
- dict
Dictionary containing the updated model output.
- class SCNWrapper(backbone, **kwargs)#
Bases:
AbstractWrapperWrapper for the SCNW model.
This wrapper defines the forward pass of the model. The SCNW model returns the embeddings of the cells of rank 0, 1, and 2.
- forward(batch)#
Forward pass for the SCNW wrapper.
- Parameters:
- batchtorch_geometric.data.Data
Batch object containing the batched data.
- Returns:
- dict
Dictionary containing the updated model output.
- normalize_matrix(matrix)#
Normalize the input matrix.
The normalization is performed using the diagonal matrix of the inverse square root of the sum of the absolute values of the rows.
- Parameters:
- matrixtorch.sparse.FloatTensor
Input matrix to be normalized.
- Returns:
- torch.sparse.FloatTensor
Normalized matrix.