topobench.nn.wrappers.cell package#
Wrappers for cell neural networks with automated exports.
- class CANWrapper(backbone, **kwargs)#
Bases:
AbstractWrapperWrapper for the CAN model.
This wrapper defines the forward pass of the model. The CAN model returns the embeddings of the cells of rank 1. The embeddings of the cells of rank 0 are computed as the sum of the embeddings of the cells of rank 1 connected to them.
- forward(batch)#
Forward pass for the CAN wrapper.
- Parameters:
- batchtorch_geometric.data.Data
Batch object containing the batched data.
- Returns:
- dict
Dictionary containing the updated model output.
- class CCCNWrapper(backbone, **kwargs)#
Bases:
AbstractWrapperWrapper for the CCCN model.
This wrapper defines the forward pass of the model. The CCCN model returns the embeddings of the cells of rank 1. The embeddings of the cells of rank 0 are computed as the sum of the embeddings of the cells of rank 1 connected to them.
- forward(batch)#
Forward pass for the CCCN wrapper.
- Parameters:
- batchtorch_geometric.data.Data
Batch object containing the batched data.
- Returns:
- dict
Dictionary containing the updated model output.
- class CCXNWrapper(backbone, **kwargs)#
Bases:
AbstractWrapperWrapper for the CCXN model.
This wrapper defines the forward pass of the model. The CCXN model returns the embeddings of the cells of rank 0, 1, and 2.
- forward(batch)#
Forward pass for the CCXN wrapper.
- Parameters:
- batchtorch_geometric.data.Data
Batch object containing the batched domain data.
- Returns:
- dict
Dictionary containing the updated model output.
- class CWNWrapper(backbone, **kwargs)#
Bases:
AbstractWrapperWrapper for the CWN model.
This wrapper defines the forward pass of the model. The CWN model returns the embeddings of the cells of rank 0, 1, and 2.
- forward(batch)#
Forward pass for the CWN wrapper.
- Parameters:
- batchtorch_geometric.data.Data
Batch object containing the batched domain data.
- Returns:
- dict
Dictionary containing the updated model output.