topobench.nn.encoders.dgm_encoder module#

Encoder class to apply BaseEncoder.

class AbstractFeatureEncoder#

Bases: Module

Abstract class to define a custom feature encoder.

__init__()#

Initialize internal Module state, shared by both nn.Module and ScriptModule.

abstract forward(data)#

Forward pass of the feature encoder model.

Parameters:
datatorch_geometric.data.Data

Input data object which should contain x features.

class BaseEncoder(in_channels, out_channels, dropout=0)#

Bases: Module

Base encoder class used by AllCellFeatureEncoder.

This class uses two linear layers with GraphNorm, Relu activation function, and dropout between the two layers.

Parameters:
in_channelsint

Dimension of input features.

out_channelsint

Dimensions of output features.

dropoutfloat, optional

Percentage of channels to discard between the two linear layers (default: 0).

__init__(in_channels, out_channels, dropout=0)#

Initialize internal Module state, shared by both nn.Module and ScriptModule.

forward(x, batch)#

Forward pass of the encoder.

It applies two linear layers with GraphNorm, Relu activation function, and dropout between the two layers.

Parameters:
xtorch.Tensor

Input tensor of dimensions [N, in_channels].

batchtorch.Tensor

The batch vector which assigns each element to a specific example.

Returns:
torch.Tensor

Output tensor of shape [N, out_channels].

class DGMStructureFeatureEncoder(in_channels, out_channels, proj_dropout=0, selected_dimensions=None, **kwargs)#

Bases: AbstractFeatureEncoder

Encoder class to apply BaseEncoder.

The BaseEncoder is applied to the features of higher order structures. The class creates a BaseEncoder for each dimension specified in selected_dimensions. Then during the forward pass, the BaseEncoders are applied to the features of the corresponding dimensions.

Parameters:
in_channelslist[int]

Input dimensions for the features.

out_channelslist[int]

Output dimensions for the features.

proj_dropoutfloat, optional

Dropout for the BaseEncoders (default: 0).

selected_dimensionslist[int], optional

List of indexes to apply the BaseEncoders to (default: None).

**kwargsdict, optional

Additional keyword arguments.

__init__(in_channels, out_channels, proj_dropout=0, selected_dimensions=None, **kwargs)#

Initialize internal Module state, shared by both nn.Module and ScriptModule.

forward(data)#

Forward pass.

The method applies the BaseEncoders to the features of the selected_dimensions.

Parameters:
datatorch_geometric.data.Data

Input data object which should contain x_{i} features for each i in the selected_dimensions.

Returns:
torch_geometric.data.Data

Output data object with updated x_{i} features.

class DGM_d(base_enc, embed_f, k=5, distance='euclidean', sparse=True)#

Bases: Module

Distance Graph Matching (DGM) neural network module.

This class implements a graph matching technique that learns to sample edges based on distance metrics in either Euclidean or Hyperbolic space.

Parameters:
base_encnn.Module

Base encoder for transforming input features.

embed_fnn.Module

Embedding function for further feature transformation.

kint, optional

Number of edges to sample in each graph. Defaults to 5.

distancestr, optional

Distance metric to use for edge sampling. Choices are ‘euclidean’ or ‘hyperbolic’. Defaults to ‘euclidean’.

sparsebool, optional

Flag to indicate sparse sampling strategy. Defaults to True.

__init__(base_enc, embed_f, k=5, distance='euclidean', sparse=True)#

Initialize internal Module state, shared by both nn.Module and ScriptModule.

forward(x, batch, fixedges=None)#

Forward pass of the Distance Graph Matching module.

Parameters:
xtorch.Tensor

Input tensor containing node features.

batchtorch.Tensor

Batch information for graph-level processing.

fixedgestorch.Tensor, optional

Predefined edges to use instead of sampling. Defaults to None.

Returns:
tuple

A tuple containing four elements: - base_encoded_features (torch.Tensor) - embedded_features (torch.Tensor) - sampled_edges (torch.Tensor) - edge_sampling_log_probabilities (torch.Tensor)

sample_without_replacement(logits)#

Sample edges without replacement using a temperature-scaled Gumbel-top-k method.

Parameters:
logitstorch.Tensor

Input logits representing edge weights or distances. Shape should be (n, n) where n is the number of nodes.

Returns:
tuple

A tuple containing two elements: - edges (torch.Tensor): Sampled edges without replacement - logprobs (torch.Tensor): Log probabilities of the sampled edges