topobench.dataloader package#

This module implements the dataloader for the topobench package.

class DataloadDataset(data_lst)#

Bases: Dataset

Custom dataset to return all the values added to the dataset object.

Parameters:
data_lstlist[torch_geometric.data.Data]

List of torch_geometric.data.Data objects.

__init__(data_lst)#
get(idx)#

Get data object from data list.

Parameters:
idxint

Index of the data object to get.

Returns:
tuple

Tuple containing a list of all the values for the data and the corresponding keys.

len()#

Return the length of the dataset.

Returns:
int

Length of the dataset.

class TBDataloader(dataset_train, dataset_val=None, dataset_test=None, batch_size=1, num_workers=0, pin_memory=False, **kwargs)#

Bases: LightningDataModule

This class takes care of returning the dataloaders for the training, validation, and test datasets.

It also handles the collate function. The class is designed to work with the torch dataloaders.

Parameters:
dataset_trainDataloadDataset

The training dataset.

dataset_valDataloadDataset, optional

The validation dataset (default: None).

dataset_testDataloadDataset, optional

The test dataset (default: None).

batch_sizeint, optional

The batch size for the dataloader (default: 1).

num_workersint, optional

The number of worker processes to use for data loading (default: 0).

pin_memorybool, optional

If True, the data loader will copy tensors into pinned memory before returning them (default: False).

**kwargsoptional

Additional arguments.

References

Read the docs:

https://lightning.ai/docs/pytorch/latest/data/datamodule.html

__init__(dataset_train, dataset_val=None, dataset_test=None, batch_size=1, num_workers=0, pin_memory=False, **kwargs)#
prepare_data_per_node#

If True, each LOCAL_RANK=0 will call prepare data. Otherwise only NODE_RANK=0, LOCAL_RANK=0 will prepare data.

allow_zero_length_dataloader_with_multiple_devices#

If True, dataloader with zero length within local rank is allowed. Default value is False.

state_dict()#

Called when saving a checkpoint. Implement to generate and save the datamodule state.

Returns:
dict

A dictionary containing the datamodule state that you want to save.

teardown(stage=None)#

Lightning hook for cleaning up after trainer.fit(), trainer.validate(), trainer.test(), and trainer.predict().

Parameters:
stagestr, optional

The stage being torn down. Either “fit”, “validate”, “test”, or “predict” (default: None).

test_dataloader()#

Create and return the test dataloader.

Returns:
torch.utils.data.DataLoader

The test dataloader.

train_dataloader()#

Create and return the train dataloader.

Returns:
torch.utils.data.DataLoader

The train dataloader.

val_dataloader()#

Create and return the validation dataloader.

Returns:
torch.utils.data.DataLoader

The validation dataloader.

Submodules#