topobench.nn.backbones.graph.nsd_utils.sheaf_models module#

class LocalConcatSheafLearner(in_channels, out_shape, sheaf_act='tanh')#

Bases: SheafLearner

Sheaf learner that concatenates source and target node features.

This learner computes sheaf parameters by concatenating the features of connected nodes and passing them through a linear layer with activation.

Parameters:
in_channelsint

Number of input channels per node.

out_shapetuple of int

Shape of output sheaf parameters. Should be (d,) for diagonal sheaf or (d, d) for general sheaf.

sheaf_actstr, optional

Activation function to apply. Options are ‘id’, ‘tanh’, or ‘elu’. Default is ‘tanh’.

Raises:
ValueError

If sheaf_act is not one of the supported activation functions.

__init__(in_channels, out_shape, sheaf_act='tanh')#

Initialize internal Module state, shared by both nn.Module and ScriptModule.

forward(x, edge_index)#

Compute sheaf parameters from concatenated node features.

Parameters:
xtorch.Tensor

Node feature matrix of shape [num_nodes, in_channels].

edge_indextorch.Tensor

Edge indices of shape [2, num_edges].

Returns:
torch.Tensor

Sheaf parameters of shape [num_edges, *out_shape].

class SheafLearner#

Bases: Module

Base model that learns a sheaf from the features and the graph structure.

This abstract class provides the interface for learning sheaf structures, including storing the learned Laplacian.

__init__()#

Initialize internal Module state, shared by both nn.Module and ScriptModule.

abstract forward(x, edge_index)#

Learn sheaf structure from node features and graph structure.

Parameters:
xtorch.Tensor

Node feature matrix.

edge_indextorch.Tensor

Edge indices of the graph.

Returns:
torch.Tensor

Learned sheaf parameters.

Raises:
NotImplementedError

This is an abstract method that must be implemented by subclasses.

set_L(weights)#

Store the learned Laplacian weights.

Parameters:
weightstorch.Tensor

Laplacian weights to store.

Returns:
None

None.

abstractmethod(funcobj)#

A decorator indicating abstract methods.

Requires that the metaclass is ABCMeta or derived from it. A class that has a metaclass derived from ABCMeta cannot be instantiated unless all of its abstract methods are overridden. The abstract methods can be called using any of the normal ‘super’ call mechanisms. abstractmethod() may be used to declare abstract methods for properties and descriptors.

Usage:

class C(metaclass=ABCMeta):

@abstractmethod def my_abstract_method(self, …):