topobench.nn.readouts.mlp_readout module#

MLP-based Readout over 0-cells (i.e. nodes).

MLPBackbone#

alias of MLP

class MLPReadout(in_channels, hidden_layers, out_channels, pooling_type='sum', dropout=0.25, norm=None, norm_kwargs=None, act='relu', act_kwargs=None, final_act=None, final_act_kwargs=None, task_level=None, **kwargs)#

Bases: MLP

MLP-based Readout over 0-cells (i.e. nodes).

This class implements a readout layer for graph neural networks, allowing for customizable MLP layers and pooling strategies.

Parameters:
in_channelsint

The dimensionality of the input features.

hidden_layersint

The dimensionality of the hidden MLP layers.

out_channelsint

The dimensionality of the output features.

pooling_typestr

Pooling type for readout layer. Either “max”, “sum” or “mean”.

dropoutfloat, optional

The dropout rate (default 0.25).

normstr, optional

The normalization layer to use (default None).

norm_kwargsdict, optional

Additional keyword arguments for the normalization layer (default None).

actstr, optional

The activation function to use (default “relu”).

act_kwargsdict, optional

Additional keyword arguments for the activation function (default None).

final_actstr, optional

The final activation function to use (default “sigmoid”).

final_act_kwargsdict, optional

Additional keyword arguments for the final activation function (default None).

task_levelstr

Task level for readout layer. Either “graph” or “node”.

**kwargs

Additional keyword arguments.

__init__(in_channels, hidden_layers, out_channels, pooling_type='sum', dropout=0.25, norm=None, norm_kwargs=None, act='relu', act_kwargs=None, final_act=None, final_act_kwargs=None, task_level=None, **kwargs)#

Initialize internal Module state, shared by both nn.Module and ScriptModule.

forward(model_out, batch)#

Readout logic based on model_output.

Parameters:
model_outdict

Dictionary containing the model output.

batchtorch_geometric.data.Data

Batch object containing the batched domain data.

Returns:
dict

Dictionary containing the updated model output.

scatter(src, index, dim=0, dim_size=None, reduce='sum')#

Reduces all values from the src tensor at the indices specified in the index tensor along a given dimension dim. See the documentation # noqa: E501 of the torch_scatter package for more information.

Parameters:
  • src (torch.Tensor) – The source tensor.

  • index (torch.Tensor) – The index tensor.

  • dim (int, optional) – The dimension along which to index. (default: 0)

  • dim_size (int, optional) – The size of the output tensor at dimension dim. If set to None, will create a minimal-sized output tensor according to index.max() + 1. (default: None)

  • reduce (str, optional) – The reduce operation ("sum", "mean", "mul", "min", "max" or "any"). (default: "sum")