topobench.nn.wrappers.combinatorial package#
Wrappers for combinatorial neural networks with automated exports.
- class HOPSEWrapper(backbone, **kwargs)#
Bases:
AbstractWrapperWrapper for the HOPSE.
- Parameters:
- backbonetorch.nn.Module
Backbone model.
- **kwargsdict
Additional arguments for the class. It should contain the following keys: - out_channels (int): Number of output channels. - num_cell_dimensions (int): Number of cell dimensions.
- __init__(backbone, **kwargs)#
- forward(batch)#
Forward pass of the HOPSE.
- Parameters:
- batchDict
Dictionary containing the batched domain data.
- Returns:
- dict
Dictionary containing the model output.
- class TuneWrapper(backbone, **kwargs)#
Bases:
AbstractWrapperWrapper for the TopoTune model.
This wrapper defines the forward pass of the TopoTune model.
- forward(batch)#
Forward pass for the Tune wrapper.
- Parameters:
- batchtorch_geometric.data.Data
Batch object containing the batched data.
- Returns:
- dict
Dictionary containing the updated model output.