topobench.nn.backbones.combinatorial.hopse module#

HOPSE model.

class HOPSE(in_channels, hidden_channels, update_func=None, complex_dim=3, max_hop=3, n_layers=2, layer_norm=True)#

Bases: Module

HOPSE model.

Parameters:
in_channelstuple of int or int

Dimension of the hidden layers.

hidden_channelsint

Dimension of the output layer.

update_funcstr

Update function.

complex_dimint

Dimension of the complex.

max_hopint

Number of hops.

n_layersint

Number of layers.

layer_normbool, optional

Wether to perform layer normalization.

__init__(in_channels, hidden_channels, update_func=None, complex_dim=3, max_hop=3, n_layers=2, layer_norm=True)#

Initialize internal Module state, shared by both nn.Module and ScriptModule.

forward(x)#

Forward pass of the model.

Parameters:
xtuple(tuple(torch.Tensor))

Tuple of tuple containing the input tensors for each simplex.

Returns:
tuple(tuple(torch.Tensor))

Tuple of tuples of final hidden state tensors.

class HOPSELayer(in_channels, out_channels, max_hop, aggr_norm=True, update_func=None, initialization='xavier_uniform', layer_norm=True)#

Bases: Module

One layer in the HOPSE model.

Parameters:
in_channelsint

Number of input channels.

out_channelsint

Number of output channels.

max_hopint

Number of hop representations to consider.

aggr_normbool

Whether to perform aggregation normalization.

update_funcstr

Update function.

initializationstr

Initialization method.

layer_normbool, optional

Whether to apply layer normalization (default: True).

Returns:
torch.Tensor

Output

__init__(in_channels, out_channels, max_hop, aggr_norm=True, update_func=None, initialization='xavier_uniform', layer_norm=True)#

Initialize internal Module state, shared by both nn.Module and ScriptModule.

forward(x_all)#

Forward computation.

Parameters:
x_allDict[Int, torch.Tensor]

Dictionary of tensors where each simplex dimension (node, edge, face) represents a key, 0-indexed.

Returns:
torch.Tensor

Output tensors for each 0-cell.

torch.Tensor

Output tensors for each 1-cell.

torch.Tensor

Output tensors for each 2-cell.

update(x)#

Update embeddings on each cell (step 4).

Parameters:
xtorch.Tensor

Input tensor.

Returns:
torch.Tensor

Updated tensor.