Source code for topobench.nn.wrappers.combinatorial.tune_wrapper

"""Wrapper for the TopoTune model."""

from topobench.nn.wrappers.base import AbstractWrapper


[docs] class TuneWrapper(AbstractWrapper): r"""Wrapper for the TopoTune model. This wrapper defines the forward pass of the TopoTune model. """
[docs] def forward(self, batch): r"""Forward pass for the Tune wrapper. Parameters ---------- batch : torch_geometric.data.Data Batch object containing the batched data. Returns ------- dict Dictionary containing the updated model output. """ x = self.backbone(batch) model_out = {"labels": batch.y, "batch_0": batch.batch_0} for key, value in x.items(): model_out[f"x_{key}"] = value return model_out