Using a new dataset#

In this tutorial we show how you can use a dataset not present in the library.

This particular example uses the ENZIMES dataset, uses a simplicial lifting to create simplicial complexes, and trains the SCN2 model. We train the model using the appropriate training and validation datasets, and finally test it on the test dataset.

Table of contents#

โ€ƒ1. Imports

โ€ƒ2. Configurations and utilities

โ€ƒ3. Loading the data

โ€ƒ4. Model initialization

โ€ƒ5. Training

โ€ƒ6. Testing the model

1. Imports#

 In [2]:
import lightning as pl
import torch
from omegaconf import OmegaConf
from topomodelx.nn.simplicial.scn2 import SCN2
from torch_geometric.datasets import TUDataset

from topobenchmark.data.preprocessor import PreProcessor
from topobenchmark.dataloader.dataloader import TBDataloader
from topobenchmark.evaluator.evaluator import TBEvaluator
from topobenchmark.loss.loss import TBLoss
from topobenchmark.model.model import TBModel
from topobenchmark.nn.encoders import AllCellFeatureEncoder
from topobenchmark.nn.readouts import PropagateSignalDown
from topobenchmark.nn.wrappers.simplicial import SCNWrapper
from topobenchmark.optimizer import TBOptimizer

2. Configurations and utilities#

Configurations can be specified using yaml files or directly specified in your code like in this example.

 In [ ]:
transform_config = { "clique_lifting":
    {"transform_type": "lifting",
    "transform_name": "SimplicialCliqueLifting",
    "complex_dim": 3,}
}

split_config = {
    "learning_setting": "inductive",
    "split_type": "random",
    "data_seed": 0,
    "data_split_dir": "./data/ENZYMES/splits/",
    "train_prop": 0.5,
}

in_channels = 3
out_channels = 6
dim_hidden = 16

wrapper_config = {
    "out_channels": dim_hidden,
    "num_cell_dimensions": 3,
}

readout_config = {
    "readout_name": "PropagateSignalDown",
    "num_cell_dimensions": 1,
    "hidden_dim": dim_hidden,
    "out_channels": out_channels,
    "task_level": "graph",
    "pooling_type": "sum",
}

loss_config = {
    "dataset_loss":
        {
            "task": "classification",
            "loss_type": "cross_entropy"
        }
}

evaluator_config = {"task": "classification",
                    "num_classes": out_channels,
                    "metrics": ["accuracy", "precision", "recall"]}

optimizer_config = {"optimizer_id": "Adam",
                    "parameters":
                        {"lr": 0.001,"weight_decay": 0.0005}
                    }

transform_config = OmegaConf.create(transform_config)
split_config = OmegaConf.create(split_config)
readout_config = OmegaConf.create(readout_config)
loss_config = OmegaConf.create(loss_config)
evaluator_config = OmegaConf.create(evaluator_config)
optimizer_config = OmegaConf.create(optimizer_config)
 In [4]:
def wrapper(**factory_kwargs):
    def factory(backbone):
        return SCNWrapper(backbone, **factory_kwargs)
    return factory

3. Loading the data#

In this example we use the ENZYMES dataset. It is a graph dataset and we use the clique lifting to transform the graphs into simplicial complexes. We invite you to check out the README of the repository to learn more about the various liftings offered.

 In [5]:
dataset_dir = "./data/ENZYMES/"
dataset = TUDataset(root=dataset_dir, name="ENZYMES")

preprocessor = PreProcessor(dataset, dataset_dir, transform_config)
dataset_train, dataset_val, dataset_test = preprocessor.load_dataset_splits(split_config)
datamodule = TBDataloader(dataset_train, dataset_val, dataset_test, batch_size=32)
Transform parameters are the same, using existing data_dir: ./data/ENZYMES/clique_lifting/3206123057

4. Model initialization#

We can create the backbone by instantiating the SCN2 model from TopoModelX. Then the SCNWrapper and the TBModel take care of the rest.

 In [6]:
