topobench.nn.backbones.combinatorial.hopse module#
HOPSE model.
- class HOPSE(in_channels, hidden_channels, update_func=None, complex_dim=3, max_hop=3, n_layers=2, layer_norm=True)#
Bases:
ModuleHOPSE model.
- Parameters:
- in_channelstuple of int or int
Dimension of the hidden layers.
- hidden_channelsint
Dimension of the output layer.
- update_funcstr
Update function.
- complex_dimint
Dimension of the complex.
- max_hopint
Number of hops.
- n_layersint
Number of layers.
- layer_normbool, optional
Wether to perform layer normalization.
- __init__(in_channels, hidden_channels, update_func=None, complex_dim=3, max_hop=3, n_layers=2, layer_norm=True)#
Initialize internal Module state, shared by both nn.Module and ScriptModule.
- forward(x)#
Forward pass of the model.
- Parameters:
- xtuple(tuple(torch.Tensor))
Tuple of tuple containing the input tensors for each simplex.
- Returns:
- tuple(tuple(torch.Tensor))
Tuple of tuples of final hidden state tensors.
- class HOPSELayer(in_channels, out_channels, max_hop, aggr_norm=True, update_func=None, initialization='xavier_uniform', layer_norm=True)#
Bases:
ModuleOne layer in the HOPSE model.
- Parameters:
- in_channelsint
Number of input channels.
- out_channelsint
Number of output channels.
- max_hopint
Number of hop representations to consider.
- aggr_normbool
Whether to perform aggregation normalization.
- update_funcstr
Update function.
- initializationstr
Initialization method.
- layer_normbool, optional
Whether to apply layer normalization (default: True).
- Returns:
- torch.Tensor
Output
- __init__(in_channels, out_channels, max_hop, aggr_norm=True, update_func=None, initialization='xavier_uniform', layer_norm=True)#
Initialize internal Module state, shared by both nn.Module and ScriptModule.
- forward(x_all)#
Forward computation.
- Parameters:
- x_allDict[Int, torch.Tensor]
Dictionary of tensors where each simplex dimension (node, edge, face) represents a key, 0-indexed.
- Returns:
- torch.Tensor
Output tensors for each 0-cell.
- torch.Tensor
Output tensors for each 1-cell.
- torch.Tensor
Output tensors for each 2-cell.
- update(x)#
Update embeddings on each cell (step 4).
- Parameters:
- xtorch.Tensor
Input tensor.
- Returns:
- torch.Tensor
Updated tensor.