topobench.nn.backbones.combinatorial.gccn module#
Define the TopoTune class, which, given a choice of hyperparameters, instantiates a GCCN expecting a collection of strictly augmented Hasse graphs as input.
- class Data(x=None, edge_index=None, edge_attr=None, y=None, pos=None, time=None, **kwargs)#
Bases:
BaseData,FeatureStore,GraphStoreA data object describing a homogeneous graph. The data object can hold node-level, link-level and graph-level attributes. In general,
Datatries to mimic the behavior of a regular :python:`Python` dictionary. In addition, it provides useful functionality for analyzing graph structures, and provides basic PyTorch tensor functionalities. See here for the accompanying tutorial.from torch_geometric.data import Data data = Data(x=x, edge_index=edge_index, ...) # Add additional arguments to `data`: data.train_idx = torch.tensor([...], dtype=torch.long) data.test_mask = torch.tensor([...], dtype=torch.bool) # Analyzing the graph structure: data.num_nodes >>> 23 data.is_directed() >>> False # PyTorch tensor functionality: data = data.pin_memory() data = data.to('cuda:0', non_blocking=True)
- Parameters:
x (torch.Tensor, optional) – Node feature matrix with shape
[num_nodes, num_node_features]. (default:None)edge_index (LongTensor, optional) – Graph connectivity in COO format with shape
[2, num_edges]. (default:None)edge_attr (torch.Tensor, optional) – Edge feature matrix with shape
[num_edges, num_edge_features]. (default:None)y (torch.Tensor, optional) – Graph-level or node-level ground-truth labels with arbitrary shape. (default:
None)pos (torch.Tensor, optional) – Node position matrix with shape
[num_nodes, num_dimensions]. (default:None)time (torch.Tensor, optional) – The timestamps for each event with shape
[num_edges]or[num_nodes]. (default:None)**kwargs (optional) – Additional attributes.
- __init__(x=None, edge_index=None, edge_attr=None, y=None, pos=None, time=None, **kwargs)#
- connected_components()#
Extracts connected components of the graph using a union-find algorithm. The components are returned as a list of
Dataobjects, where each object represents a connected component of the graph.data = Data() data.x = torch.tensor([[1.0], [2.0], [3.0], [4.0]]) data.y = torch.tensor([[1.1], [2.1], [3.1], [4.1]]) data.edge_index = torch.tensor( [[0, 1, 2, 3], [1, 0, 3, 2]], dtype=torch.long ) components = data.connected_components() print(len(components)) >>> 2 print(components[0].x) >>> Data(x=[2, 1], y=[2, 1], edge_index=[2, 2])
- Returns:
A list of disconnected components.
- Return type:
List[Data]
- debug()#
- edge_subgraph(subset)#
Returns the induced subgraph given by the edge indices
subset. Will currently preserve all the nodes in the graph, even if they are isolated after subgraph computation.- Parameters:
subset (LongTensor or BoolTensor) – The edges to keep.
- classmethod from_dict(mapping)#
Creates a
Dataobject from a dictionary.
- get_all_edge_attrs()#
Returns all registered edge attributes.
- get_all_tensor_attrs()#
Obtains all feature attributes stored in Data.
- stores_as(data)#
- subgraph(subset)#
Returns the induced subgraph given by the node indices
subset.- Parameters:
subset (LongTensor or BoolTensor) – The nodes to keep.
- to_dict()#
Returns a dictionary of stored key/value pairs.
- to_heterogeneous(node_type=None, edge_type=None, node_type_names=None, edge_type_names=None)#
Converts a
Dataobject to a heterogeneousHeteroDataobject. For this, node and edge attributes are splitted according to the node-level and edge-level vectorsnode_typeandedge_type, respectively.node_type_namesandedge_type_namescan be used to give meaningful node and edge type names, respectively. That is, the node_type0is given bynode_type_names[0]. If theDataobject was constructed viato_homogeneous(), the object can be reconstructed without any need to pass in additional arguments.- Parameters:
node_type (torch.Tensor, optional) – A node-level vector denoting the type of each node. (default:
None)edge_type (torch.Tensor, optional) – An edge-level vector denoting the type of each edge. (default:
None)node_type_names (List[str], optional) – The names of node types. (default:
None)edge_type_names (List[Tuple[str, str, str]], optional) – The names of edge types. (default:
None)
- to_namedtuple()#
Returns a
NamedTupleof stored key/value pairs.
- update(data)#
Updates the data object with the elements from another data object. Added elements will override existing ones (in case of duplicates).
- validate(raise_on_error=True)#
Validates the correctness of the data.
- property num_features: int#
Returns the number of features per node in the graph. Alias for
num_node_features.
- property num_nodes: int | None#
Returns the number of nodes in the graph.
Note
The number of nodes in the data object is automatically inferred in case node-level attributes are present, e.g.,
data.x. In some cases, however, a graph may only be given without any node-level attributes. :pyg:`PyG` then guesses the number of nodes according toedge_index.max().item() + 1. However, in case there exists isolated nodes, this number does not have to be correct which can result in unexpected behavior. Thus, we recommend to set the number of nodes in your data object explicitly viadata.num_nodes = .... You will be given a warning that requests you to do so.
- class TopoTune(GNN, neighborhoods, layers, use_edge_attr, activation)#
Bases:
ModuleTunes a GNN model using higher-order relations.
This class takes a GNN and its kwargs as inputs, and tunes it with specified additional relations.
- Parameters:
- GNNtorch.nn.Module, a class not an object
The GNN class to use. ex: GAT, GCN.
- neighborhoodslist of lists
The neighborhoods of interest.
- layersint
The number of layers to use. Each layer contains one GNN.
- use_edge_attrbool
Whether to use edge attributes.
- activationstr
The activation function to use. ex: ‘relu’, ‘tanh’, ‘sigmoid’.
- __init__(GNN, neighborhoods, layers, use_edge_attr, activation)#
Initialize internal Module state, shared by both nn.Module and ScriptModule.
- aggregate_inter_nbhd(x_out_per_route)#
Aggregate the outputs of the GNN for each rank.
While the GNN takes care of intra-nbhd aggregation, this will take care of inter-nbhd aggregation. Default: sum.
- Parameters:
- x_out_per_routedict
The outputs of the GNN for each route.
- Returns:
- dict
The aggregated outputs of the GNN for each rank.
- forward(batch)#
Forward pass of the model.
- Parameters:
- batchComplex or ComplexBatch(Complex)
The input data.
- Returns:
- dict
The output hidden states of the model per rank.
- generate_membership_vectors(batch)#
Generate membership vectors based on batch.cell_statistics.
- Parameters:
- batchtorch_geometric.data.Data
Batch object containing the batched domain data.
- Returns:
- dict
The batch membership of the graphs per rank.
- get_nbhd_cache(params)#
Cache the nbhd information into a dict for the complex at hand.
- Parameters:
- paramsdict
The parameters of the batch, containing the complex.
- Returns:
- dict
The neighborhood cache.
- interrank_expand(params, src_rank, dst_rank, nbhd_cache, membership)#
Expand the complex into an interrank Hasse graph.
- Parameters:
- paramsdict
The parameters of the batch, containting the complex.
- src_rankint
The source rank.
- dst_rankint
The destination rank.
- nbhd_cachedict
The neighborhood cache containing the expanded boundary index and edge attributes.
- membershipdict
The batch membership of the graphs per rank.
- Returns:
- torch_geometric.data.Data
The expanded batch of interrank Hasse graphs for this route.
- interrank_gnn_forward(batch_route, layer_idx, route_index, n_dst_cells)#
Forward pass of the GNN (one layer) for an interrank Hasse graph.
- Parameters:
- batch_routetorch_geometric.data.Data
The batch of interrank Hasse graphs for this route.
- layer_idxint
The index of the layer.
- route_indexint
The index of the route.
- n_dst_cellsint
The number of destination cells in the whole batch.
- Returns:
- torch.tensor
The output of the GNN (updated features).
- intrarank_expand(params, src_rank, nbhd)#
Expand the complex into an intrarank Hasse graph.
- Parameters:
- paramsdict
The parameters of the batch, containting the complex.
- src_rankint
The source rank.
- nbhdstr
The neighborhood to use.
- Returns:
- torch_geometric.data.Data
The expanded batch of intrarank Hasse graphs for this route.
- intrarank_gnn_forward(batch_route, layer_idx, route_index)#
Forward pass of the GNN (one layer) for an intrarank Hasse graph.
- Parameters:
- batch_routetorch_geometric.data.Data
The batch of intrarank Hasse graphs for this route.
- layer_idxint
The index of the TopoTune layer.
- route_indexint
The index of the route.
- Returns:
- torch.tensor
The output of the GNN (updated features).
- get_activation(nonlinearity, return_module=False)#
Activation resolver from CWN.
- Parameters:
- nonlinearitystr
The nonlinearity to use.
- return_modulebool
Whether to return the module or the function.
- Returns:
- module or function
The module or the function.
- get_routes_from_neighborhoods(neighborhoods)#
Get the routes from the neighborhoods.
Combination of src_rank, dst_rank. ex: [[0, 0], [1, 0], [1, 1], [1, 1], [2, 1]].
- Parameters:
- neighborhoodslist
List of neighborhoods of interest.
- Returns:
- list
List of routes.
- interrank_boundary_index(x_src, boundary_index, n_dst_nodes)#
Recover lifted graph.
Edge-to-node boundary relationships of a graph with n_nodes and n_edges can be represented as up-adjacency node relations. There are n_nodes+n_edges nodes in this lifted graph. Desgiend to work for regular (edge-to-node and face-to-edge) boundary relationships.
- Parameters:
- x_srctorch.tensor
Source node features. Shape [n_src_nodes, n_features]. Should represent edge or face features.
- boundary_indexlist of lists or list of tensors
List boundary_index[0] stores node ids in the boundary of edge stored in boundary_index[1]. List boundary_index[1] stores list of edges.
- n_dst_nodesint
Number of destination nodes.
- Returns:
- edge_indexlist of lists
The edge_index[0][i] and edge_index[1][i] are the two nodes of edge i.
- edge_attrtensor
Edge features are given by feature of bounding node represnting an edge. Shape [n_edges, n_features].