topobench.nn.backbones.graph.gps module#

This module implements a GPS-based model[1] that can be used with the training framework.

GPS combines local message passing with global attention mechanisms. Uses the official PyTorch Geometric GPSConv implementation.

[1] Rampášek, Ladislav, et al. “Recipe for a general, powerful, scalable graph transformer.” Advances in Neural Information Processing Systems 35 (2022): 14501-14515.

class Any(*args, **kwargs)#

Bases: object

Special type indicating an unconstrained type.

  • Any is compatible with every type.

  • Any assumed to have all methods.

  • All values assumed to be instances of Any.

Note that all the above statements are true from the point of view of static type checkers. At runtime, Any should not be used with instance checks.

class GINConv(nn, eps=0.0, train_eps=False, **kwargs)#

Bases: MessagePassing

The graph isomorphism operator from the “How Powerful are Graph Neural Networks?” paper.

\[\mathbf{x}^{\prime}_i = h_{\mathbf{\Theta}} \left( (1 + \epsilon) \cdot \mathbf{x}_i + \sum_{j \in \mathcal{N}(i)} \mathbf{x}_j \right)\]

or

\[\mathbf{X}^{\prime} = h_{\mathbf{\Theta}} \left( \left( \mathbf{A} + (1 + \epsilon) \cdot \mathbf{I} \right) \cdot \mathbf{X} \right),\]

here \(h_{\mathbf{\Theta}}\) denotes a neural network, .i.e. an MLP.

Parameters:
  • nn (torch.nn.Module) – A neural network \(h_{\mathbf{\Theta}}\) that maps node features x of shape [-1, in_channels] to shape [-1, out_channels], e.g., defined by torch.nn.Sequential.

  • eps (float, optional) – (Initial) \(\epsilon\)-value. (default: 0.)

  • train_eps (bool, optional) – If set to True, \(\epsilon\) will be a trainable parameter. (default: False)

  • **kwargs (optional) – Additional arguments of torch_geometric.nn.conv.MessagePassing.

Shapes:
  • input: node features \((|\mathcal{V}|, F_{in})\) or \(((|\mathcal{V_s}|, F_{s}), (|\mathcal{V_t}|, F_{t}))\) if bipartite, edge indices \((2, |\mathcal{E}|)\)

  • output: node features \((|\mathcal{V}|, F_{out})\) or \((|\mathcal{V}_t|, F_{out})\) if bipartite

__init__(nn, eps=0.0, train_eps=False, **kwargs)#

Initialize internal Module state, shared by both nn.Module and ScriptModule.

forward(x, edge_index, size=None)#

Runs the forward pass of the module.

message(x_j)#

Constructs messages from node \(j\) to node \(i\) in analogy to \(\phi_{\mathbf{\Theta}}\) for each edge in edge_index. This function can take any argument as input which was initially passed to propagate(). Furthermore, tensors passed to propagate() can be mapped to the respective nodes \(i\) and \(j\) by appending _i or _j to the variable name, .e.g. x_i and x_j.

message_and_aggregate(adj_t, x)#

Fuses computations of message() and aggregate() into a single function. If applicable, this saves both time and memory since messages do not explicitly need to be materialized. This function will only gets called in case it is implemented and propagation takes place based on a torch_sparse.SparseTensor or a torch.sparse.Tensor.

reset_parameters()#

Resets all learnable parameters of the module.

class GPSConv(channels, conv, heads=1, dropout=0.0, act='relu', act_kwargs=None, norm='batch_norm', norm_kwargs=None, attn_type='multihead', attn_kwargs=None)#

Bases: Module

The general, powerful, scalable (GPS) graph transformer layer from the “Recipe for a General, Powerful, Scalable Graph Transformer” paper.

The GPS layer is based on a 3-part recipe:

  1. Inclusion of positional (PE) and structural encodings (SE) to the input features (done in a pre-processing step via torch_geometric.transforms).

  2. A local message passing layer (MPNN) that operates on the input graph.

  3. A global attention layer that operates on the entire graph.

Note

For an example of using GPSConv, see examples/graph_gps.py.

Parameters:
  • channels (int) – Size of each input sample.

  • conv (MessagePassing, optional) – The local message passing layer.

  • heads (int, optional) – Number of multi-head-attentions. (default: 1)

  • dropout (float, optional) – Dropout probability of intermediate embeddings. (default: 0.)

  • act (str or Callable, optional) – The non-linear activation function to use. (default: "relu")

  • act_kwargs (Dict[str, Any], optional) – Arguments passed to the respective activation function defined by act. (default: None)

  • norm (str or Callable, optional) – The normalization function to use. (default: "batch_norm")

  • norm_kwargs (Dict[str, Any], optional) – Arguments passed to the respective normalization function defined by norm. (default: None)

  • attn_type (str) – Global attention type, multihead or performer. (default: multihead)

  • attn_kwargs (Dict[str, Any], optional) – Arguments passed to the attention layer. (default: None)

