Source code for topobenchmark.transforms.feature_liftings.concatenation

"""Concatenation feature lifting."""

import torch
import torch_geometric


[docs] class Concatenation(torch_geometric.transforms.BaseTransform): r"""Lift r-cell features to r+1-cells by concatenation. Parameters ---------- **kwargs : optional Additional arguments for the class. """ def __init__(self, **kwargs): super().__init__() def __repr__(self) -> str: return f"{self.__class__.__name__}()"
[docs] def lift_features( self, data: torch_geometric.data.Data | dict ) -> torch_geometric.data.Data | dict: r"""Concatenate r-cell features to obtain r+1-cell features. Parameters ---------- data : torch_geometric.data.Data | dict The input data to be lifted. Returns ------- torch_geometric.data.Data | dict The lifted data. """ keys = sorted( [ key.split("_")[1] for key in data if "incidence" in key and "-" not in key ] ) for elem in keys: if f"x_{elem}" not in data: idx_to_project = 0 if elem == "hyperedges" else int(elem) - 1 incidence = data["incidence_" + elem] _, n = incidence.shape if n != 0: idxs_list = [] for n_feature in range(n): idxs_for_feature = incidence.indices()[ 0, incidence.indices()[1, :] == n_feature ] idxs_list.append(torch.sort(idxs_for_feature)[0]) idxs = torch.stack(idxs_list, dim=0) values = data[f"x_{idx_to_project}"][idxs].view(n, -1) else: m = data[f"x_{int(elem)-1}"].shape[1] * (int(elem) + 1) values = torch.zeros([0, m]) data["x_" + elem] = values return data
[docs] def forward( self, data: torch_geometric.data.Data | dict ) -> torch_geometric.data.Data | dict: r"""Apply the lifting to the input data. Parameters ---------- data : torch_geometric.data.Data | dict The input data to be lifted. Returns ------- torch_geometric.data.Data | dict The lifted data. """ data = self.lift_features(data) return data