topobench.nn.readouts.identical module#
Readout layer that does not perform any operation on the node embeddings.
- class AbstractZeroCellReadOut(hidden_dim, out_channels, task_level, pooling_type='sum', logits_linear_layer=True, **kwargs)#
Bases:
ModuleReadout layer for GNNs that operates on the batch level.
- Parameters:
- hidden_dimint
Hidden dimension of the GNN model.
- out_channelsint
Number of output channels.
- task_levelstr
Task level for readout layer. Either “graph” or “node”.
- pooling_typestr
Pooling type for readout layer. Either “max”, “sum” or “mean”.
- logits_linear_layerbool
Whether to use a linear layer for getting the final logits.
- **kwargsdict
Additional arguments.
- __init__(hidden_dim, out_channels, task_level, pooling_type='sum', logits_linear_layer=True, **kwargs)#
Initialize internal Module state, shared by both nn.Module and ScriptModule.
- compute_logits(x, batch)#
Compute logits based on the readout layer.
- Parameters:
- xtorch.Tensor
Node embeddings.
- batchtorch.Tensor
Batch index tensor.
- Returns:
- torch.Tensor
Logits tensor.
- abstract forward(model_out, batch)#
Forward pass.
- Parameters:
- model_outdict
Dictionary containing the model output.
- batchtorch_geometric.data.Data
Batch object containing the batched domain data.
- class NoReadOut(**kwargs)#
Bases:
AbstractZeroCellReadOutNo readout layer.
This readout layer does not perform any operation on the node embeddings.
- Parameters:
- **kwargsdict, optional
Additional keyword arguments.
- __init__(**kwargs)#
Initialize internal Module state, shared by both nn.Module and ScriptModule.
- forward(model_out, batch)#
Forward pass of the no readout layer.
It returns the model output without any modification.
- Parameters:
- model_outdict
Dictionary containing the model output.
- batchtorch_geometric.data.Data
Batch object containing the batched domain data.
- Returns:
- dict
Dictionary containing the model output.