__init__(channels, conv, heads=1, dropout=0.0, act='relu', act_kwargs=None, norm='batch_norm', norm_kwargs=None, attn_type='multihead', attn_kwargs=None)#

Initialize internal Module state, shared by both nn.Module and ScriptModule.

forward(x, edge_index, batch=None, **kwargs)#

Runs the forward pass of the module.

reset_parameters()#

Resets all learnable parameters of the module.

class GPSEncoder(input_dim, hidden_dim, num_layers=4, heads=4, dropout=0.1, attn_type='multihead', local_conv_type='gin', use_edge_attr=False, redraw_interval=None, attn_kwargs=None)#

Bases: Module

GPS Encoder that can be used with the training framework.

Uses the official PyTorch Geometric GPSConv implementation. This encoder combines local message passing with global attention mechanisms for powerful graph representation learning.

Parameters:
input_dimint

Dimension of input node features.

hidden_dimint

Dimension of hidden layers.

num_layersint, optional

Number of GPS layers. Default is 4.

headsint, optional

Number of attention heads in GPSConv layers. Default is 4.

dropoutfloat, optional

Dropout rate for GPSConv layers. Default is 0.1.

attn_typestr, optional

Type of attention mechanism to use. Options are ‘multihead’, ‘performer’, etc. Default is ‘multihead’.

local_conv_typestr, optional

Type of local message passing layer. Options are ‘gin’, ‘pna’, etc. Default is ‘gin’.

use_edge_attrbool, optional

Whether to use edge attributes in GPSConv layers. Default is False.

redraw_intervalint or None, optional

Interval for redrawing random projections in Performer attention. If None, projections are not redrawn. Default is None.

attn_kwargsdict, optional

Additional keyword arguments for the attention mechanism.

__init__(input_dim, hidden_dim, num_layers=4, heads=4, dropout=0.1, attn_type='multihead', local_conv_type='gin', use_edge_attr=False, redraw_interval=None, attn_kwargs=None)#

Initialize internal Module state, shared by both nn.Module and ScriptModule.

forward(x, edge_index, batch=None, edge_attr=None, **kwargs)#

Forward pass of GPS encoder.

Parameters:
xtorch.Tensor

Node feature matrix of shape [num_nodes, input_dim].

edge_indextorch.Tensor

Edge indices of shape [2, num_edges].

batchtorch.Tensor, optional

Batch vector assigning each node to a specific graph. Shape [num_nodes]. Default is None.

edge_attrtorch.Tensor, optional

Edge feature matrix of shape [num_edges, edge_dim]. Default is None.

**kwargsdict

Additional arguments (not used).

Returns:
torch.Tensor

Output node feature matrix of shape [num_nodes, hidden_dim].

class PNAConv(in_channels, out_channels, aggregators, scalers, deg, edge_dim=None, towers=1, pre_layers=1, post_layers=1, divide_input=False, act='relu', act_kwargs=None, train_norm=False, **kwargs)#

Bases: MessagePassing

The Principal Neighbourhood Aggregation graph convolution operator from the “Principal Neighbourhood Aggregation for Graph Nets” paper.

\[\mathbf{x}_i^{\prime} = \gamma_{\mathbf{\Theta}} \left( \mathbf{x}_i, \underset{j \in \mathcal{N}(i)}{\bigoplus} h_{\mathbf{\Theta}} \left( \mathbf{x}_i, \mathbf{x}_j \right) \right)\]

with

\[\begin{split}\bigoplus = \underbrace{\begin{bmatrix} 1 \\ S(\mathbf{D}, \alpha=1) \\ S(\mathbf{D}, \alpha=-1) \end{bmatrix} }_{\text{scalers}} \otimes \underbrace{\begin{bmatrix} \mu \\ \sigma \\ \max \\ \min \end{bmatrix}}_{\text{aggregators}},\end{split}\]

where \(\gamma_{\mathbf{\Theta}}\) and \(h_{\mathbf{\Theta}}\) denote MLPs.

Note

For an example of using PNAConv, see examples/pna.py.

