topobench.nn.readouts.base module#

Abstract base class for readout layers.

class AbstractZeroCellReadOut(hidden_dim, out_channels, task_level, pooling_type='sum', logits_linear_layer=True, **kwargs)#

Bases: Module

Readout layer for GNNs that operates on the batch level.

Parameters:
hidden_dimint

Hidden dimension of the GNN model.

out_channelsint

Number of output channels.

task_levelstr

Task level for readout layer. Either “graph” or “node”.

pooling_typestr

Pooling type for readout layer. Either “max”, “sum” or “mean”.

logits_linear_layerbool

Whether to use a linear layer for getting the final logits.

**kwargsdict

Additional arguments.

__init__(hidden_dim, out_channels, task_level, pooling_type='sum', logits_linear_layer=True, **kwargs)#

Initialize internal Module state, shared by both nn.Module and ScriptModule.

compute_logits(x, batch)#

Compute logits based on the readout layer.

Parameters:
xtorch.Tensor

Node embeddings.

batchtorch.Tensor

Batch index tensor.

Returns:
torch.Tensor

Logits tensor.

abstract forward(model_out, batch)#

Forward pass.

Parameters:
model_outdict

Dictionary containing the model output.

batchtorch_geometric.data.Data

Batch object containing the batched domain data.

abstractmethod(funcobj)#

A decorator indicating abstract methods.

Requires that the metaclass is ABCMeta or derived from it. A class that has a metaclass derived from ABCMeta cannot be instantiated unless all of its abstract methods are overridden. The abstract methods can be called using any of the normal ‘super’ call mechanisms. abstractmethod() may be used to declare abstract methods for properties and descriptors.

Usage:

class C(metaclass=ABCMeta):

@abstractmethod def my_abstract_method(self, …):

scatter(src, index, dim=0, dim_size=None, reduce='sum')#

Reduces all values from the src tensor at the indices specified in the index tensor along a given dimension dim. See the documentation # noqa: E501 of the torch_scatter package for more information.

Parameters:
  • src (torch.Tensor) – The source tensor.

  • index (torch.Tensor) – The index tensor.

  • dim (int, optional) – The dimension along which to index. (default: 0)

  • dim_size (int, optional) – The size of the output tensor at dimension dim. If set to None, will create a minimal-sized output tensor according to index.max() + 1. (default: None)

  • reduce (str, optional) – The reduce operation ("sum", "mean", "mul", "min", "max" or "any"). (default: "sum")