topobench.nn.wrappers.simplicial package#

Submodules#

topobench.nn.wrappers.simplicial.san_wrapper module#

Wrapper for the SAN model.

class topobench.nn.wrappers.simplicial.san_wrapper.SANWrapper(backbone, **kwargs)[source]#

Bases: AbstractWrapper

Wrapper for the SAN model.

This wrapper defines the forward pass of the model. The SAN model returns the embeddings of the cells of rank 1. The embeddings of the cells of rank 0 are computed as the sum of the embeddings of the cells of rank 1 connected to them.

forward(batch)[source]#

Forward pass for the SAN wrapper.

Parameters:
batchtorch_geometric.data.Data

Batch object containing the batched data.

Returns:
dict

Dictionary containing the updated model output.

topobench.nn.wrappers.simplicial.sccn_wrapper module#

Wrapper for the SCCN model.

class topobench.nn.wrappers.simplicial.sccn_wrapper.SCCNWrapper(backbone, **kwargs)[source]#

Bases: AbstractWrapper

Wrapper for the SCCN model.

This wrapper defines the forward pass of the model. The SCCN model returns the embeddings of the cells of any rank.

forward(batch)[source]#

Forward pass for the SCCN wrapper.

Parameters:
batchtorch_geometric.data.Data

Batch object containing the batched data.

Returns:
dict

Dictionary containing the updated model output.

topobench.nn.wrappers.simplicial.sccnn_wrapper module#

Wrapper for the SCCNN model.

class topobench.nn.wrappers.simplicial.sccnn_wrapper.SCCNNWrapper(backbone, **kwargs)[source]#

Bases: AbstractWrapper

Wrapper for the SCCNN model.

This wrapper defines the forward pass of the model. The SCCNN model returns the embeddings of the cells of rank 0, 1, and 2.

forward(batch)[source]#

Forward pass for the SCCNN wrapper.

Parameters:
batchtorch_geometric.data.Data

Batch object containing the batched data.

Returns:
dict

Dictionary containing the updated model output.

topobench.nn.wrappers.simplicial.scn_wrapper module#

Wrapper for the SCNW model.

class topobench.nn.wrappers.simplicial.scn_wrapper.SCNWrapper(backbone, **kwargs)[source]#

Bases: AbstractWrapper

Wrapper for the SCNW model.

This wrapper defines the forward pass of the model. The SCNW model returns the embeddings of the cells of rank 0, 1, and 2.

forward(batch)[source]#

Forward pass for the SCNW wrapper.

Parameters:
batchtorch_geometric.data.Data

Batch object containing the batched data.

Returns:
dict

Dictionary containing the updated model output.

normalize_matrix(matrix)[source]#

Normalize the input matrix.

The normalization is performed using the diagonal matrix of the inverse square root of the sum of the absolute values of the rows.

Parameters:
matrixtorch.sparse.FloatTensor

Input matrix to be normalized.

Returns:
torch.sparse.FloatTensor

Normalized matrix.

Module contents#

Wrappers for simplicial neural networks with automated exports.

class topobench.nn.wrappers.simplicial.SANWrapper(backbone, **kwargs)#

Bases: AbstractWrapper

Wrapper for the SAN model.

This wrapper defines the forward pass of the model. The SAN model returns the embeddings of the cells of rank 1. The embeddings of the cells of rank 0 are computed as the sum of the embeddings of the cells of rank 1 connected to them.

forward(batch)#

Forward pass for the SAN wrapper.

Parameters:
batchtorch_geometric.data.Data

Batch object containing the batched data.

Returns:
dict

Dictionary containing the updated model output.

class topobench.nn.wrappers.simplicial.SCCNNWrapper(backbone, **kwargs)#

Bases: AbstractWrapper

Wrapper for the SCCNN model.

This wrapper defines the forward pass of the model. The SCCNN model returns the embeddings of the cells of rank 0, 1, and 2.

forward(batch)#

Forward pass for the SCCNN wrapper.

Parameters:
batchtorch_geometric.data.Data

Batch object containing the batched data.

Returns:
dict

Dictionary containing the updated model output.

class topobench.nn.wrappers.simplicial.SCCNWrapper(backbone, **kwargs)#

Bases: AbstractWrapper

Wrapper for the SCCN model.

This wrapper defines the forward pass of the model. The SCCN model returns the embeddings of the cells of any rank.

forward(batch)#

Forward pass for the SCCN wrapper.

Parameters:
batchtorch_geometric.data.Data

Batch object containing the batched data.

Returns:
dict

Dictionary containing the updated model output.

class topobench.nn.wrappers.simplicial.SCNWrapper(backbone, **kwargs)#

Bases: AbstractWrapper

Wrapper for the SCNW model.

This wrapper defines the forward pass of the model. The SCNW model returns the embeddings of the cells of rank 0, 1, and 2.

forward(batch)#

Forward pass for the SCNW wrapper.

Parameters:
batchtorch_geometric.data.Data

Batch object containing the batched data.

Returns:
dict

Dictionary containing the updated model output.

normalize_matrix(matrix)#

Normalize the input matrix.

The normalization is performed using the diagonal matrix of the inverse square root of the sum of the absolute values of the rows.

Parameters:
matrixtorch.sparse.FloatTensor

Input matrix to be normalized.

Returns:
torch.sparse.FloatTensor

Normalized matrix.