Source code for topobench.nn.wrappers.cell.cwn_wrapper

"""Wrapper for the CWN model."""

from topobench.nn.wrappers.base import AbstractWrapper


[docs] class CWNWrapper(AbstractWrapper): r"""Wrapper for the CWN model. This wrapper defines the forward pass of the model. The CWN model returns the embeddings of the cells of rank 0, 1, and 2. """
[docs] def forward(self, batch): r"""Forward pass for the CWN wrapper. Parameters ---------- batch : torch_geometric.data.Data Batch object containing the batched domain data. Returns ------- dict Dictionary containing the updated model output. """ x_0, x_1, x_2 = self.backbone( x_0=batch.x_0, x_1=batch.x_1, x_2=batch.x_2, incidence_1_t=batch.incidence_1.T, adjacency_0=batch.adjacency_1, incidence_2=batch.incidence_2, ) model_out = {"labels": batch.y, "batch_0": batch.batch_0} model_out["x_0"] = x_0 model_out["x_1"] = x_1 model_out["x_2"] = x_2 return model_out