topobench.nn.backbones.graph.gps module#
This module implements a GPS-based model[1] that can be used with the training framework.
GPS combines local message passing with global attention mechanisms. Uses the official PyTorch Geometric GPSConv implementation.
[1] Rampášek, Ladislav, et al. “Recipe for a general, powerful, scalable graph transformer.” Advances in Neural Information Processing Systems 35 (2022): 14501-14515.
- class Any(*args, **kwargs)#
Bases:
objectSpecial type indicating an unconstrained type.
Any is compatible with every type.
Any assumed to have all methods.
All values assumed to be instances of Any.
Note that all the above statements are true from the point of view of static type checkers. At runtime, Any should not be used with instance checks.
- class GINConv(nn, eps=0.0, train_eps=False, **kwargs)#
Bases:
MessagePassingThe graph isomorphism operator from the “How Powerful are Graph Neural Networks?” paper.
\[\mathbf{x}^{\prime}_i = h_{\mathbf{\Theta}} \left( (1 + \epsilon) \cdot \mathbf{x}_i + \sum_{j \in \mathcal{N}(i)} \mathbf{x}_j \right)\]or
\[\mathbf{X}^{\prime} = h_{\mathbf{\Theta}} \left( \left( \mathbf{A} + (1 + \epsilon) \cdot \mathbf{I} \right) \cdot \mathbf{X} \right),\]here \(h_{\mathbf{\Theta}}\) denotes a neural network, .i.e. an MLP.
- Parameters:
nn (torch.nn.Module) – A neural network \(h_{\mathbf{\Theta}}\) that maps node features
xof shape[-1, in_channels]to shape[-1, out_channels], e.g., defined bytorch.nn.Sequential.eps (float, optional) – (Initial) \(\epsilon\)-value. (default:
0.)train_eps (bool, optional) – If set to
True, \(\epsilon\) will be a trainable parameter. (default:False)**kwargs (optional) – Additional arguments of
torch_geometric.nn.conv.MessagePassing.
- Shapes:
input: node features \((|\mathcal{V}|, F_{in})\) or \(((|\mathcal{V_s}|, F_{s}), (|\mathcal{V_t}|, F_{t}))\) if bipartite, edge indices \((2, |\mathcal{E}|)\)
output: node features \((|\mathcal{V}|, F_{out})\) or \((|\mathcal{V}_t|, F_{out})\) if bipartite
- __init__(nn, eps=0.0, train_eps=False, **kwargs)#
Initialize internal Module state, shared by both nn.Module and ScriptModule.
- forward(x, edge_index, size=None)#
Runs the forward pass of the module.
- message(x_j)#
Constructs messages from node \(j\) to node \(i\) in analogy to \(\phi_{\mathbf{\Theta}}\) for each edge in
edge_index. This function can take any argument as input which was initially passed topropagate(). Furthermore, tensors passed topropagate()can be mapped to the respective nodes \(i\) and \(j\) by appending_ior_jto the variable name, .e.g.x_iandx_j.
- message_and_aggregate(adj_t, x)#
Fuses computations of
message()andaggregate()into a single function. If applicable, this saves both time and memory since messages do not explicitly need to be materialized. This function will only gets called in case it is implemented and propagation takes place based on atorch_sparse.SparseTensoror atorch.sparse.Tensor.
- reset_parameters()#
Resets all learnable parameters of the module.
- class GPSConv(channels, conv, heads=1, dropout=0.0, act='relu', act_kwargs=None, norm='batch_norm', norm_kwargs=None, attn_type='multihead', attn_kwargs=None)#
Bases:
ModuleThe general, powerful, scalable (GPS) graph transformer layer from the “Recipe for a General, Powerful, Scalable Graph Transformer” paper.
The GPS layer is based on a 3-part recipe:
Inclusion of positional (PE) and structural encodings (SE) to the input features (done in a pre-processing step via
torch_geometric.transforms).A local message passing layer (MPNN) that operates on the input graph.
A global attention layer that operates on the entire graph.
Note
For an example of using
GPSConv, see examples/graph_gps.py.- Parameters:
channels (int) – Size of each input sample.
conv (MessagePassing, optional) – The local message passing layer.
heads (int, optional) – Number of multi-head-attentions. (default:
1)dropout (float, optional) – Dropout probability of intermediate embeddings. (default:
0.)act (str or Callable, optional) – The non-linear activation function to use. (default:
"relu")act_kwargs (Dict[str, Any], optional) – Arguments passed to the respective activation function defined by
act. (default:None)norm (str or Callable, optional) – The normalization function to use. (default:
"batch_norm")norm_kwargs (Dict[str, Any], optional) – Arguments passed to the respective normalization function defined by
norm. (default:None)attn_type (str) – Global attention type,
multiheadorperformer. (default:multihead)attn_kwargs (Dict[str, Any], optional) – Arguments passed to the attention layer. (default:
None)
- __init__(channels, conv, heads=1, dropout=0.0, act='relu', act_kwargs=None, norm='batch_norm', norm_kwargs=None, attn_type='multihead', attn_kwargs=None)#
Initialize internal Module state, shared by both nn.Module and ScriptModule.
- forward(x, edge_index, batch=None, **kwargs)#
Runs the forward pass of the module.
- reset_parameters()#
Resets all learnable parameters of the module.
- class GPSEncoder(input_dim, hidden_dim, num_layers=4, heads=4, dropout=0.1, attn_type='multihead', local_conv_type='gin', use_edge_attr=False, redraw_interval=None, attn_kwargs=None)#
Bases:
ModuleGPS Encoder that can be used with the training framework.
Uses the official PyTorch Geometric GPSConv implementation. This encoder combines local message passing with global attention mechanisms for powerful graph representation learning.
- Parameters:
- input_dimint
Dimension of input node features.
- hidden_dimint
Dimension of hidden layers.
- num_layersint, optional
Number of GPS layers. Default is 4.
- headsint, optional
Number of attention heads in GPSConv layers. Default is 4.
- dropoutfloat, optional
Dropout rate for GPSConv layers. Default is 0.1.
- attn_typestr, optional
Type of attention mechanism to use. Options are ‘multihead’, ‘performer’, etc. Default is ‘multihead’.
- local_conv_typestr, optional
Type of local message passing layer. Options are ‘gin’, ‘pna’, etc. Default is ‘gin’.
- use_edge_attrbool, optional
Whether to use edge attributes in GPSConv layers. Default is False.
- redraw_intervalint or None, optional
Interval for redrawing random projections in Performer attention. If None, projections are not redrawn. Default is None.
- attn_kwargsdict, optional
Additional keyword arguments for the attention mechanism.
- __init__(input_dim, hidden_dim, num_layers=4, heads=4, dropout=0.1, attn_type='multihead', local_conv_type='gin', use_edge_attr=False, redraw_interval=None, attn_kwargs=None)#
Initialize internal Module state, shared by both nn.Module and ScriptModule.
- forward(x, edge_index, batch=None, edge_attr=None, **kwargs)#
Forward pass of GPS encoder.
- Parameters:
- xtorch.Tensor
Node feature matrix of shape [num_nodes, input_dim].
- edge_indextorch.Tensor
Edge indices of shape [2, num_edges].
- batchtorch.Tensor, optional
Batch vector assigning each node to a specific graph. Shape [num_nodes]. Default is None.
- edge_attrtorch.Tensor, optional
Edge feature matrix of shape [num_edges, edge_dim]. Default is None.
- **kwargsdict
Additional arguments (not used).
- Returns:
- torch.Tensor
Output node feature matrix of shape [num_nodes, hidden_dim].
- class PNAConv(in_channels, out_channels, aggregators, scalers, deg, edge_dim=None, towers=1, pre_layers=1, post_layers=1, divide_input=False, act='relu', act_kwargs=None, train_norm=False, **kwargs)#
Bases:
MessagePassingThe Principal Neighbourhood Aggregation graph convolution operator from the “Principal Neighbourhood Aggregation for Graph Nets” paper.
\[\mathbf{x}_i^{\prime} = \gamma_{\mathbf{\Theta}} \left( \mathbf{x}_i, \underset{j \in \mathcal{N}(i)}{\bigoplus} h_{\mathbf{\Theta}} \left( \mathbf{x}_i, \mathbf{x}_j \right) \right)\]with
\[\begin{split}\bigoplus = \underbrace{\begin{bmatrix} 1 \\ S(\mathbf{D}, \alpha=1) \\ S(\mathbf{D}, \alpha=-1) \end{bmatrix} }_{\text{scalers}} \otimes \underbrace{\begin{bmatrix} \mu \\ \sigma \\ \max \\ \min \end{bmatrix}}_{\text{aggregators}},\end{split}\]where \(\gamma_{\mathbf{\Theta}}\) and \(h_{\mathbf{\Theta}}\) denote MLPs.
Note
For an example of using
PNAConv, see examples/pna.py.- Parameters:
in_channels (int) – Size of each input sample, or
-1to derive the size from the first input(s) to the forward method.out_channels (int) – Size of each output sample.
aggregators (List[str]) – Set of aggregation function identifiers, namely
"sum","mean","min","max","var"and"std".scalers (List[str]) – Set of scaling function identifiers, namely
"identity","amplification","attenuation","linear"and"inverse_linear".deg (torch.Tensor) – Histogram of in-degrees of nodes in the training set, used by scalers to normalize.
edge_dim (int, optional) – Edge feature dimensionality (in case there are any). (default
None)towers (int, optional) – Number of towers (default:
1).pre_layers (int, optional) – Number of transformation layers before aggregation (default:
1).post_layers (int, optional) – Number of transformation layers after aggregation (default:
1).divide_input (bool, optional) – Whether the input features should be split between towers or not (default:
False).act (str or callable, optional) – Pre- and post-layer activation function to use. (default:
"relu")act_kwargs (Dict[str, Any], optional) – Arguments passed to the respective activation function defined by
act. (default:None)train_norm (bool, optional) – Whether normalization parameters are trainable. (default:
False)**kwargs (optional) – Additional arguments of
torch_geometric.nn.conv.MessagePassing.
- Shapes:
input: node features \((|\mathcal{V}|, F_{in})\), edge indices \((2, |\mathcal{E}|)\), edge features \((|\mathcal{E}|, D)\) (optional)
output: node features \((|\mathcal{V}|, F_{out})\)
- __init__(in_channels, out_channels, aggregators, scalers, deg, edge_dim=None, towers=1, pre_layers=1, post_layers=1, divide_input=False, act='relu', act_kwargs=None, train_norm=False, **kwargs)#
Initialize internal Module state, shared by both nn.Module and ScriptModule.
- forward(x, edge_index, edge_attr=None)#
Runs the forward pass of the module.
- static get_degree_histogram(loader)#
Returns the degree histogram to be used as input for the
degargument inPNAConv.
- message(x_i, x_j, edge_attr)#
Constructs messages from node \(j\) to node \(i\) in analogy to \(\phi_{\mathbf{\Theta}}\) for each edge in
edge_index. This function can take any argument as input which was initially passed topropagate(). Furthermore, tensors passed topropagate()can be mapped to the respective nodes \(i\) and \(j\) by appending_ior_jto the variable name, .e.g.x_iandx_j.
- reset_parameters()#
Resets all learnable parameters of the module.
- class PerformerAttention(channels, heads, head_channels=64, kernel=ReLU(), qkv_bias=False, attn_out_bias=True, dropout=0.0)#
Bases:
ModuleThe linear scaled attention mechanism from the “Rethinking Attention with Performers” paper.
- Parameters:
channels (int) – Size of each input sample.
heads (int, optional) – Number of parallel attention heads.
head_channels (int, optional) – Size of each attention head. (default:
64.)kernel (Callable, optional) – Kernels for generalized attention. If not specified, ReLU kernel will be used. (default:
torch.nn.ReLU())qkv_bias (bool, optional) – If specified, add bias to query, key and value in the self attention. (default:
False)attn_out_bias (bool, optional) – If specified, add bias to the attention output. (default:
True)dropout (float, optional) – Dropout probability of the final attention output. (default:
0.0)
- __init__(channels, heads, head_channels=64, kernel=ReLU(), qkv_bias=False, attn_out_bias=True, dropout=0.0)#
Initialize internal Module state, shared by both nn.Module and ScriptModule.
- forward(x, mask=None)#
Forward pass.
- Parameters:
x (torch.Tensor) – Node feature tensor \(\mathbf{X} \in \mathbb{R}^{B \times N \times F}\), with batch-size \(B\), (maximum) number of nodes \(N\) for each graph, and feature dimension \(F\).
mask (torch.Tensor, optional) – Mask matrix \(\mathbf{M} \in {\{ 0, 1 \}}^{B \times N}\) indicating the valid nodes for each graph. (default:
None)
- redraw_projection_matrix()#
As described in the paper, periodically redraw examples to improve overall approximation of attention.
- class RedrawProjection(model, redraw_interval=None)#
Bases:
objectHelper class to handle redrawing of random projections in Performer attention.
This is crucial for maintaining the quality of the random feature approximation.
- Parameters:
- modeltorch.nn.Module
The model containing PerformerAttention modules.
- redraw_intervalint or None, optional
Interval for redrawing random projections. If None, projections are not redrawn. Default is None.
- __init__(model, redraw_interval=None)#
- redraw_projections()#
Redraw random projections in PerformerAttention modules if needed.
- Returns:
- None
None.