topobench.nn.backbones.graph.nsd_utils.inductive_discrete_models module#

Inductive Neural Sheaf Diffusion models.

This module implements three variants of inductive sheaf diffusion: - Diagonal: Diagonal restriction maps - Bundle: Orthogonal restriction maps with normalization - General: Full matrix restriction maps

class DiagLaplacianBuilder(size, edge_index, d)#

Bases: LaplacianBuilder

Builder for sheaf Laplacian with diagonal restriction maps.

This builder constructs a sheaf Laplacian where the restriction maps are diagonal matrices, parameterized by d values per edge.

Parameters:
sizeint

Number of nodes in the graph.

edge_indextorch.Tensor

Edge indices of shape [2, num_edges].

dint

Dimension of the diagonal stalk space.

__init__(size, edge_index, d)#

Initialize internal Module state, shared by both nn.Module and ScriptModule.

forward(maps)#

Build the sheaf Laplacian from diagonal restriction maps.

Parameters:
mapstorch.Tensor

Diagonal restriction map parameters of shape [num_edges, d].

Returns:
Ltuple of torch.Tensor

Sparse Laplacian representation as (indices, values).

saved_tril_mapstorch.Tensor

Saved lower triangular restriction maps for analysis.

class GeneralLaplacianBuilder(size, edge_index, d, augmented=True)#

Bases: LaplacianBuilder

Builder for general sheaf Laplacian with full matrix restriction maps.

This builder constructs a sheaf Laplacian where the restriction maps are arbitrary d x d matrices learned from data.

Parameters:
sizeint

Number of nodes in the graph.

edge_indextorch.Tensor

Edge indices of shape [2, num_edges].

dint

Dimension of the stalk space.

augmentedbool, optional

Whether to use augmented representation (not currently used). Default is True.

__init__(size, edge_index, d, augmented=True)#

Initialize internal Module state, shared by both nn.Module and ScriptModule.

forward(maps)#

Build the sheaf Laplacian from general restriction maps.

Parameters:
mapstorch.Tensor

General restriction map matrices of shape [num_edges, d, d].

Returns:
Ltuple of torch.Tensor

Sparse Laplacian representation as (indices, values).

saved_tril_mapstorch.Tensor

Saved lower triangular transport maps for analysis.

class InductiveDiscreteBundleSheafDiffusion(config)#

Bases: SheafDiffusion

Inductive sheaf diffusion with orthogonal bundle restriction maps.

This model learns orthogonal d x d restriction maps for each edge, ensuring isometric transport between stalks. Uses normalized Laplacian and Cayley/matrix exponential parameterization for orthogonality.

Parameters:
configdict

Configuration dictionary containing: - d (int): Dimension of stalk space (must be > 1). - layers (int): Number of diffusion layers. - hidden_channels (int): Hidden channels per stalk dimension. - input_dim (int): Input feature dimension. - output_dim (int): Output feature dimension. - device (str): Device to run on. - input_dropout (float): Input layer dropout rate. - dropout (float): Hidden layer dropout rate. - sheaf_act (str): Activation for sheaf learning. - orth (str): Orthogonalization method (‘cayley’ or ‘matrix_exp’).

Raises:
AssertionError

If d is not greater than 1 or hidden_dim is not divisible by d.

__init__(config)#

Initialize internal Module state, shared by both nn.Module and ScriptModule.

forward(x, edge_index)#

Forward pass of bundle sheaf diffusion.

Parameters:
xtorch.Tensor

Node feature matrix of shape [num_nodes, input_dim].

edge_indextorch.Tensor

Edge indices of shape [2, num_edges].

Returns:
torch.Tensor

Output node features of shape [num_nodes, output_dim].

get_param_size()#

Get the number of parameters needed for orthogonal maps.

Returns:
int

Number of parameters (d*(d+1)/2 for lower triangular parameterization).

left_right_linear(x, left, right, actual_num_nodes)#

Apply left and right linear transformations to stalk vectors.

Parameters:
xtorch.Tensor

Input tensor of shape [num_nodes * d, hidden_channels].

leftnn.Linear

Left linear transformation (acts on stalk dimension).

rightnn.Linear

Right linear transformation (acts on hidden channels).

actual_num_nodesint

Number of nodes in the current graph.

Returns:
torch.Tensor

Transformed tensor of shape [num_nodes * d, hidden_channels].

class InductiveDiscreteDiagSheafDiffusion(config)#

Bases: SheafDiffusion

Inductive sheaf diffusion with diagonal restriction maps.

This model learns diagonal d x d restriction maps for each edge, parameterized by d scalar values. Suitable for problems where feature channels can be processed independently.

Parameters:
configdict

Configuration dictionary containing: - d (int): Dimension of stalk space (must be > 0). - layers (int): Number of diffusion layers. - hidden_channels (int): Hidden channels per stalk dimension. - input_dim (int): Input feature dimension. - output_dim (int): Output feature dimension. - device (str): Device to run on. - input_dropout (float): Input layer dropout rate. - dropout (float): Hidden layer dropout rate. - sheaf_act (str): Activation for sheaf learning.

__init__(config)#

Initialize internal Module state, shared by both nn.Module and ScriptModule.

forward(x, edge_index)#

Forward pass of diagonal sheaf diffusion.

