Source code for topobench.transforms.data_manipulations.redefine_simplicial_neighbourhoods
"""An transform that redifines simplicial complex neighbourhood."""
import torch_geometric
from topobench.data.utils import data2simplicial
from topobench.data.utils.utils import get_complex_connectivity
[docs]
class RedefineSimplicialNeighbourhoods(
torch_geometric.transforms.BaseTransform
):
r"""An transform that redifines simplicial complex neighbourhood.
Parameters
----------
**kwargs : optional
Parameters for the base transform.
"""
def __init__(self, **kwargs):
super().__init__()
self.type = "RedefineSimplicialNeighbourhoods"
self.parameters = kwargs
def __repr__(self) -> str:
return f"{self.__class__.__name__}(type={self.type!r}, parameters={self.parameters!r})"
[docs]
def forward(self, data: torch_geometric.data.Data):
r"""Apply the transform to the input data.
Parameters
----------
data : torch_geometric.data.Data
The input data.
Returns
-------
torch_geometric.data.Data
The same data.
"""
keys_to_keep = ["x", "x_0", "x_1", "x_2", "y"]
simplicial_complex = data2simplicial(data)
lifted_topology = get_complex_connectivity(
simplicial_complex,
self.parameters["complex_dim"],
neighborhoods=self.parameters["neighborhoods"],
signed=self.parameters["signed"],
)
# Get rid of the old keys
for key, _ in data:
if key not in keys_to_keep:
data.pop(key)
# Assign new topology
for key in lifted_topology:
data[key] = lifted_topology[key]
return data