topobench.data.preprocessor.preprocessor module#

Preprocessor for datasets.

class DataTransform(transform_name, **kwargs)#

Bases: BaseTransform

Abstract class to define a custom data lifting.

Parameters:
transform_namestr

The name of the transform to be used.

**kwargsdict

Additional arguments for the class. Should contain “transform_name”.

__init__(transform_name, **kwargs)#
forward(data)#

Forward pass of the lifting.

Parameters:
datatorch_geometric.data.Data

The input data to be lifted.

Returns:
torch_geometric.data.Data

The lifted data.

class DataloadDataset(data_lst)#

Bases: Dataset

Custom dataset to return all the values added to the dataset object.

Parameters:
data_lstlist[torch_geometric.data.Data]

List of torch_geometric.data.Data objects.

__init__(data_lst)#
get(idx)#

Get data object from data list.

Parameters:
idxint

Index of the data object to get.

Returns:
tuple

Tuple containing a list of all the values for the data and the corresponding keys.

len()#

Return the length of the dataset.

Returns:
int

Length of the dataset.

class PreProcessor(dataset, data_dir, transforms_config=None, **kwargs)#

Bases: InMemoryDataset

Preprocessor for datasets.

Parameters:
datasetlist

List of data objects.

data_dirstr

Path to the directory containing the data.

transforms_configDictConfig, optional

Configuration parameters for the transforms (default: None).

**kwargsoptional

Optional additional arguments.

__init__(dataset, data_dir, transforms_config=None, **kwargs)#
instantiate_pre_transform(data_dir, transforms_config)#

Instantiate the pre-transforms.

Parameters:
data_dirstr

Path to the directory containing the data.

transforms_configDictConfig

Configuration parameters for the transforms.

Returns:
torch_geometric.transforms.Compose

Pre-transform object.

load(path)#

Load the dataset from the file path path.

Parameters:
pathstr

The path to the processed data.

load_dataset_splits(split_params)#

Load the dataset splits.

Parameters:
split_paramsdict

Parameters for loading the dataset splits.

Returns:
tuple

A tuple containing the train, validation, and test datasets.

process()#

Method that processes the data.

save_transform_parameters()#

Save the transform parameters.

set_processed_data_dir(pre_transforms_dict, data_dir, transforms_config)#

Set the processed data directory.

Parameters:
pre_transforms_dictdict

Dictionary containing the pre-transforms.

data_dirstr

Path to the directory containing the data.

transforms_configDictConfig

Configuration parameters for the transforms.

property processed_dir: str#

Return the path to the processed directory.

Returns:
str

Path to the processed directory.

property processed_file_names: str#

Return the name of the processed file.

Returns:
str

Name of the processed file.

ensure_serializable(obj)#

Ensure that the object is serializable.

Parameters:
objobject

Object to ensure serializability.

Returns:
object

Object that is serializable.

load_inductive_splits(dataset, parameters)#

Load multiple-graph datasets with the specified split.

Parameters:
datasettorch_geometric.data.Dataset

Graph dataset.

parametersDictConfig

Configuration parameters.

Returns:
list:

List containing the train, validation, and test splits.

load_transductive_splits(dataset, parameters)#

Load the graph dataset with the specified split.

Parameters:
datasettorch_geometric.data.Dataset

Graph dataset.

parametersDictConfig

Configuration parameters.

Returns:
list:

List containing the train, validation, and test splits.

make_hash(o)#

Make a hash from a dictionary, list, tuple or set to any level, that contains only other hashable types.

Parameters:
odict, list, tuple, set

Object to hash.

Returns:
int

Hash of the object.