topobench.nn.backbones.combinatorial package#
Combinatorial backbones with automated exports.
- class TopoTune(GNN, neighborhoods, layers, use_edge_attr, activation)#
Bases:
ModuleTunes a GNN model using higher-order relations.
This class takes a GNN and its kwargs as inputs, and tunes it with specified additional relations.
- Parameters:
- GNNtorch.nn.Module, a class not an object
The GNN class to use. ex: GAT, GCN.
- neighborhoodslist of lists
The neighborhoods of interest.
- layersint
The number of layers to use. Each layer contains one GNN.
- use_edge_attrbool
Whether to use edge attributes.
- activationstr
The activation function to use. ex: ‘relu’, ‘tanh’, ‘sigmoid’.
- __init__(GNN, neighborhoods, layers, use_edge_attr, activation)#
Initialize internal Module state, shared by both nn.Module and ScriptModule.
- aggregate_inter_nbhd(x_out_per_route)#
Aggregate the outputs of the GNN for each rank.
While the GNN takes care of intra-nbhd aggregation, this will take care of inter-nbhd aggregation. Default: sum.
- Parameters:
- x_out_per_routedict
The outputs of the GNN for each route.
- Returns:
- dict
The aggregated outputs of the GNN for each rank.
- forward(batch)#
Forward pass of the model.
- Parameters:
- batchComplex or ComplexBatch(Complex)
The input data.
- Returns:
- dict
The output hidden states of the model per rank.
- generate_membership_vectors(batch)#
Generate membership vectors based on batch.cell_statistics.
- Parameters:
- batchtorch_geometric.data.Data
Batch object containing the batched domain data.
- Returns:
- dict
The batch membership of the graphs per rank.
- get_nbhd_cache(params)#
Cache the nbhd information into a dict for the complex at hand.
- Parameters:
- paramsdict
The parameters of the batch, containing the complex.
- Returns:
- dict
The neighborhood cache.
- interrank_expand(params, src_rank, dst_rank, nbhd_cache, membership)#
Expand the complex into an interrank Hasse graph.
- Parameters:
- paramsdict
The parameters of the batch, containting the complex.
- src_rankint
The source rank.
- dst_rankint
The destination rank.
- nbhd_cachedict
The neighborhood cache containing the expanded boundary index and edge attributes.
- membershipdict
The batch membership of the graphs per rank.
- Returns:
- torch_geometric.data.Data
The expanded batch of interrank Hasse graphs for this route.
- interrank_gnn_forward(batch_route, layer_idx, route_index, n_dst_cells)#
Forward pass of the GNN (one layer) for an interrank Hasse graph.
- Parameters:
- batch_routetorch_geometric.data.Data
The batch of interrank Hasse graphs for this route.
- layer_idxint
The index of the layer.
- route_indexint
The index of the route.
- n_dst_cellsint
The number of destination cells in the whole batch.
- Returns:
- torch.tensor
The output of the GNN (updated features).
- intrarank_expand(params, src_rank, nbhd)#
Expand the complex into an intrarank Hasse graph.
- Parameters:
- paramsdict
The parameters of the batch, containting the complex.
- src_rankint
The source rank.
- nbhdstr
The neighborhood to use.
- Returns:
- torch_geometric.data.Data
The expanded batch of intrarank Hasse graphs for this route.
- intrarank_gnn_forward(batch_route, layer_idx, route_index)#
Forward pass of the GNN (one layer) for an intrarank Hasse graph.
- Parameters:
- batch_routetorch_geometric.data.Data
The batch of intrarank Hasse graphs for this route.
- layer_idxint
The index of the TopoTune layer.
- route_indexint
The index of the route.
- Returns:
- torch.tensor
The output of the GNN (updated features).
- class TopoTune_OneHasse(GNN, neighborhoods, layers, use_edge_attr, activation)#
Bases:
ModuleTunes a GNN model using higher-order relations.
This class takes a GNN and its kwargs as inputs, and tunes it with specified additional relations. Unlike the case of TopoTune, this class expects a single Hasse graph as input, where all higher-order neighborhoods are represented as a single adjacency matrix.
- Parameters:
- GNNtorch.nn.Module, a class not an object
The GNN class to use. ex: GAT, GCN.
- neighborhoodslist of lists
The neighborhoods of interest.
- layersint
The number of layers to use. Each layer contains one GNN.
- use_edge_attrbool
Whether to use edge attributes.
- activationstr
The activation function to use. ex: ‘relu’, ‘tanh’, ‘sigmoid’.
- __init__(GNN, neighborhoods, layers, use_edge_attr, activation)#
Initialize internal Module state, shared by both nn.Module and ScriptModule.
- aggregate_inter_nbhd(x_out)#
Aggregate the outputs of the GNN for each rank.
While the GNN takes care of intra-nbhd aggregation, this will take care of inter-nbhd aggregation. Default: sum.
- Parameters:
- x_outtorch.tensor
The output of the GNN, concatenated features of each rank.
- Returns:
- dict
The aggregated outputs of the GNN for each rank.
- all_nbhds_expand(params, membership)#
Expand the complex into a single Hasse graph which contains all ranks and all nbhd.
- Parameters:
- paramsdict
The parameters of the batch, containing the complex.
- membershipdict
The batch membership of the graphs per rank.
- Returns:
- torch_geometric.data.Data
The expanded Hasse graph.
- all_nbhds_gnn_forward(batch_route, layer_idx)#
Forward pass of the GNN (one layer) for an intrarank Hasse graph.
- Parameters:
- batch_routetorch_geometric.data.Data
The batch of intrarank Hasse graphs for this route.
- layer_idxint
The index of the TopoTune layer.
- Returns:
- torch.tensor
The output of the GNN (updated features).
- forward(batch)#
Forward pass of the model.
- Parameters:
- batchComplex or ComplexBatch(Complex)
The input data.
- Returns:
- dict
The output hidden states of the model per rank.
- generate_membership_vectors(batch)#
Generate membership vectors based on batch.cell_statistics.
- Parameters:
- batchtorch_geometric.data.Data
Batch object containing the batched domain data.
- Returns:
- dict
The batch membership of the graphs per rank.
Submodules#
- topobench.nn.backbones.combinatorial.gccn module
DataData.__init__()Data.connected_components()Data.debug()Data.edge_subgraph()Data.from_dict()Data.get_all_edge_attrs()Data.get_all_tensor_attrs()Data.is_edge_attr()Data.is_node_attr()Data.stores_as()Data.subgraph()Data.to_dict()Data.to_heterogeneous()Data.to_namedtuple()Data.update()Data.validate()Data.batchData.edge_attrData.edge_indexData.edge_storesData.edge_weightData.faceData.node_storesData.num_edge_featuresData.num_edge_typesData.num_facesData.num_featuresData.num_node_featuresData.num_node_typesData.num_nodesData.posData.storesData.timeData.xData.y
TopoTuneget_activation()get_routes_from_neighborhoods()interrank_boundary_index()
- topobench.nn.backbones.combinatorial.gccn_onehasse module
DataData.__init__()Data.connected_components()Data.debug()Data.edge_subgraph()Data.from_dict()Data.get_all_edge_attrs()Data.get_all_tensor_attrs()Data.is_edge_attr()Data.is_node_attr()Data.stores_as()Data.subgraph()Data.to_dict()Data.to_heterogeneous()Data.to_namedtuple()Data.update()Data.validate()Data.batchData.edge_attrData.edge_indexData.edge_storesData.edge_weightData.faceData.node_storesData.num_edge_featuresData.num_edge_typesData.num_facesData.num_featuresData.num_node_featuresData.num_node_typesData.num_nodesData.posData.storesData.timeData.xData.y
TopoTune_OneHasseget_activation()get_routes_from_neighborhoods()