topobench.nn.readouts package#
This module contains the readout classes that are used by the library with automated exports.
- class AbstractZeroCellReadOut(hidden_dim, out_channels, task_level, pooling_type='sum', logits_linear_layer=True, **kwargs)#
Bases:
ModuleReadout layer for GNNs that operates on the batch level.
- Parameters:
- hidden_dimint
Hidden dimension of the GNN model.
- out_channelsint
Number of output channels.
- task_levelstr
Task level for readout layer. Either “graph” or “node”.
- pooling_typestr
Pooling type for readout layer. Either “max”, “sum” or “mean”.
- logits_linear_layerbool
Whether to use a linear layer for getting the final logits.
- **kwargsdict
Additional arguments.
- __init__(hidden_dim, out_channels, task_level, pooling_type='sum', logits_linear_layer=True, **kwargs)#
Initialize internal Module state, shared by both nn.Module and ScriptModule.
- compute_logits(x, batch)#
Compute logits based on the readout layer.
- Parameters:
- xtorch.Tensor
Node embeddings.
- batchtorch.Tensor
Batch index tensor.
- Returns:
- torch.Tensor
Logits tensor.
- abstract forward(model_out, batch)#
Forward pass.
- Parameters:
- model_outdict
Dictionary containing the model output.
- batchtorch_geometric.data.Data
Batch object containing the batched domain data.
- class MLPReadout(in_channels, hidden_layers, out_channels, pooling_type='sum', dropout=0.25, norm=None, norm_kwargs=None, act='relu', act_kwargs=None, final_act=None, final_act_kwargs=None, task_level=None, **kwargs)#
Bases:
MLPMLP-based Readout over 0-cells (i.e. nodes).
This class implements a readout layer for graph neural networks, allowing for customizable MLP layers and pooling strategies.
- Parameters:
- in_channelsint
The dimensionality of the input features.
- hidden_layersint
The dimensionality of the hidden MLP layers.
- out_channelsint
The dimensionality of the output features.
- pooling_typestr
Pooling type for readout layer. Either “max”, “sum” or “mean”.
- dropoutfloat, optional
The dropout rate (default 0.25).
- normstr, optional
The normalization layer to use (default None).
- norm_kwargsdict, optional
Additional keyword arguments for the normalization layer (default None).
- actstr, optional
The activation function to use (default “relu”).
- act_kwargsdict, optional
Additional keyword arguments for the activation function (default None).
- final_actstr, optional
The final activation function to use (default “sigmoid”).
- final_act_kwargsdict, optional
Additional keyword arguments for the final activation function (default None).
- task_levelstr
Task level for readout layer. Either “graph” or “node”.
- **kwargs
Additional keyword arguments.
- __init__(in_channels, hidden_layers, out_channels, pooling_type='sum', dropout=0.25, norm=None, norm_kwargs=None, act='relu', act_kwargs=None, final_act=None, final_act_kwargs=None, task_level=None, **kwargs)#
Initialize internal Module state, shared by both nn.Module and ScriptModule.
- forward(model_out, batch)#
Readout logic based on model_output.
- Parameters:
- model_outdict
Dictionary containing the model output.
- batchtorch_geometric.data.Data
Batch object containing the batched domain data.
- Returns:
- dict
Dictionary containing the updated model output.
- class NoReadOut(**kwargs)#
Bases:
AbstractZeroCellReadOutNo readout layer.
This readout layer does not perform any operation on the node embeddings.
- Parameters:
- **kwargsdict, optional
Additional keyword arguments.
- __init__(**kwargs)#
Initialize internal Module state, shared by both nn.Module and ScriptModule.
- forward(model_out, batch)#
Forward pass of the no readout layer.
It returns the model output without any modification.
- Parameters:
- model_outdict
Dictionary containing the model output.
- batchtorch_geometric.data.Data
Batch object containing the batched domain data.
- Returns:
- dict
Dictionary containing the model output.
- class PropagateSignalDown(**kwargs)#
Bases:
AbstractZeroCellReadOutPropagate signal down readout layer.
This readout layer propagates the signal from cells of a certain order to the cells of the lower order.
- Parameters:
- **kwargsdict
Additional keyword arguments. It should contain the following keys: - num_cell_dimensions (int): Highest order of cells considered by the model. - self.hidden_dim (int): Dimension of the cells representations. - readout_name (str): Readout name.
- __init__(**kwargs)#
Initialize internal Module state, shared by both nn.Module and ScriptModule.
- forward(model_out, batch)#
Forward pass of the propagate signal down readout layer.
The layer takes the embeddings of the cells of a certain order and applies a convolutional layer to them. Layer normalization is then applied to the features. The output is concatenated with the initial embeddings of the cells and the result is projected with the use of a linear layer to the dimensions of the cells of lower rank. The process is repeated until the nodes embeddings, which are the cells of rank 0, are reached.
- Parameters:
- model_outdict
Dictionary containing the model output.
- batchtorch_geometric.data.Data
Batch object containing the batched domain data.
- Returns:
- dict
Dictionary containing the updated model output.