topobench.dataloader package#
This module implements the dataloader for the topobench package.
- class DataloadDataset(data_lst)#
Bases:
DatasetCustom dataset to return all the values added to the dataset object.
- Parameters:
- data_lstlist[torch_geometric.data.Data]
List of torch_geometric.data.Data objects.
- __init__(data_lst)#
- get(idx)#
Get data object from data list.
- Parameters:
- idxint
Index of the data object to get.
- Returns:
- tuple
Tuple containing a list of all the values for the data and the corresponding keys.
- len()#
Return the length of the dataset.
- Returns:
- int
Length of the dataset.
- class TBDataloader(dataset_train, dataset_val=None, dataset_test=None, batch_size=1, num_workers=0, pin_memory=False, **kwargs)#
Bases:
LightningDataModuleThis class takes care of returning the dataloaders for the training, validation, and test datasets.
It also handles the collate function. The class is designed to work with the torch dataloaders.
- Parameters:
- dataset_trainDataloadDataset
The training dataset.
- dataset_valDataloadDataset, optional
The validation dataset (default: None).
- dataset_testDataloadDataset, optional
The test dataset (default: None).
- batch_sizeint, optional
The batch size for the dataloader (default: 1).
- num_workersint, optional
The number of worker processes to use for data loading (default: 0).
- pin_memorybool, optional
If True, the data loader will copy tensors into pinned memory before returning them (default: False).
- **kwargsoptional
Additional arguments.
References
- __init__(dataset_train, dataset_val=None, dataset_test=None, batch_size=1, num_workers=0, pin_memory=False, **kwargs)#
- prepare_data_per_node#
If True, each LOCAL_RANK=0 will call prepare data. Otherwise only NODE_RANK=0, LOCAL_RANK=0 will prepare data.
- allow_zero_length_dataloader_with_multiple_devices#
If True, dataloader with zero length within local rank is allowed. Default value is False.
- state_dict()#
Called when saving a checkpoint. Implement to generate and save the datamodule state.
- Returns:
- dict
A dictionary containing the datamodule state that you want to save.
- teardown(stage=None)#
Lightning hook for cleaning up after trainer.fit(), trainer.validate(), trainer.test(), and trainer.predict().
- Parameters:
- stagestr, optional
The stage being torn down. Either “fit”, “validate”, “test”, or “predict” (default: None).
- test_dataloader()#
Create and return the test dataloader.
- Returns:
- torch.utils.data.DataLoader
The test dataloader.
- train_dataloader()#
Create and return the train dataloader.
- Returns:
- torch.utils.data.DataLoader
The train dataloader.
- val_dataloader()#
Create and return the validation dataloader.
- Returns:
- torch.utils.data.DataLoader
The validation dataloader.
Submodules#
- topobench.dataloader.dataload_dataset module
- topobench.dataloader.dataloader module
AnyDataLoaderDataLoader.__init__()DataLoader.check_worker_number_rationality()DataLoader.batch_sizeDataLoader.datasetDataLoader.drop_lastDataLoader.multiprocessing_contextDataLoader.num_workersDataLoader.pin_memoryDataLoader.pin_memory_deviceDataLoader.prefetch_factorDataLoader.samplerDataLoader.timeout
DataloadDatasetLightningDataModuleLightningDataModule.__init__()LightningDataModule.from_datasets()LightningDataModule.load_from_checkpoint()LightningDataModule.load_state_dict()LightningDataModule.on_exception()LightningDataModule.state_dict()LightningDataModule.CHECKPOINT_HYPER_PARAMS_KEYLightningDataModule.CHECKPOINT_HYPER_PARAMS_NAMELightningDataModule.CHECKPOINT_HYPER_PARAMS_TYPELightningDataModule.name
TBDataloadercollate_fn()
- topobench.dataloader.utils module
AnyDomainDataSparseTensorSparseTensor.__init__()SparseTensor.add()SparseTensor.add_()SparseTensor.add_nnz()SparseTensor.add_nnz_()SparseTensor.avg_bandwidth()SparseTensor.avg_col_length()SparseTensor.avg_row_length()SparseTensor.bandwidth()SparseTensor.bandwidth_proportion()SparseTensor.bfloat16()SparseTensor.bool()SparseTensor.byte()SparseTensor.char()SparseTensor.clear_cache_()SparseTensor.clone()SparseTensor.coalesce()SparseTensor.coo()SparseTensor.copy()SparseTensor.cpu()SparseTensor.csc()SparseTensor.csr()SparseTensor.cuda()SparseTensor.density()SparseTensor.detach()SparseTensor.detach_()SparseTensor.device()SparseTensor.device_as()SparseTensor.dim()SparseTensor.double()SparseTensor.dtype()SparseTensor.eye()SparseTensor.fill_cache_()SparseTensor.fill_diag()SparseTensor.fill_value()SparseTensor.fill_value_()SparseTensor.float()SparseTensor.from_dense()SparseTensor.from_edge_index()SparseTensor.from_scipy()SparseTensor.from_storage()SparseTensor.from_torch_sparse_coo_tensor()SparseTensor.from_torch_sparse_csr_tensor()SparseTensor.get_diag()SparseTensor.half()SparseTensor.has_value()SparseTensor.index_select()SparseTensor.index_select_nnz()SparseTensor.int()SparseTensor.is_coalesced()SparseTensor.is_cuda()SparseTensor.is_floating_point()SparseTensor.is_pinned()SparseTensor.is_quadratic()SparseTensor.is_shared()SparseTensor.is_symmetric()SparseTensor.long()SparseTensor.masked_select()SparseTensor.masked_select_nnz()SparseTensor.matmul()SparseTensor.max()SparseTensor.mean()SparseTensor.min()SparseTensor.mul()SparseTensor.mul_()SparseTensor.mul_nnz()SparseTensor.mul_nnz_()SparseTensor.narrow()SparseTensor.nnz()SparseTensor.numel()SparseTensor.partition()SparseTensor.permute()SparseTensor.pin_memory()SparseTensor.random_walk()SparseTensor.remove_diag()SparseTensor.requires_grad()SparseTensor.requires_grad_()SparseTensor.reverse_cuthill_mckee()SparseTensor.saint_subgraph()SparseTensor.sample()SparseTensor.sample_adj()SparseTensor.select()SparseTensor.set_diag()SparseTensor.set_value()SparseTensor.set_value_()SparseTensor.share_memory_()SparseTensor.short()SparseTensor.size()SparseTensor.sizes()SparseTensor.sparse_reshape()SparseTensor.sparse_resize()SparseTensor.sparse_size()SparseTensor.sparse_sizes()SparseTensor.sparsity()SparseTensor.spmm()SparseTensor.spspmm()SparseTensor.sum()SparseTensor.t()SparseTensor.to()SparseTensor.to_dense()SparseTensor.to_device()SparseTensor.to_scipy()SparseTensor.to_symmetric()SparseTensor.to_torch_sparse_coo_tensor()SparseTensor.to_torch_sparse_csc_tensor()SparseTensor.to_torch_sparse_csr_tensor()SparseTensor.type()SparseTensor.type_as()SparseTensor.storage
defaultdictcollate_fn()to_data_list()