topobench.data.utils.split_utils module#
Split utilities.
- class DataloadDataset(data_lst)#
Bases:
DatasetCustom dataset to return all the values added to the dataset object.
- Parameters:
- data_lstlist[torch_geometric.data.Data]
List of torch_geometric.data.Data objects.
- __init__(data_lst)#
- get(idx)#
Get data object from data list.
- Parameters:
- idxint
Index of the data object to get.
- Returns:
- tuple
Tuple containing a list of all the values for the data and the corresponding keys.
- len()#
Return the length of the dataset.
- Returns:
- int
Length of the dataset.
- class StratifiedKFold(n_splits=5, *, shuffle=False, random_state=None)#
Bases:
_BaseKFoldClass-wise stratified K-Fold cross-validator.
Provides train/test indices to split data in train/test sets.
This cross-validation object is a variation of KFold that returns stratified folds. The folds are made by preserving the percentage of samples for each class in y in a binary or multiclass classification setting.
Read more in the User Guide.
For visualisation of cross-validation behaviour and comparison between common scikit-learn split methods refer to sphx_glr_auto_examples_model_selection_plot_cv_indices.py
Note
Stratification on the class label solves an engineering problem rather than a statistical one. See stratification for more details.
- Parameters:
- n_splitsint, default=5
Number of folds. Must be at least 2.
Changed in version 0.22:
n_splitsdefault value changed from 3 to 5.- shufflebool, default=False
Whether to shuffle each class’s samples before splitting into batches. Note that the samples within each split will not be shuffled.
- random_stateint, RandomState instance or None, default=None
When shuffle is True, random_state affects the ordering of the indices, which controls the randomness of each fold for each class. Otherwise, leave random_state as None. Pass an int for reproducible output across multiple function calls. See Glossary.
See also
RepeatedStratifiedKFoldRepeats Stratified K-Fold n times.
Notes
The implementation is designed to:
Generate test sets such that all contain the same distribution of classes, or as close as possible.
Be invariant to class label: relabelling
y = ["Happy", "Sad"]toy = [1, 0]should not change the indices generated.Preserve order dependencies in the dataset ordering, when
shuffle=False: all samples from class k in some test set were contiguous in y, or separated in y by samples from classes other than k.Generate test sets where the smallest and largest differ by at most one sample.
Changed in version 0.22: The previous implementation did not follow the last constraint.
Examples
>>> import numpy as np >>> from sklearn.model_selection import StratifiedKFold >>> X = np.array([[1, 2], [3, 4], [1, 2], [3, 4]]) >>> y = np.array([0, 0, 1, 1]) >>> skf = StratifiedKFold(n_splits=2) >>> skf.get_n_splits(X, y) 2 >>> print(skf) StratifiedKFold(n_splits=2, random_state=None, shuffle=False) >>> for i, (train_index, test_index) in enumerate(skf.split(X, y)): ... print(f"Fold {i}:") ... print(f" Train: index={train_index}") ... print(f" Test: index={test_index}") Fold 0: Train: index=[1 3] Test: index=[0 2] Fold 1: Train: index=[0 2] Test: index=[1 3]
- __init__(n_splits=5, *, shuffle=False, random_state=None)#
- split(X, y, groups=None)#
Generate indices to split data into training and test set.
- Parameters:
- Xarray-like of shape (n_samples, n_features)
Training data, where n_samples is the number of samples and n_features is the number of features.
Note that providing
yis sufficient to generate the splits and hencenp.zeros(n_samples)may be used as a placeholder forXinstead of actual training data.- yarray-like of shape (n_samples,)
The target variable for supervised learning problems. Stratification is done based on the y labels.
- groupsobject
Always ignored, exists for compatibility.
- Yields:
- trainndarray
The training set indices for that split.
- testndarray
The testing set indices for that split.
Notes
Randomized CV splitters may return different results for each call of split. You can make the results identical by setting random_state to an integer.
- assign_train_val_test_mask_to_graphs(dataset, split_idx)#
Split the graph dataset into train, validation, and test datasets.
- Parameters:
- datasettorch_geometric.data.Dataset
Considered dataset.
- split_idxdict
Dictionary containing the train, validation, and test indices.
- Returns:
- tuple:
Tuple containing the train, validation, and test datasets.
- k_fold_split(labels, parameters, root=None)#
Return train and valid indices as in K-Fold Cross-Validation.
If the split already exists it loads it automatically, otherwise it creates the split file for the subsequent runs.
- Parameters:
- labelstorch.Tensor
Label tensor.
- parametersDictConfig
Configuration parameters.
- rootstr, optional
Root directory for data splits. Overwrite the default directory.
- Returns:
- dict
Dictionary containing the train, validation and test indices, with keys “train”, “valid”, and “test”.
- load_coauthorship_hypergraph_splits(data, parameters, train_prop=0.5)#
Load the split generated by rand_train_test_idx function.
- Parameters:
- datatorch_geometric.data.Data
Graph dataset.
- parametersDictConfig
Configuration parameters.
- train_propfloat
Proportion of training data.
- Returns:
- torch_geometric.data.Data:
Graph dataset with the specified split.
- load_inductive_splits(dataset, parameters)#
Load multiple-graph datasets with the specified split.
- Parameters:
- datasettorch_geometric.data.Dataset
Graph dataset.
- parametersDictConfig
Configuration parameters.
- Returns:
- list:
List containing the train, validation, and test splits.
- load_transductive_splits(dataset, parameters)#
Load the graph dataset with the specified split.
- Parameters:
- datasettorch_geometric.data.Dataset
Graph dataset.
- parametersDictConfig
Configuration parameters.
- Returns:
- list:
List containing the train, validation, and test splits.
- random_splitting(labels, parameters, root=None, global_data_seed=42)#
Randomly splits label into train/valid/test splits.
Adapted from CUAI/Non-Homophily-Benchmarks.
- Parameters:
- labelstorch.Tensor
Label tensor.
- parametersDictConfig
Configuration parameter.
- rootstr, optional
Root directory for data splits. Overwrite the default directory.
- global_data_seedint
Seed for the random number generator.
- Returns:
- dict:
Dictionary containing the train, validation and test indices with keys “train”, “valid”, and “test”.