Parameters:
  • in_channels (int) – Size of each input sample, or -1 to derive the size from the first input(s) to the forward method.

  • out_channels (int) – Size of each output sample.

  • aggregators (List[str]) – Set of aggregation function identifiers, namely "sum", "mean", "min", "max", "var" and "std".

  • scalers (List[str]) – Set of scaling function identifiers, namely "identity", "amplification", "attenuation", "linear" and "inverse_linear".

  • deg (torch.Tensor) – Histogram of in-degrees of nodes in the training set, used by scalers to normalize.

  • edge_dim (int, optional) – Edge feature dimensionality (in case there are any). (default None)

  • towers (int, optional) – Number of towers (default: 1).

  • pre_layers (int, optional) – Number of transformation layers before aggregation (default: 1).

  • post_layers (int, optional) – Number of transformation layers after aggregation (default: 1).

  • divide_input (bool, optional) – Whether the input features should be split between towers or not (default: False).

  • act (str or callable, optional) – Pre- and post-layer activation function to use. (default: "relu")

  • act_kwargs (Dict[str, Any], optional) – Arguments passed to the respective activation function defined by act. (default: None)

  • train_norm (bool, optional) – Whether normalization parameters are trainable. (default: False)

  • **kwargs (optional) – Additional arguments of torch_geometric.nn.conv.MessagePassing.

Shapes:
  • input: node features \((|\mathcal{V}|, F_{in})\), edge indices \((2, |\mathcal{E}|)\), edge features \((|\mathcal{E}|, D)\) (optional)

  • output: node features \((|\mathcal{V}|, F_{out})\)

__init__(in_channels, out_channels, aggregators, scalers, deg, edge_dim=None, towers=1, pre_layers=1, post_layers=1, divide_input=False, act='relu', act_kwargs=None, train_norm=False, **kwargs)#

Initialize internal Module state, shared by both nn.Module and ScriptModule.

forward(x, edge_index, edge_attr=None)#

Runs the forward pass of the module.

static get_degree_histogram(loader)#

Returns the degree histogram to be used as input for the deg argument in PNAConv.

message(x_i, x_j, edge_attr)#

Constructs messages from node \(j\) to node \(i\) in analogy to \(\phi_{\mathbf{\Theta}}\) for each edge in edge_index. This function can take any argument as input which was initially passed to propagate(). Furthermore, tensors passed to propagate() can be mapped to the respective nodes \(i\) and \(j\) by appending _i or _j to the variable name, .e.g. x_i and x_j.

reset_parameters()#

Resets all learnable parameters of the module.

class PerformerAttention(channels, heads, head_channels=64, kernel=ReLU(), qkv_bias=False, attn_out_bias=True, dropout=0.0)#

Bases: Module

The linear scaled attention mechanism from the “Rethinking Attention with Performers” paper.

Parameters:
  • channels (int) – Size of each input sample.

  • heads (int, optional) – Number of parallel attention heads.

  • head_channels (int, optional) – Size of each attention head. (default: 64.)

  • kernel (Callable, optional) – Kernels for generalized attention. If not specified, ReLU kernel will be used. (default: torch.nn.ReLU())

  • qkv_bias (bool, optional) – If specified, add bias to query, key and value in the self attention. (default: False)

  • attn_out_bias (bool, optional) – If specified, add bias to the attention output. (default: True)

  • dropout (float, optional) – Dropout probability of the final attention output. (default: 0.0)

__init__(channels, heads, head_channels=64, kernel=ReLU(), qkv_bias=False, attn_out_bias=True, dropout=0.0)#

Initialize internal Module state, shared by both nn.Module and ScriptModule.

forward(x, mask=None)#

Forward pass.

Parameters:
  • x (torch.Tensor) – Node feature tensor \(\mathbf{X} \in \mathbb{R}^{B \times N \times F}\), with batch-size \(B\), (maximum) number of nodes \(N\) for each graph, and feature dimension \(F\).

  • mask (torch.Tensor, optional) – Mask matrix \(\mathbf{M} \in {\{ 0, 1 \}}^{B \times N}\) indicating the valid nodes for each graph. (default: None)

redraw_projection_matrix()#

As described in the paper, periodically redraw examples to improve overall approximation of attention.

class RedrawProjection(model, redraw_interval=None)#

Bases: object

Helper class to handle redrawing of random projections in Performer attention.

This is crucial for maintaining the quality of the random feature approximation.

Parameters:
modeltorch.nn.Module

The model containing PerformerAttention modules.

redraw_intervalint or None, optional

Interval for redrawing random projections. If None, projections are not redrawn. Default is None.

__init__(model, redraw_interval=None)#
redraw_projections()#

Redraw random projections in PerformerAttention modules if needed.

Returns:
None

None.