topobench.nn.readouts.hopse module#

Readout function for the HOPSE model.

class AbstractZeroCellReadOut(hidden_dim, out_channels, task_level, pooling_type='sum', logits_linear_layer=True, **kwargs)#

Bases: Module

Readout layer for GNNs that operates on the batch level.

Parameters:
hidden_dimint

Hidden dimension of the GNN model.

out_channelsint

Number of output channels.

task_levelstr

Task level for readout layer. Either “graph” or “node”.

pooling_typestr

Pooling type for readout layer. Either “max”, “sum” or “mean”.

logits_linear_layerbool

Whether to use a linear layer for getting the final logits.

**kwargsdict

Additional arguments.

__init__(hidden_dim, out_channels, task_level, pooling_type='sum', logits_linear_layer=True, **kwargs)#

Initialize internal Module state, shared by both nn.Module and ScriptModule.

compute_logits(x, batch)#

Compute logits based on the readout layer.

Parameters:
xtorch.Tensor

Node embeddings.

batchtorch.Tensor

Batch index tensor.

Returns:
torch.Tensor

Logits tensor.

abstractmethod forward(model_out, batch)#

Forward pass.

Parameters:
model_outdict

Dictionary containing the model output.

batchtorch_geometric.data.Data

Batch object containing the batched domain data.

class HOPSEReadout(**kwargs)#

Bases: AbstractZeroCellReadOut

Readout function for the HOPSE model.

Parameters:
**kwargsdict

Additional keyword arguments. It should contain the following keys: - complex_dim (int): Dimension of the simplicial complex. - max_hop (int): Maximum hop neighbourhood to consider. - hidden_dim_1 (int): Dimension of the embeddings. - out_channels (int): Number of classes. - pooling_type (str): Type of pooling operationg

__init__(**kwargs)#

Initialize internal Module state, shared by both nn.Module and ScriptModule.

compute_logits(x, batch)#

Compute logits based on the readout layer.

Parameters:
xtorch.Tensor

Node embeddings.

batchtorch.Tensor

Batch index tensor.

Returns:
torch.Tensor

Logits tensor.

forward(model_out, batch)#

Readout logic based on model_output.

Parameters:
model_outdict

Dictionary containing the model output.

batchtorch_geometric.data.Data

Batch object containing the batched domain data.

Returns:
dict

Dictionary containing the updated model output.

scatter(src, index, dim=-1, out=None, dim_size=None, reduce='sum')#

https://raw.githubusercontent.com/rusty1s/pytorch_scatter/master/docs/source/_figures/add.svg?sanitize=true

Reduces all values from the src tensor into out at the indices specified in the index tensor along a given axis dim. For each value in src, its output index is specified by its index in src for dimensions outside of dim and by the corresponding value in index for dimension dim. The applied reduction is defined via the reduce argument.

Formally, if src and index are \(n\)-dimensional tensors with size \((x_0, ..., x_{i-1}, x_i, x_{i+1}, ..., x_{n-1})\) and dim = i, then out must be an \(n\)-dimensional tensor with size \((x_0, ..., x_{i-1}, y, x_{i+1}, ..., x_{n-1})\). Moreover, the values of index must be between \(0\) and \(y - 1\), although no specific ordering of indices is required. The index tensor supports broadcasting in case its dimensions do not match with src.

For one-dimensional tensors with reduce="sum", the operation computes

\[\mathrm{out}_i = \mathrm{out}_i + \sum_j~\mathrm{src}_j\]

where \(\sum_j\) is over \(j\) such that \(\mathrm{index}_j = i\).

Note

This operation is implemented via atomic operations on the GPU and is therefore non-deterministic since the order of parallel operations to the same value is undetermined. For floating-point variables, this results in a source of variance in the result.

Parameters:
  • src (Tensor) – The source tensor.

  • index (Tensor) – The indices of elements to scatter.

  • dim (int) – The axis along which to index. (default: -1)

  • out (Tensor | None) – The destination tensor.

  • dim_size (int | None) – If out is not given, automatically create output with size dim_size at dimension dim. If dim_size is not given, a minimal sized output tensor according to index.max() + 1 is returned.

  • reduce (str) – The reduce operation ("sum", "mul", "mean", "min" or "max"). (default: "sum")

Return type:

Tensor

from torch_scatter import scatter

src = torch.randn(10, 6, 64)
index = torch.tensor([0, 1, 0, 1, 2, 1])

# Broadcasting in the first and last dim.
out = scatter(src, index, dim=1, reduce="sum")

print(out.size())
torch.Size([10, 3, 64])