topobench.nn.wrappers.combinatorial package#

Wrappers for combinatorial neural networks with automated exports.

class HOPSEWrapper(backbone, **kwargs)#

Bases: AbstractWrapper

Wrapper for the HOPSE.

Parameters:
backbonetorch.nn.Module

Backbone model.

**kwargsdict

Additional arguments for the class. It should contain the following keys: - out_channels (int): Number of output channels. - num_cell_dimensions (int): Number of cell dimensions.

__init__(backbone, **kwargs)#
forward(batch)#

Forward pass of the HOPSE.

Parameters:
batchDict

Dictionary containing the batched domain data.

Returns:
dict

Dictionary containing the model output.

class TuneWrapper(backbone, **kwargs)#

Bases: AbstractWrapper

Wrapper for the TopoTune model.

This wrapper defines the forward pass of the TopoTune model.

forward(batch)#

Forward pass for the Tune wrapper.

Parameters:
batchtorch_geometric.data.Data

Batch object containing the batched data.

Returns:
dict

Dictionary containing the updated model output.

Submodules#