Source code for topobench.nn.encoders.all_cell_encoder
"""Class to apply BaseEncoder to the features of higher order structures."""
import torch
import torch_geometric
from torch_geometric.nn.norm import GraphNorm
from topobench.nn.encoders.base import AbstractFeatureEncoder
[docs]
class AllCellFeatureEncoder(AbstractFeatureEncoder):
r"""Encoder class to apply BaseEncoder.
The BaseEncoder is applied to the features of higher order
structures. The class creates a BaseEncoder for each dimension specified in
selected_dimensions. Then during the forward pass, the BaseEncoders are
applied to the features of the corresponding dimensions.
Parameters
----------
in_channels : list[int]
Input dimensions for the features.
out_channels : list[int]
Output dimensions for the features.
proj_dropout : float, optional
Dropout for the BaseEncoders (default: 0).
selected_dimensions : list[int], optional
List of indexes to apply the BaseEncoders to (default: None).
**kwargs : dict, optional
Additional keyword arguments.
"""
def __init__(
self,
in_channels,
out_channels,
proj_dropout=0,
selected_dimensions=None,
**kwargs,
):
super().__init__()
self.in_channels = in_channels
self.out_channels = out_channels
self.dimensions = (
selected_dimensions
if (
selected_dimensions is not None
) # and len(selected_dimensions) <= len(self.in_channels))
else range(len(self.in_channels))
)
for i in self.dimensions:
setattr(
self,
f"encoder_{i}",
BaseEncoder(
self.in_channels[i],
self.out_channels,
dropout=proj_dropout,
),
)
def __repr__(self):
return f"{self.__class__.__name__}(in_channels={self.in_channels}, out_channels={self.out_channels}, dimensions={self.dimensions})"
[docs]
def forward(
self, data: torch_geometric.data.Data
) -> torch_geometric.data.Data:
r"""Forward pass.
The method applies the BaseEncoders to the features of the selected_dimensions.
Parameters
----------
data : torch_geometric.data.Data
Input data object which should contain x_{i} features for each i in the selected_dimensions.
Returns
-------
torch_geometric.data.Data
Output data object with updated x_{i} features.
"""
if not hasattr(data, "x_0"):
data.x_0 = data.x
for i in self.dimensions:
if hasattr(data, f"x_{i}") and hasattr(self, f"encoder_{i}"):
batch = getattr(data, f"batch_{i}")
data[f"x_{i}"] = getattr(self, f"encoder_{i}")(
data[f"x_{i}"], batch
)
return data
[docs]
class BaseEncoder(torch.nn.Module):
r"""Base encoder class used by AllCellFeatureEncoder.
This class uses two linear layers with GraphNorm, Relu activation function, and dropout between the two layers.
Parameters
----------
in_channels : int
Dimension of input features.
out_channels : int
Dimensions of output features.
dropout : float, optional
Percentage of channels to discard between the two linear layers (default: 0).
"""
def __init__(self, in_channels, out_channels, dropout=0):
super().__init__()
self.linear1 = torch.nn.Linear(in_channels, out_channels)
self.linear2 = torch.nn.Linear(out_channels, out_channels)
self.relu = torch.nn.ReLU()
self.BN = GraphNorm(out_channels)
self.dropout = torch.nn.Dropout(dropout)
def __repr__(self):
return f"{self.__class__.__name__}(in_channels={self.linear1.in_features}, out_channels={self.linear1.out_features})"
[docs]
def forward(self, x: torch.Tensor, batch: torch.Tensor) -> torch.Tensor:
r"""Forward pass of the encoder.
It applies two linear layers with GraphNorm, Relu activation function, and dropout between the two layers.
Parameters
----------
x : torch.Tensor
Input tensor of dimensions [N, in_channels].
batch : torch.Tensor
The batch vector which assigns each element to a specific example.
Returns
-------
torch.Tensor
Output tensor of shape [N, out_channels].
"""
x = self.linear1(x)
x = self.BN(x, batch=batch) if batch.shape[0] > 0 else self.BN(x)
x = self.dropout(self.relu(x))
x = self.linear2(x)
return x