topobench.nn.wrappers.simplicial package#
Submodules#
topobench.nn.wrappers.simplicial.san_wrapper module#
Wrapper for the SAN model.
- class topobench.nn.wrappers.simplicial.san_wrapper.SANWrapper(backbone, **kwargs)[source]#
Bases:
AbstractWrapper
Wrapper for the SAN model.
This wrapper defines the forward pass of the model. The SAN model returns the embeddings of the cells of rank 1. The embeddings of the cells of rank 0 are computed as the sum of the embeddings of the cells of rank 1 connected to them.
topobench.nn.wrappers.simplicial.sccn_wrapper module#
Wrapper for the SCCN model.
- class topobench.nn.wrappers.simplicial.sccn_wrapper.SCCNWrapper(backbone, **kwargs)[source]#
Bases:
AbstractWrapper
Wrapper for the SCCN model.
This wrapper defines the forward pass of the model. The SCCN model returns the embeddings of the cells of any rank.
topobench.nn.wrappers.simplicial.sccnn_wrapper module#
Wrapper for the SCCNN model.
- class topobench.nn.wrappers.simplicial.sccnn_wrapper.SCCNNWrapper(backbone, **kwargs)[source]#
Bases:
AbstractWrapper
Wrapper for the SCCNN model.
This wrapper defines the forward pass of the model. The SCCNN model returns the embeddings of the cells of rank 0, 1, and 2.
topobench.nn.wrappers.simplicial.scn_wrapper module#
Wrapper for the SCNW model.
- class topobench.nn.wrappers.simplicial.scn_wrapper.SCNWrapper(backbone, **kwargs)[source]#
Bases:
AbstractWrapper
Wrapper for the SCNW model.
This wrapper defines the forward pass of the model. The SCNW model returns the embeddings of the cells of rank 0, 1, and 2.
- forward(batch)[source]#
Forward pass for the SCNW wrapper.
- Parameters:
- batchtorch_geometric.data.Data
Batch object containing the batched data.
- Returns:
- dict
Dictionary containing the updated model output.
- normalize_matrix(matrix)[source]#
Normalize the input matrix.
The normalization is performed using the diagonal matrix of the inverse square root of the sum of the absolute values of the rows.
- Parameters:
- matrixtorch.sparse.FloatTensor
Input matrix to be normalized.
- Returns:
- torch.sparse.FloatTensor
Normalized matrix.
Module contents#
Wrappers for simplicial neural networks with automated exports.
- class topobench.nn.wrappers.simplicial.SANWrapper(backbone, **kwargs)#
Bases:
AbstractWrapper
Wrapper for the SAN model.
This wrapper defines the forward pass of the model. The SAN model returns the embeddings of the cells of rank 1. The embeddings of the cells of rank 0 are computed as the sum of the embeddings of the cells of rank 1 connected to them.
- forward(batch)#
Forward pass for the SAN wrapper.
- Parameters:
- batchtorch_geometric.data.Data
Batch object containing the batched data.
- Returns:
- dict
Dictionary containing the updated model output.
- class topobench.nn.wrappers.simplicial.SCCNNWrapper(backbone, **kwargs)#
Bases:
AbstractWrapper
Wrapper for the SCCNN model.
This wrapper defines the forward pass of the model. The SCCNN model returns the embeddings of the cells of rank 0, 1, and 2.
- forward(batch)#
Forward pass for the SCCNN wrapper.
- Parameters:
- batchtorch_geometric.data.Data
Batch object containing the batched data.
- Returns:
- dict
Dictionary containing the updated model output.
- class topobench.nn.wrappers.simplicial.SCCNWrapper(backbone, **kwargs)#
Bases:
AbstractWrapper
Wrapper for the SCCN model.
This wrapper defines the forward pass of the model. The SCCN model returns the embeddings of the cells of any rank.
- forward(batch)#
Forward pass for the SCCN wrapper.
- Parameters:
- batchtorch_geometric.data.Data
Batch object containing the batched data.
- Returns:
- dict
Dictionary containing the updated model output.
- class topobench.nn.wrappers.simplicial.SCNWrapper(backbone, **kwargs)#
Bases:
AbstractWrapper
Wrapper for the SCNW model.
This wrapper defines the forward pass of the model. The SCNW model returns the embeddings of the cells of rank 0, 1, and 2.
- forward(batch)#
Forward pass for the SCNW wrapper.
- Parameters:
- batchtorch_geometric.data.Data
Batch object containing the batched data.
- Returns:
- dict
Dictionary containing the updated model output.
- normalize_matrix(matrix)#
Normalize the input matrix.
The normalization is performed using the diagonal matrix of the inverse square root of the sum of the absolute values of the rows.
- Parameters:
- matrixtorch.sparse.FloatTensor
Input matrix to be normalized.
- Returns:
- torch.sparse.FloatTensor
Normalized matrix.