topobench.nn.readouts.mlp_readout module#
MLP-based Readout over 0-cells (i.e. nodes).
- class MLPReadout(in_channels, hidden_layers, out_channels, pooling_type='sum', dropout=0.25, norm=None, norm_kwargs=None, act='relu', act_kwargs=None, final_act=None, final_act_kwargs=None, task_level=None, **kwargs)#
Bases:
MLPMLP-based Readout over 0-cells (i.e. nodes).
This class implements a readout layer for graph neural networks, allowing for customizable MLP layers and pooling strategies.
- Parameters:
- in_channelsint
The dimensionality of the input features.
- hidden_layersint
The dimensionality of the hidden MLP layers.
- out_channelsint
The dimensionality of the output features.
- pooling_typestr
Pooling type for readout layer. Either “max”, “sum” or “mean”.
- dropoutfloat, optional
The dropout rate (default 0.25).
- normstr, optional
The normalization layer to use (default None).
- norm_kwargsdict, optional
Additional keyword arguments for the normalization layer (default None).
- actstr, optional
The activation function to use (default “relu”).
- act_kwargsdict, optional
Additional keyword arguments for the activation function (default None).
- final_actstr, optional
The final activation function to use (default “sigmoid”).
- final_act_kwargsdict, optional
Additional keyword arguments for the final activation function (default None).
- task_levelstr
Task level for readout layer. Either “graph” or “node”.
- **kwargs
Additional keyword arguments.
- __init__(in_channels, hidden_layers, out_channels, pooling_type='sum', dropout=0.25, norm=None, norm_kwargs=None, act='relu', act_kwargs=None, final_act=None, final_act_kwargs=None, task_level=None, **kwargs)#
Initialize internal Module state, shared by both nn.Module and ScriptModule.
- forward(model_out, batch)#
Readout logic based on model_output.
- Parameters:
- model_outdict
Dictionary containing the model output.
- batchtorch_geometric.data.Data
Batch object containing the batched domain data.
- Returns:
- dict
Dictionary containing the updated model output.
- scatter(src, index, dim=0, dim_size=None, reduce='sum')#
Reduces all values from the
srctensor at the indices specified in theindextensor along a given dimensiondim. See the documentation # noqa: E501 of thetorch_scatterpackage for more information.- Parameters:
src (torch.Tensor) – The source tensor.
index (torch.Tensor) – The index tensor.
dim (int, optional) – The dimension along which to index. (default:
0)dim_size (int, optional) – The size of the output tensor at dimension
dim. If set toNone, will create a minimal-sized output tensor according toindex.max() + 1. (default:None)reduce (str, optional) – The reduce operation (
"sum","mean","mul","min","max"or"any"). (default:"sum")