topobench.data.utils.split_utils module#

Split utilities.

class DataloadDataset(data_lst)#

Bases: Dataset

Custom dataset to return all the values added to the dataset object.

Parameters:
data_lstlist[torch_geometric.data.Data]

List of torch_geometric.data.Data objects.

__init__(data_lst)#
get(idx)#

Get data object from data list.

Parameters:
idxint

Index of the data object to get.

Returns:
tuple

Tuple containing a list of all the values for the data and the corresponding keys.

len()#

Return the length of the dataset.

Returns:
int

Length of the dataset.

class StratifiedKFold(n_splits=5, *, shuffle=False, random_state=None)#

Bases: _BaseKFold

Class-wise stratified K-Fold cross-validator.

Provides train/test indices to split data in train/test sets.

This cross-validation object is a variation of KFold that returns stratified folds. The folds are made by preserving the percentage of samples for each class in y in a binary or multiclass classification setting.

Read more in the User Guide.

For visualisation of cross-validation behaviour and comparison between common scikit-learn split methods refer to sphx_glr_auto_examples_model_selection_plot_cv_indices.py

Note

Stratification on the class label solves an engineering problem rather than a statistical one. See stratification for more details.

Parameters:
n_splitsint, default=5

Number of folds. Must be at least 2.

Changed in version 0.22: n_splits default value changed from 3 to 5.

shufflebool, default=False

Whether to shuffle each class’s samples before splitting into batches. Note that the samples within each split will not be shuffled.

random_stateint, RandomState instance or None, default=None

When shuffle is True, random_state affects the ordering of the indices, which controls the randomness of each fold for each class. Otherwise, leave random_state as None. Pass an int for reproducible output across multiple function calls. See Glossary.

See also

RepeatedStratifiedKFold

Repeats Stratified K-Fold n times.

Notes

The implementation is designed to:

  • Generate test sets such that all contain the same distribution of classes, or as close as possible.

  • Be invariant to class label: relabelling y = ["Happy", "Sad"] to y = [1, 0] should not change the indices generated.

  • Preserve order dependencies in the dataset ordering, when shuffle=False: all samples from class k in some test set were contiguous in y, or separated in y by samples from classes other than k.

  • Generate test sets where the smallest and largest differ by at most one sample.

Changed in version 0.22: The previous implementation did not follow the last constraint.

Examples

>>> import numpy as np
>>> from sklearn.model_selection import StratifiedKFold
>>> X = np.array([[1, 2], [3, 4], [1, 2], [3, 4]])
>>> y = np.array([0, 0, 1, 1])
>>> skf = StratifiedKFold(n_splits=2)
>>> skf.get_n_splits(X, y)
2
>>> print(skf)
StratifiedKFold(n_splits=2, random_state=None, shuffle=False)
>>> for i, (train_index, test_index) in enumerate(skf.split(X, y)):
...     print(f"Fold {i}:")
...     print(f"  Train: index={train_index}")
...     print(f"  Test:  index={test_index}")
Fold 0:
  Train: index=[1 3]
  Test:  index=[0 2]
Fold 1:
  Train: index=[0 2]
  Test:  index=[1 3]
__init__(n_splits=5, *, shuffle=False, random_state=None)#
split(X, y, groups=None)#

Generate indices to split data into training and test set.

Parameters:
Xarray-like of shape (n_samples, n_features)

Training data, where n_samples is the number of samples and n_features is the number of features.

Note that providing y is sufficient to generate the splits and hence np.zeros(n_samples) may be used as a placeholder for X instead of actual training data.

yarray-like of shape (n_samples,)

The target variable for supervised learning problems. Stratification is done based on the y labels.

groupsobject

Always ignored, exists for compatibility.

Yields:
trainndarray

The training set indices for that split.

testndarray

The testing set indices for that split.

Notes

Randomized CV splitters may return different results for each call of split. You can make the results identical by setting random_state to an integer.

assign_train_val_test_mask_to_graphs(dataset, split_idx)#

Split the graph dataset into train, validation, and test datasets.

Parameters:
datasettorch_geometric.data.Dataset

Considered dataset.

split_idxdict

Dictionary containing the train, validation, and test indices.

Returns:
tuple:

Tuple containing the train, validation, and test datasets.

k_fold_split(labels, parameters, root=None)#

Return train and valid indices as in K-Fold Cross-Validation.

If the split already exists it loads it automatically, otherwise it creates the split file for the subsequent runs.

Parameters:
labelstorch.Tensor

Label tensor.

parametersDictConfig

Configuration parameters.

rootstr, optional

Root directory for data splits. Overwrite the default directory.

Returns:
dict

Dictionary containing the train, validation and test indices, with keys “train”, “valid”, and “test”.

load_coauthorship_hypergraph_splits(data, parameters, train_prop=0.5)#

Load the split generated by rand_train_test_idx function.

Parameters:
datatorch_geometric.data.Data

Graph dataset.

parametersDictConfig

Configuration parameters.

train_propfloat

Proportion of training data.

Returns:
torch_geometric.data.Data:

Graph dataset with the specified split.

load_inductive_splits(dataset, parameters)#

Load multiple-graph datasets with the specified split.

Parameters:
datasettorch_geometric.data.Dataset

Graph dataset.

parametersDictConfig

Configuration parameters.

Returns:
list:

List containing the train, validation, and test splits.

load_transductive_splits(dataset, parameters)#

Load the graph dataset with the specified split.

Parameters:
datasettorch_geometric.data.Dataset

Graph dataset.

parametersDictConfig

Configuration parameters.

Returns:
list:

List containing the train, validation, and test splits.

random_splitting(labels, parameters, root=None, global_data_seed=42)#

Randomly splits label into train/valid/test splits.

Adapted from CUAI/Non-Homophily-Benchmarks.

Parameters:
labelstorch.Tensor

Label tensor.

parametersDictConfig

Configuration parameter.

rootstr, optional

Root directory for data splits. Overwrite the default directory.

global_data_seedint

Seed for the random number generator.

Returns:
dict:

Dictionary containing the train, validation and test indices with keys “train”, “valid”, and “test”.