topobench.nn.readouts.base module#
Abstract base class for readout layers.
- class AbstractZeroCellReadOut(hidden_dim, out_channels, task_level, pooling_type='sum', logits_linear_layer=True, **kwargs)#
Bases:
ModuleReadout layer for GNNs that operates on the batch level.
- Parameters:
- hidden_dimint
Hidden dimension of the GNN model.
- out_channelsint
Number of output channels.
- task_levelstr
Task level for readout layer. Either “graph” or “node”.
- pooling_typestr
Pooling type for readout layer. Either “max”, “sum” or “mean”.
- logits_linear_layerbool
Whether to use a linear layer for getting the final logits.
- **kwargsdict
Additional arguments.
- __init__(hidden_dim, out_channels, task_level, pooling_type='sum', logits_linear_layer=True, **kwargs)#
Initialize internal Module state, shared by both nn.Module and ScriptModule.
- compute_logits(x, batch)#
Compute logits based on the readout layer.
- Parameters:
- xtorch.Tensor
Node embeddings.
- batchtorch.Tensor
Batch index tensor.
- Returns:
- torch.Tensor
Logits tensor.
- abstract forward(model_out, batch)#
Forward pass.
- Parameters:
- model_outdict
Dictionary containing the model output.
- batchtorch_geometric.data.Data
Batch object containing the batched domain data.
- abstractmethod(funcobj)#
A decorator indicating abstract methods.
Requires that the metaclass is ABCMeta or derived from it. A class that has a metaclass derived from ABCMeta cannot be instantiated unless all of its abstract methods are overridden. The abstract methods can be called using any of the normal ‘super’ call mechanisms. abstractmethod() may be used to declare abstract methods for properties and descriptors.
Usage:
- class C(metaclass=ABCMeta):
@abstractmethod def my_abstract_method(self, …):
…
- scatter(src, index, dim=0, dim_size=None, reduce='sum')#
Reduces all values from the
srctensor at the indices specified in theindextensor along a given dimensiondim. See the documentation # noqa: E501 of thetorch_scatterpackage for more information.- Parameters:
src (torch.Tensor) – The source tensor.
index (torch.Tensor) – The index tensor.
dim (int, optional) – The dimension along which to index. (default:
0)dim_size (int, optional) – The size of the output tensor at dimension
dim. If set toNone, will create a minimal-sized output tensor according toindex.max() + 1. (default:None)reduce (str, optional) – The reduce operation (
"sum","mean","mul","min","max"or"any"). (default:"sum")