topobench.nn.wrappers.cell package#

Wrappers for cell neural networks with automated exports.

class CANWrapper(backbone, **kwargs)#

Bases: AbstractWrapper

Wrapper for the CAN model.

This wrapper defines the forward pass of the model. The CAN model returns the embeddings of the cells of rank 1. The embeddings of the cells of rank 0 are computed as the sum of the embeddings of the cells of rank 1 connected to them.

forward(batch)#

Forward pass for the CAN wrapper.

Parameters:
batchtorch_geometric.data.Data

Batch object containing the batched data.

Returns:
dict

Dictionary containing the updated model output.

class CCCNWrapper(backbone, **kwargs)#

Bases: AbstractWrapper

Wrapper for the CCCN model.

This wrapper defines the forward pass of the model. The CCCN model returns the embeddings of the cells of rank 1. The embeddings of the cells of rank 0 are computed as the sum of the embeddings of the cells of rank 1 connected to them.

forward(batch)#

Forward pass for the CCCN wrapper.

Parameters:
batchtorch_geometric.data.Data

Batch object containing the batched data.

Returns:
dict

Dictionary containing the updated model output.

class CCXNWrapper(backbone, **kwargs)#

Bases: AbstractWrapper

Wrapper for the CCXN model.

This wrapper defines the forward pass of the model. The CCXN model returns the embeddings of the cells of rank 0, 1, and 2.

forward(batch)#

Forward pass for the CCXN wrapper.

Parameters:
batchtorch_geometric.data.Data

Batch object containing the batched domain data.

Returns:
dict

Dictionary containing the updated model output.

class CWNWrapper(backbone, **kwargs)#

Bases: AbstractWrapper

Wrapper for the CWN model.

This wrapper defines the forward pass of the model. The CWN model returns the embeddings of the cells of rank 0, 1, and 2.

forward(batch)#

Forward pass for the CWN wrapper.

Parameters:
batchtorch_geometric.data.Data

Batch object containing the batched domain data.

Returns:
dict

Dictionary containing the updated model output.

Submodules#