topobench.nn.readouts.identical module#

Readout layer that does not perform any operation on the node embeddings.

class AbstractZeroCellReadOut(hidden_dim, out_channels, task_level, pooling_type='sum', logits_linear_layer=True, **kwargs)#

Bases: Module

Readout layer for GNNs that operates on the batch level.

Parameters:
hidden_dimint

Hidden dimension of the GNN model.

out_channelsint

Number of output channels.

task_levelstr

Task level for readout layer. Either “graph” or “node”.

pooling_typestr

Pooling type for readout layer. Either “max”, “sum” or “mean”.

logits_linear_layerbool

Whether to use a linear layer for getting the final logits.

**kwargsdict

Additional arguments.

__init__(hidden_dim, out_channels, task_level, pooling_type='sum', logits_linear_layer=True, **kwargs)#

Initialize internal Module state, shared by both nn.Module and ScriptModule.

compute_logits(x, batch)#

Compute logits based on the readout layer.

Parameters:
xtorch.Tensor

Node embeddings.

batchtorch.Tensor

Batch index tensor.

Returns:
torch.Tensor

Logits tensor.

abstract forward(model_out, batch)#

Forward pass.

Parameters:
model_outdict

Dictionary containing the model output.

batchtorch_geometric.data.Data

Batch object containing the batched domain data.

class NoReadOut(**kwargs)#

Bases: AbstractZeroCellReadOut

No readout layer.

This readout layer does not perform any operation on the node embeddings.

Parameters:
**kwargsdict, optional

Additional keyword arguments.

__init__(**kwargs)#

Initialize internal Module state, shared by both nn.Module and ScriptModule.

forward(model_out, batch)#

Forward pass of the no readout layer.

It returns the model output without any modification.

Parameters:
model_outdict

Dictionary containing the model output.

batchtorch_geometric.data.Data

Batch object containing the batched domain data.

Returns:
dict

Dictionary containing the model output.