topobench.dataloader package#
Submodules#
topobench.dataloader.dataload_dataset module#
Dataset class compatible with TBDataloader.
- class topobench.dataloader.dataload_dataset.DataloadDataset(*args: Any, **kwargs: Any)#
Bases:
Dataset
Custom dataset to return all the values added to the dataset object.
- Parameters:
- data_lstlist[torch_geometric.data.Data]
List of torch_geometric.data.Data objects.
- get(idx)#
Get data object from data list.
- Parameters:
- idxint
Index of the data object to get.
- Returns:
- tuple
Tuple containing a list of all the values for the data and the corresponding keys.
- len()#
Return the length of the dataset.
- Returns:
- int
Length of the dataset.
topobench.dataloader.dataloader module#
TBDataloader class.
- class topobench.dataloader.dataloader.TBDataloader(dataset_train: DataloadDataset, dataset_val: DataloadDataset = None, dataset_test: DataloadDataset = None, batch_size: int = 1, num_workers: int = 0, pin_memory: bool = False, **kwargs: Any)#
Bases:
LightningDataModule
This class takes care of returning the dataloaders for the training, validation, and test datasets.
It also handles the collate function. The class is designed to work with the torch dataloaders.
- Parameters:
- dataset_trainDataloadDataset
The training dataset.
- dataset_valDataloadDataset, optional
The validation dataset (default: None).
- dataset_testDataloadDataset, optional
The test dataset (default: None).
- batch_sizeint, optional
The batch size for the dataloader (default: 1).
- num_workersint, optional
The number of worker processes to use for data loading (default: 0).
- pin_memorybool, optional
If True, the data loader will copy tensors into pinned memory before returning them (default: False).
- **kwargsoptional
Additional arguments.
References
- state_dict() dict[Any, Any] #
Called when saving a checkpoint. Implement to generate and save the datamodule state.
- Returns:
- dict
A dictionary containing the datamodule state that you want to save.
- teardown(stage: str | None = None) None #
Lightning hook for cleaning up after trainer.fit(), trainer.validate(), trainer.test(), and trainer.predict().
- Parameters:
- stagestr, optional
The stage being torn down. Either “fit”, “validate”, “test”, or “predict” (default: None).
- test_dataloader() torch.utils.data.DataLoader #
Create and return the test dataloader.
- Returns:
- torch.utils.data.DataLoader
The test dataloader.
- train_dataloader() torch.utils.data.DataLoader #
Create and return the train dataloader.
- Returns:
- torch.utils.data.DataLoader
The train dataloader.
- val_dataloader() torch.utils.data.DataLoader #
Create and return the validation dataloader.
- Returns:
- torch.utils.data.DataLoader
The validation dataloader.
topobench.dataloader.utils module#
Dataloader utilities.
- class topobench.dataloader.utils.DomainData(*args: Any, **kwargs: Any)#
Bases:
Data
Helper Data class so that not only sparse matrices with adj in the name can work with PyG dataloaders.
It overwrites some methods from torch_geometric.data.Data
- is_valid(string)#
Check if the string contains any of the valid names.
- Parameters:
- stringstr
String to check.
- Returns:
- bool
Whether the string contains any of the valid names.
- topobench.dataloader.utils.collate_fn(batch)#
Overwrite torch_geometric.data.DataLoader collate function to use the DomainData class.
This ensures that the torch_geometric dataloaders work with sparse matrices that are not necessarily named adj. The function also generates the batch slices for the different cell dimensions.
- Parameters:
- batchlist
List of data objects (e.g., torch_geometric.data.Data).
- Returns:
- torch_geometric.data.Batch
A torch_geometric.data.Batch object.
- topobench.dataloader.utils.to_data_list(batch)#
Workaround needed since torch_geometric doesn’t work when using torch.sparse instead of torch_sparse.
- Parameters:
- batchtorch_geometric.data.Batch
The batch of data.
- Returns:
- list
List of data objects.
Module contents#
This module implements the dataloader for the topobench package.
- class topobench.dataloader.DataloadDataset(*args: Any, **kwargs: Any)#
Bases:
Dataset
Custom dataset to return all the values added to the dataset object.
- Parameters:
- data_lstlist[torch_geometric.data.Data]
List of torch_geometric.data.Data objects.
- get(idx)#
Get data object from data list.
- Parameters:
- idxint
Index of the data object to get.
- Returns:
- tuple
Tuple containing a list of all the values for the data and the corresponding keys.
- len()#
Return the length of the dataset.
- Returns:
- int
Length of the dataset.
- class topobench.dataloader.TBDataloader(dataset_train: DataloadDataset, dataset_val: DataloadDataset = None, dataset_test: DataloadDataset = None, batch_size: int = 1, num_workers: int = 0, pin_memory: bool = False, **kwargs: Any)#
Bases:
LightningDataModule
This class takes care of returning the dataloaders for the training, validation, and test datasets.
It also handles the collate function. The class is designed to work with the torch dataloaders.
- Parameters:
- dataset_trainDataloadDataset
The training dataset.
- dataset_valDataloadDataset, optional
The validation dataset (default: None).
- dataset_testDataloadDataset, optional
The test dataset (default: None).
- batch_sizeint, optional
The batch size for the dataloader (default: 1).
- num_workersint, optional
The number of worker processes to use for data loading (default: 0).
- pin_memorybool, optional
If True, the data loader will copy tensors into pinned memory before returning them (default: False).
- **kwargsoptional
Additional arguments.
References
- state_dict() dict[Any, Any] #
Called when saving a checkpoint. Implement to generate and save the datamodule state.
- Returns:
- dict
A dictionary containing the datamodule state that you want to save.
- teardown(stage: str | None = None) None #
Lightning hook for cleaning up after trainer.fit(), trainer.validate(), trainer.test(), and trainer.predict().
- Parameters:
- stagestr, optional
The stage being torn down. Either “fit”, “validate”, “test”, or “predict” (default: None).
- test_dataloader() torch.utils.data.DataLoader #
Create and return the test dataloader.
- Returns:
- torch.utils.data.DataLoader
The test dataloader.
- train_dataloader() torch.utils.data.DataLoader #
Create and return the train dataloader.
- Returns:
- torch.utils.data.DataLoader
The train dataloader.
- val_dataloader() torch.utils.data.DataLoader #
Create and return the validation dataloader.
- Returns:
- torch.utils.data.DataLoader
The validation dataloader.