topobench.nn.readouts.hopse module#
Readout function for the HOPSE model.
- class AbstractZeroCellReadOut(hidden_dim, out_channels, task_level, pooling_type='sum', logits_linear_layer=True, **kwargs)#
Bases:
ModuleReadout layer for GNNs that operates on the batch level.
- Parameters:
- hidden_dimint
Hidden dimension of the GNN model.
- out_channelsint
Number of output channels.
- task_levelstr
Task level for readout layer. Either “graph” or “node”.
- pooling_typestr
Pooling type for readout layer. Either “max”, “sum” or “mean”.
- logits_linear_layerbool
Whether to use a linear layer for getting the final logits.
- **kwargsdict
Additional arguments.
- __init__(hidden_dim, out_channels, task_level, pooling_type='sum', logits_linear_layer=True, **kwargs)#
Initialize internal Module state, shared by both nn.Module and ScriptModule.
- compute_logits(x, batch)#
Compute logits based on the readout layer.
- Parameters:
- xtorch.Tensor
Node embeddings.
- batchtorch.Tensor
Batch index tensor.
- Returns:
- torch.Tensor
Logits tensor.
- abstractmethod forward(model_out, batch)#
Forward pass.
- Parameters:
- model_outdict
Dictionary containing the model output.
- batchtorch_geometric.data.Data
Batch object containing the batched domain data.
- class HOPSEReadout(**kwargs)#
Bases:
AbstractZeroCellReadOutReadout function for the HOPSE model.
- Parameters:
- **kwargsdict
Additional keyword arguments. It should contain the following keys: - complex_dim (int): Dimension of the simplicial complex. - max_hop (int): Maximum hop neighbourhood to consider. - hidden_dim_1 (int): Dimension of the embeddings. - out_channels (int): Number of classes. - pooling_type (str): Type of pooling operationg
- __init__(**kwargs)#
Initialize internal Module state, shared by both nn.Module and ScriptModule.
- compute_logits(x, batch)#
Compute logits based on the readout layer.
- Parameters:
- xtorch.Tensor
Node embeddings.
- batchtorch.Tensor
Batch index tensor.
- Returns:
- torch.Tensor
Logits tensor.
- forward(model_out, batch)#
Readout logic based on model_output.
- Parameters:
- model_outdict
Dictionary containing the model output.
- batchtorch_geometric.data.Data
Batch object containing the batched domain data.
- Returns:
- dict
Dictionary containing the updated model output.
- scatter(src, index, dim=-1, out=None, dim_size=None, reduce='sum')#
Reduces all values from the
srctensor intooutat the indices specified in theindextensor along a given axisdim. For each value insrc, its output index is specified by its index insrcfor dimensions outside ofdimand by the corresponding value inindexfor dimensiondim. The applied reduction is defined via thereduceargument.Formally, if
srcandindexare \(n\)-dimensional tensors with size \((x_0, ..., x_{i-1}, x_i, x_{i+1}, ..., x_{n-1})\) anddim= i, thenoutmust be an \(n\)-dimensional tensor with size \((x_0, ..., x_{i-1}, y, x_{i+1}, ..., x_{n-1})\). Moreover, the values ofindexmust be between \(0\) and \(y - 1\), although no specific ordering of indices is required. Theindextensor supports broadcasting in case its dimensions do not match withsrc.For one-dimensional tensors with
reduce="sum", the operation computes\[\mathrm{out}_i = \mathrm{out}_i + \sum_j~\mathrm{src}_j\]where \(\sum_j\) is over \(j\) such that \(\mathrm{index}_j = i\).
Note
This operation is implemented via atomic operations on the GPU and is therefore non-deterministic since the order of parallel operations to the same value is undetermined. For floating-point variables, this results in a source of variance in the result.
- Parameters:
src (Tensor) – The source tensor.
index (Tensor) – The indices of elements to scatter.
dim (int) – The axis along which to index. (default:
-1)out (Tensor | None) – The destination tensor.
dim_size (int | None) – If
outis not given, automatically create output with sizedim_sizeat dimensiondim. Ifdim_sizeis not given, a minimal sized output tensor according toindex.max() + 1is returned.reduce (str) – The reduce operation (
"sum","mul","mean","min"or"max"). (default:"sum")
- Return type:
Tensor
from torch_scatter import scatter src = torch.randn(10, 6, 64) index = torch.tensor([0, 1, 0, 1, 2, 1]) # Broadcasting in the first and last dim. out = scatter(src, index, dim=1, reduce="sum") print(out.size())
torch.Size([10, 3, 64])