backbone = SCN2(in_channels_0=dim_hidden, in_channels_1=dim_hidden, in_channels_2=dim_hidden)
wrapper = wrapper(**wrapper_config)

readout = PropagateSignalDown(**readout_config)
loss = TBLoss(**loss_config)
feature_encoder = AllCellFeatureEncoder(in_channels=[in_channels, in_channels, in_channels], out_channels=dim_hidden)

evaluator = TBEvaluator(**evaluator_config)
optimizer = TBOptimizer(**optimizer_config)
 In [7]:
model = TBModel(backbone=backbone,
                 backbone_wrapper=wrapper,
                 readout=readout,
                 loss=loss,
                 feature_encoder=feature_encoder,
                 evaluator=evaluator,
                 optimizer=optimizer,
                 compile=False,)

5. Training#

Now we can use the lightning trainer to train the model.

 In [8]:
#%%capture
# Increase the number of epochs to get better results
trainer = pl.Trainer(max_epochs=5, accelerator="cpu", enable_progress_bar=False)

trainer.fit(model, datamodule)
train_metrics = trainer.callback_metrics
GPU available: True (cuda), used: False
TPU available: False, using: 0 TPU cores
HPU available: False, using: 0 HPUs
/home/lev/miniconda3/envs/tb/lib/python3.11/site-packages/lightning/pytorch/trainer/setup.py:177: GPU available but not used. You can set it by doing `Trainer(accelerator='gpu')`.
/home/lev/miniconda3/envs/tb/lib/python3.11/site-packages/lightning/pytorch/trainer/connectors/logger_connector/logger_connector.py:75: Starting from v1.9.0, `tensorboardX` has been removed as a dependency of the `lightning.pytorch` package, due to potential conflicts with other packages in the ML ecosystem. For this reason, `logger=True` will use `CSVLogger` as the default logger, unless the `tensorboard` or `tensorboardX` packages are found. Please `pip install lightning[extra]` or one of them to enable TensorBoard support by default
/home/lev/miniconda3/envs/tb/lib/python3.11/site-packages/lightning/pytorch/utilities/parsing.py:44: Attribute 'backbone_wrapper' removed from hparams because it cannot be pickled. You can suppress this warning by setting `self.save_hyperparameters(ignore=['backbone_wrapper'])`.

  | Name            | Type                  | Params | Mode
------------------------------------------------------------------
0 | feature_encoder | AllCellFeatureEncoder | 1.2 K  | train
1 | backbone        | SCNWrapper            | 1.6 K  | train
2 | readout         | PropagateSignalDown   | 102    | train
3 | val_acc_best    | MeanMetric            | 0      | train
------------------------------------------------------------------
2.9 K     Trainable params
0         Non-trainable params
2.9 K     Total params
0.012     Total estimated model params size (MB)
36        Modules in train mode
0         Modules in eval mode
/home/lev/miniconda3/envs/tb/lib/python3.11/site-packages/lightning/pytorch/trainer/connectors/data_connector.py:424: The 'val_dataloader' does not have many workers which may be a bottleneck. Consider increasing the value of the `num_workers` argument` to `num_workers=31` in the `DataLoader` to improve performance.
/home/lev/miniconda3/envs/tb/lib/python3.11/site-packages/torchmetrics/utilities/prints.py:43: UserWarning: The ``compute`` method of metric MulticlassAccuracy was called before the ``update`` method which may lead to errors, as metric states have not yet been updated.
  warnings.warn(*args, **kwargs)  # noqa: B028
/home/lev/miniconda3/envs/tb/lib/python3.11/site-packages/torchmetrics/utilities/prints.py:43: UserWarning: The ``compute`` method of metric MulticlassPrecision was called before the ``update`` method which may lead to errors, as metric states have not yet been updated.
  warnings.warn(*args, **kwargs)  # noqa: B028
/home/lev/miniconda3/envs/tb/lib/python3.11/site-packages/torchmetrics/utilities/prints.py:43: UserWarning: The ``compute`` method of metric MulticlassRecall was called before the ``update`` method which may lead to errors, as metric states have not yet been updated.
  warnings.warn(*args, **kwargs)  # noqa: B028
/home/lev/projects/TopoBenchmark/topobenchmark/nn/wrappers/simplicial/scn_wrapper.py:75: UserWarning: Sparse CSR tensor support is in beta state. If you miss a functionality in the sparse tensor support, please submit a feature request to https://github.com/pytorch/pytorch/issues. (Triggered internally at ../aten/src/ATen/SparseCsrTensorImpl.cpp:53.)
  normalized_matrix = diag_matrix @ (matrix @ diag_matrix)
/home/lev/miniconda3/envs/tb/lib/python3.11/site-packages/lightning/pytorch/trainer/connectors/data_connector.py:424: The 'train_dataloader' does not have many workers which may be a bottleneck. Consider increasing the value of the `num_workers` argument` to `num_workers=31` in the `DataLoader` to improve performance.
/home/lev/miniconda3/envs/tb/lib/python3.11/site-packages/lightning/pytorch/loops/fit_loop.py:298: The number of training batches (10) is smaller than the logging interval Trainer(log_every_n_steps=50). Set a lower value for log_every_n_steps if you want to see logs for the training epoch.
`Trainer.fit` stopped: `max_epochs=5` reached.
 In [9]:
