topobench.nn.encoders.kdgm module#

KDGM module.

class DGM_d(base_enc, embed_f, k=5, distance='euclidean', sparse=True)#

Bases: Module

Distance Graph Matching (DGM) neural network module.

This class implements a graph matching technique that learns to sample edges based on distance metrics in either Euclidean or Hyperbolic space.

Parameters:
base_encnn.Module

Base encoder for transforming input features.

embed_fnn.Module

Embedding function for further feature transformation.

kint, optional

Number of edges to sample in each graph. Defaults to 5.

distancestr, optional

Distance metric to use for edge sampling. Choices are ‘euclidean’ or ‘hyperbolic’. Defaults to ‘euclidean’.

sparsebool, optional

Flag to indicate sparse sampling strategy. Defaults to True.

__init__(base_enc, embed_f, k=5, distance='euclidean', sparse=True)#

Initialize internal Module state, shared by both nn.Module and ScriptModule.

forward(x, batch, fixedges=None)#

Forward pass of the Distance Graph Matching module.

Parameters:
xtorch.Tensor

Input tensor containing node features.

batchtorch.Tensor

Batch information for graph-level processing.

fixedgestorch.Tensor, optional

Predefined edges to use instead of sampling. Defaults to None.

Returns:
tuple

A tuple containing four elements: - base_encoded_features (torch.Tensor) - embedded_features (torch.Tensor) - sampled_edges (torch.Tensor) - edge_sampling_log_probabilities (torch.Tensor)

sample_without_replacement(logits)#

Sample edges without replacement using a temperature-scaled Gumbel-top-k method.

Parameters:
logitstorch.Tensor

Input logits representing edge weights or distances. Shape should be (n, n) where n is the number of nodes.

Returns:
tuple

A tuple containing two elements: - edges (torch.Tensor): Sampled edges without replacement - logprobs (torch.Tensor): Log probabilities of the sampled edges

pairwise_euclidean_distances(x, dim=-1)#

Compute pairwise Euclidean distances between points in a tensor.

Parameters:
xtorch.Tensor

Input tensor of points. Each row represents a point in a multidimensional space.

dimint, optional

Dimension along which to compute the squared distances. Defaults to -1 (last dimension).

Returns:
tuple

A tuple containing two elements: - dist (torch.Tensor): Squared pairwise Euclidean distances matrix - x (torch.Tensor): The original input tensor

pairwise_poincare_distances(x, dim=-1)#

Compute pairwise distances in the Poincarè disk model (Hyperbolic space).

Parameters:
xtorch.Tensor

Input tensor of points. Each row represents a point in a multidimensional space.

dimint, optional

Dimension along which to compute the squared distances. Defaults to -1 (last dimension).

Returns:
tuple

A tuple containing two elements: - dist (torch.Tensor): Squared pairwise hyperbolic distances matrix - x (torch.Tensor): Normalized input tensor in the Poincarè disk