Parameters:
xtorch.Tensor

Node feature matrix of shape [num_nodes, input_dim].

edge_indextorch.Tensor

Edge indices of shape [2, num_edges].

Returns:
torch.Tensor

Output node features of shape [num_nodes, output_dim].

class InductiveDiscreteGeneralSheafDiffusion(config)#

Bases: SheafDiffusion

Inductive sheaf diffusion with general (unrestricted) restriction maps.

This model learns arbitrary d x d restriction maps for each edge, providing maximum expressiveness but requiring more parameters. Each restriction map is a full d x d matrix.

Parameters:
configdict

Configuration dictionary containing: - d (int): Dimension of stalk space (must be > 1). - layers (int): Number of diffusion layers. - hidden_channels (int): Hidden channels per stalk dimension. - input_dim (int): Input feature dimension. - output_dim (int): Output feature dimension. - device (str): Device to run on. - input_dropout (float): Input layer dropout rate. - dropout (float): Hidden layer dropout rate. - sheaf_act (str): Activation for sheaf learning.

Raises:
AssertionError

If d is not greater than 1.

__init__(config)#

Initialize internal Module state, shared by both nn.Module and ScriptModule.

forward(x, edge_index)#

Forward pass of general sheaf diffusion.

Parameters:
xtorch.Tensor

Node feature matrix of shape [num_nodes, input_dim].

edge_indextorch.Tensor

Edge indices of shape [2, num_edges].

Returns:
torch.Tensor

Output node features of shape [num_nodes, output_dim].

left_right_linear(x, left, right, actual_num_nodes)#

Apply left and right linear transformations to stalk vectors.

Parameters:
xtorch.Tensor

Input tensor of shape [num_nodes * d, hidden_channels].

leftnn.Linear

Left linear transformation (acts on stalk dimension).

rightnn.Linear

Right linear transformation (acts on hidden channels).

actual_num_nodesint

Number of nodes in the current graph.

Returns:
torch.Tensor

Transformed tensor of shape [num_nodes * d, hidden_channels].

class LocalConcatSheafLearner(in_channels, out_shape, sheaf_act='tanh')#

Bases: SheafLearner

Sheaf learner that concatenates source and target node features.

This learner computes sheaf parameters by concatenating the features of connected nodes and passing them through a linear layer with activation.

Parameters:
in_channelsint

Number of input channels per node.

out_shapetuple of int

Shape of output sheaf parameters. Should be (d,) for diagonal sheaf or (d, d) for general sheaf.

sheaf_actstr, optional

Activation function to apply. Options are ‘id’, ‘tanh’, or ‘elu’. Default is ‘tanh’.

Raises:
ValueError

If sheaf_act is not one of the supported activation functions.

__init__(in_channels, out_shape, sheaf_act='tanh')#

Initialize internal Module state, shared by both nn.Module and ScriptModule.

forward(x, edge_index)#

Compute sheaf parameters from concatenated node features.

Parameters:
xtorch.Tensor

Node feature matrix of shape [num_nodes, in_channels].

edge_indextorch.Tensor

Edge indices of shape [2, num_edges].

Returns:
torch.Tensor

Sheaf parameters of shape [num_edges, *out_shape].

class NormConnectionLaplacianBuilder(size, edge_index, d, orth_map=None)#

Bases: LaplacianBuilder

Builder for normalized bundle sheaf Laplacian with orthogonal restriction maps.

This builder constructs a normalized sheaf Laplacian where the restriction maps are orthogonal matrices parameterized via Cayley transform or matrix exponential. Used for bundle sheaf models.

Parameters:
sizeint

Number of nodes in the graph.

edge_indextorch.Tensor

Edge indices of shape [2, num_edges].

dint

Dimension of the stalk space.

orth_mapstr or None, optional

Method for orthogonalization (‘cayley’ or ‘matrix_exp’). Default is None.

__init__(size, edge_index, d, orth_map=None)#

Initialize internal Module state, shared by both nn.Module and ScriptModule.

forward(map_params)#

Build the normalized sheaf Laplacian from orthogonal restriction maps.

Parameters:
map_paramstorch.Tensor

Orthogonal map parameters of shape [num_edges, d*(d+1)/2].

Returns:
Ltuple of torch.Tensor

Sparse normalized Laplacian representation as (indices, values).

saved_tril_mapstorch.Tensor

Saved lower triangular transport maps for analysis.

class SheafDiffusion(edge_index, args)#

Bases: Module

Base class for sheaf diffusion models.

This class provides the foundational structure for all sheaf diffusion variants, storing common parameters and configurations.

Parameters:
edge_indextorch.Tensor or None

Edge indices of shape [2, num_edges]. Can be None for inductive models.

argsdict

Configuration dictionary containing: - d (int): Dimension of the stalk space (must be > 0). - hidden_channels (int): Number of hidden channels per stalk dimension. - device (str): Device to run the model on. - layers (int): Number of diffusion layers. - input_dropout (float): Dropout rate for input layer. - dropout (float): Dropout rate for hidden layers. - input_dim (int): Dimension of input features. - output_dim (int): Dimension of output features. - sheaf_act (str): Activation function for sheaf learning. - orth (str): Orthogonalization method.

__init__(edge_index, args)#

Initialize internal Module state, shared by both nn.Module and ScriptModule.