print('      Training metrics\n', '-'*26)
for key in train_metrics:
    print('{:<21s} {:>5.4f}'.format(key+':', train_metrics[key].item()))
      Training metrics
 --------------------------
train/accuracy:       0.1567
train/precision:      0.1365
train/recall:         0.1525
val/loss:             2.3835
val/accuracy:         0.1400
val/precision:        0.1269
val/recall:           0.1830
train/loss:           2.3218

6. Testing the model#

Finally, we can test the model and obtain the results.

 In [10]:
trainer.test(model, datamodule)
test_metrics = trainer.callback_metrics

/home/lev/miniconda3/envs/tb/lib/python3.11/site-packages/lightning/pytorch/trainer/connectors/data_connector.py:424: The 'test_dataloader' does not have many workers which may be a bottleneck. Consider increasing the value of the `num_workers` argument` to `num_workers=31` in the `DataLoader` to improve performance.
โ”โ”โ”โ”โ”โ”โ”โ”โ”โ”โ”โ”โ”โ”โ”โ”โ”โ”โ”โ”โ”โ”โ”โ”โ”โ”โ”โ”โ”ณโ”โ”โ”โ”โ”โ”โ”โ”โ”โ”โ”โ”โ”โ”โ”โ”โ”โ”โ”โ”โ”โ”โ”โ”โ”โ”โ”โ”“
โ”ƒ        Test metric        โ”ƒ       DataLoader 0        โ”ƒ
โ”กโ”โ”โ”โ”โ”โ”โ”โ”โ”โ”โ”โ”โ”โ”โ”โ”โ”โ”โ”โ”โ”โ”โ”โ”โ”โ”โ”โ•‡โ”โ”โ”โ”โ”โ”โ”โ”โ”โ”โ”โ”โ”โ”โ”โ”โ”โ”โ”โ”โ”โ”โ”โ”โ”โ”โ”โ”ฉ
โ”‚       test/accuracy       โ”‚    0.1666666716337204     โ”‚
โ”‚         test/loss         โ”‚     2.021564483642578     โ”‚
โ”‚      test/precision       โ”‚    0.08934479206800461    โ”‚
โ”‚        test/recall        โ”‚    0.15170806646347046    โ”‚
โ””โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”ดโ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”˜
 In [11]:
print('      Testing metrics\n', '-'*25)
for key in test_metrics:
    print('{:<20s} {:>5.4f}'.format(key+':', test_metrics[key].item()))
      Testing metrics
 -------------------------
test/loss:           2.0216
test/accuracy:       0.1667
test/precision:      0.0893
test/recall:         0.1517
 In [ ]: