topobench.nn.backbones.combinatorial package#

Combinatorial backbones with automated exports.

class TopoTune(GNN, neighborhoods, layers, use_edge_attr, activation)#

Bases: Module

Tunes a GNN model using higher-order relations.

This class takes a GNN and its kwargs as inputs, and tunes it with specified additional relations.

Parameters:
GNNtorch.nn.Module, a class not an object

The GNN class to use. ex: GAT, GCN.

neighborhoodslist of lists

The neighborhoods of interest.

layersint

The number of layers to use. Each layer contains one GNN.

use_edge_attrbool

Whether to use edge attributes.

activationstr

The activation function to use. ex: ‘relu’, ‘tanh’, ‘sigmoid’.

__init__(GNN, neighborhoods, layers, use_edge_attr, activation)#

Initialize internal Module state, shared by both nn.Module and ScriptModule.

aggregate_inter_nbhd(x_out_per_route)#

Aggregate the outputs of the GNN for each rank.

While the GNN takes care of intra-nbhd aggregation, this will take care of inter-nbhd aggregation. Default: sum.

Parameters:
x_out_per_routedict

The outputs of the GNN for each route.

Returns:
dict

The aggregated outputs of the GNN for each rank.

forward(batch)#

Forward pass of the model.

Parameters:
batchComplex or ComplexBatch(Complex)

The input data.

Returns:
dict

The output hidden states of the model per rank.

generate_membership_vectors(batch)#

Generate membership vectors based on batch.cell_statistics.

Parameters:
batchtorch_geometric.data.Data

Batch object containing the batched domain data.

Returns:
dict

The batch membership of the graphs per rank.

get_nbhd_cache(params)#

Cache the nbhd information into a dict for the complex at hand.

Parameters:
paramsdict

The parameters of the batch, containing the complex.

Returns:
dict

The neighborhood cache.

interrank_expand(params, src_rank, dst_rank, nbhd_cache, membership)#

Expand the complex into an interrank Hasse graph.

Parameters:
paramsdict

The parameters of the batch, containting the complex.

src_rankint

The source rank.

dst_rankint

The destination rank.

nbhd_cachedict

The neighborhood cache containing the expanded boundary index and edge attributes.

membershipdict

The batch membership of the graphs per rank.

Returns:
torch_geometric.data.Data

The expanded batch of interrank Hasse graphs for this route.

interrank_gnn_forward(batch_route, layer_idx, route_index, n_dst_cells)#

Forward pass of the GNN (one layer) for an interrank Hasse graph.

Parameters:
batch_routetorch_geometric.data.Data

The batch of interrank Hasse graphs for this route.

layer_idxint

The index of the layer.

route_indexint

The index of the route.

n_dst_cellsint

The number of destination cells in the whole batch.

Returns:
torch.tensor

The output of the GNN (updated features).

intrarank_expand(params, src_rank, nbhd)#

Expand the complex into an intrarank Hasse graph.

Parameters:
paramsdict

The parameters of the batch, containting the complex.

src_rankint

The source rank.

nbhdstr

The neighborhood to use.

Returns:
torch_geometric.data.Data

The expanded batch of intrarank Hasse graphs for this route.

intrarank_gnn_forward(batch_route, layer_idx, route_index)#

Forward pass of the GNN (one layer) for an intrarank Hasse graph.

Parameters:
batch_routetorch_geometric.data.Data

The batch of intrarank Hasse graphs for this route.

layer_idxint

The index of the TopoTune layer.

route_indexint

The index of the route.

Returns:
torch.tensor

The output of the GNN (updated features).

class TopoTune_OneHasse(GNN, neighborhoods, layers, use_edge_attr, activation)#

Bases: Module

Tunes a GNN model using higher-order relations.

This class takes a GNN and its kwargs as inputs, and tunes it with specified additional relations. Unlike the case of TopoTune, this class expects a single Hasse graph as input, where all higher-order neighborhoods are represented as a single adjacency matrix.

Parameters:
GNNtorch.nn.Module, a class not an object

The GNN class to use. ex: GAT, GCN.

neighborhoodslist of lists

The neighborhoods of interest.

layersint

The number of layers to use. Each layer contains one GNN.

use_edge_attrbool

Whether to use edge attributes.

activationstr

The activation function to use. ex: ‘relu’, ‘tanh’, ‘sigmoid’.

__init__(GNN, neighborhoods, layers, use_edge_attr, activation)#

Initialize internal Module state, shared by both nn.Module and ScriptModule.

aggregate_inter_nbhd(x_out)#

Aggregate the outputs of the GNN for each rank.

While the GNN takes care of intra-nbhd aggregation, this will take care of inter-nbhd aggregation. Default: sum.

Parameters:
x_outtorch.tensor

The output of the GNN, concatenated features of each rank.

Returns:
dict

The aggregated outputs of the GNN for each rank.

all_nbhds_expand(params, membership)#

Expand the complex into a single Hasse graph which contains all ranks and all nbhd.

Parameters:
paramsdict

The parameters of the batch, containing the complex.

membershipdict

The batch membership of the graphs per rank.

Returns:
torch_geometric.data.Data

The expanded Hasse graph.

all_nbhds_gnn_forward(batch_route, layer_idx)#

Forward pass of the GNN (one layer) for an intrarank Hasse graph.

Parameters:
batch_routetorch_geometric.data.Data

The batch of intrarank Hasse graphs for this route.

layer_idxint

The index of the TopoTune layer.

Returns:
torch.tensor

The output of the GNN (updated features).

forward(batch)#

Forward pass of the model.

Parameters:
batchComplex or ComplexBatch(Complex)

The input data.

Returns:
dict

The output hidden states of the model per rank.

generate_membership_vectors(batch)#

Generate membership vectors based on batch.cell_statistics.

Parameters:
batchtorch_geometric.data.Data

Batch object containing the batched domain data.

Returns:
dict

The batch membership of the graphs per rank.

Submodules#