Source code for topobench.transforms.liftings.simplicial2combinatorial.coface_cc_lifting

"""The CofaceCCLifting lifting."""

from toponetx.classes.combinatorial_complex import CombinatorialComplex
from toponetx.classes.hyperedge import HyperEdge
from torch_geometric.data import Data

from topobench.data.utils.utils import (
    get_combinatorial_complex_connectivity,
)
from topobench.transforms.liftings.simplicial2combinatorial.base import (
    Simplicial2CombinatorialLifting,
)


[docs] class CofaceCCLifting(Simplicial2CombinatorialLifting): """The CofaceCCLifting class. This class lifts a simplicial complex to a combinatorial complex by using the coface relation between the simplicial cells. Parameters ---------- **kwargs : dict The keyword arguments. """ def __init__(self, **kwargs): super().__init__(**kwargs) self.keep_features = kwargs.get("keep_features", False)
[docs] def get_lower_cells(self, data: Data) -> list[HyperEdge]: """Get the lower cells of the complex. Parameters ---------- data : torch_geometric.data.Data The input data. Returns ------- list The list of lower cells. """ cells: list[HyperEdge] = [] ## Add 0-cells for cell in range(data["x_0"].size(0)): zero_cell = HyperEdge([cell], rank=0) cells.append(zero_cell) ## Add 1-cells for inc_c_1 in data["incidence_1"].to_dense().T: # Get the 0-cells that are incident to the 1-cell cell_0_bound = inc_c_1.nonzero().flatten().tolist() assert len(cell_0_bound) == 2 one_cell = HyperEdge(cell_0_bound, rank=1) cells.append(one_cell) ## Add 2-cells for inc_c_2 in data["incidence_2"].to_dense().T: # Get the 1-cells that are incident to the 2-cell cell_1_bound = inc_c_2.nonzero().flatten() # Get the 0-cells that are incident to the 1-cells cell_0_bound = ( data["incidence_1"].to_dense().T[cell_1_bound].nonzero() ) # Get the actual 0-cells since nonzero() # indexes in 2D cell_0_bound = cell_0_bound[:, 1] # Remove redudants and convert to tuple two_cell = HyperEdge(tuple(set(cell_0_bound.tolist())), rank=2) cells.append(two_cell) return cells
[docs] def lift_topology(self, data: Data) -> dict: """Lift the simplicial topology to a combinatorial complex. Parameters ---------- data : torch_geometric.data.Data The input data. Returns ------- dict The lifted connectivity dict. """ # Check that the dataset has the required fields # assume that it's a simplicial dataset assert "incidence_1" in data assert "incidence_2" in data cells = self.get_lower_cells(data) ccc = CombinatorialComplex(cells, graph_based=False) # Iterate over the 2-cells and add the 3-cells for r_cell in ccc.skeleton(rank=2): # Get the coface of the 2-cell indices, coface = ccc.coadjacency_matrix(2, 1, index=True) # Get the indices of the 2-cell that are co-adjacent coface_indices = ( coface.todense()[indices[r_cell]].nonzero()[1].tolist() ) cell_3 = set(r_cell) # Iterate over the indices of the 2-cells # and add their 0-cells as a 3-cell for idx in coface_indices: cell_3 = cell_3.union(set(ccc.skeleton(rank=2)[idx])) # Adding a rank 3 cell with less than 4 vertices # will take this cell from the skeleton of 2-cells if it exists # so in the interest of keeping features the user # can choose to recompute all feature embeddings if len(cell_3) < 4 and self.keep_features: continue # Get the cofaces incident to the 2-cell `cell` and add `cell` to the set ccc.add_cell(cell_3, rank=3) # Create the incidence, adjacency and laplacian matrices lifted_data = get_combinatorial_complex_connectivity(ccc, 3) # If the user wants to keep the features # from the r-cells aside from the first x_0 if self.keep_features: lifted_data = { "x_0": data["x_0"], "x_1": data["x_1"], "x_2": data["x_2"], **lifted_data, } else: lifted_data = {"x_0": data["x_0"], **lifted_data} return lifted_data
[docs] def forward(self, data: Data) -> Data: """Forward pass. Parameters ---------- data : torch_geometric.data.Data The input data. Returns ------- torch_geometric.data.Data The updated lifted data. """ initial_data = data.to_dict() lifted_topology = self.lift_topology(data) lifted_topology = self.feature_lifting(lifted_topology) # Make sure to remove passing of duplicated data # so that the constructor of Data does not raise an error for k in lifted_topology: if k in initial_data: del initial_data[k] return Data(**initial_data, **lifted_topology)