Source code for encoders.dgm_encoder
"""Encoder class to apply BaseEncoder."""
import torch_geometric
from topobench.nn.encoders.all_cell_encoder import BaseEncoder
from topobench.nn.encoders.base import AbstractFeatureEncoder
from .kdgm import DGM_d
[docs]
class DGMStructureFeatureEncoder(AbstractFeatureEncoder):
r"""Encoder class to apply BaseEncoder.
The BaseEncoder is applied to the features of higher order
structures. The class creates a BaseEncoder for each dimension specified in
selected_dimensions. Then during the forward pass, the BaseEncoders are
applied to the features of the corresponding dimensions.
Parameters
----------
in_channels : list[int]
Input dimensions for the features.
out_channels : list[int]
Output dimensions for the features.
proj_dropout : float, optional
Dropout for the BaseEncoders (default: 0).
selected_dimensions : list[int], optional
List of indexes to apply the BaseEncoders to (default: None).
**kwargs : dict, optional
Additional keyword arguments.
"""
def __init__(
self,
in_channels,
out_channels,
proj_dropout=0,
selected_dimensions=None,
**kwargs,
):
super().__init__()
self.in_channels = in_channels
self.out_channels = out_channels
self.dimensions = (
selected_dimensions
if (
selected_dimensions is not None
) # and len(selected_dimensions) <= len(self.in_channels))
else range(len(self.in_channels))
)
for i in self.dimensions:
base_enc = BaseEncoder(
self.in_channels[i],
self.out_channels,
dropout=proj_dropout,
)
embed_f = BaseEncoder(
self.in_channels[i],
self.out_channels,
dropout=proj_dropout,
)
setattr(
self,
f"encoder_{i}",
DGM_d(base_enc=base_enc, embed_f=embed_f),
)
def __repr__(self):
return f"{self.__class__.__name__}(in_channels={self.in_channels}, out_channels={self.out_channels}, dimensions={self.dimensions})"
[docs]
def forward(
self, data: torch_geometric.data.Data
) -> torch_geometric.data.Data:
r"""Forward pass.
The method applies the BaseEncoders to the features of the selected_dimensions.
Parameters
----------
data : torch_geometric.data.Data
Input data object which should contain x_{i} features for each i in the selected_dimensions.
Returns
-------
torch_geometric.data.Data
Output data object with updated x_{i} features.
"""
if not hasattr(data, "x_0"):
data.x_0 = data.x
for i in self.dimensions:
if hasattr(data, f"x_{i}") and hasattr(self, f"encoder_{i}"):
batch = getattr(data, f"batch_{i}")
x_, x_aux, edges_dgm, logprobs = getattr(self, f"encoder_{i}")(
data[f"x_{i}"], batch
)
data[f"x_{i}"] = x_
data[f"x_aux_{i}"] = x_aux
data["edges_index"] = edges_dgm
data[f"logprobs_{i}"] = logprobs
return data