topobench.dataloader.utils module#
Dataloader utilities.
- class Any(*args, **kwargs)#
Bases:
objectSpecial type indicating an unconstrained type.
Any is compatible with every type.
Any assumed to have all methods.
All values assumed to be instances of Any.
Note that all the above statements are true from the point of view of static type checkers. At runtime, Any should not be used with instance checks.
- class DomainData(x=None, edge_index=None, edge_attr=None, y=None, pos=None, time=None, **kwargs)#
Bases:
DataHelper Data class so that not only sparse matrices with adj in the name can work with PyG dataloaders.
It overwrites some methods from torch_geometric.data.Data
- is_valid(string)#
Check if the string contains any of the valid names.
- Parameters:
- stringstr
String to check.
- Returns:
- bool
Whether the string contains any of the valid names.
- class SparseTensor(row=None, rowptr=None, col=None, value=None, sparse_sizes=None, is_sorted=False, trust_data=False)#
Bases:
object- __init__(row=None, rowptr=None, col=None, value=None, sparse_sizes=None, is_sorted=False, trust_data=False)#
- add(other)#
- add_(other)#
- add_nnz(other, layout=None)#
- add_nnz_(other, layout=None)#
- avg_bandwidth()#
- avg_col_length()#
- avg_row_length()#
- bandwidth()#
- bandwidth_proportion(bandwidth)#
- bfloat16()#
- bool()#
- byte()#
- char()#
- clear_cache_()#
- clone()#
- coalesce(reduce='sum')#
- coo()#
- copy()#
- cpu()#
- csc()#
- csr()#
- cuda(device=None, non_blocking=False)#
- density()#
- detach()#
- detach_()#
- device()#
- device_as(tensor, non_blocking=False)#
- dim()#
- double()#
- dtype()#
- classmethod eye(M, N=None, has_value=True, dtype=None, device=None, fill_cache=False)#
- fill_cache_()#
- fill_diag(fill_value, k=0)#
- fill_value(fill_value, dtype=None)#
- fill_value_(fill_value, dtype=None)#
- float()#
- classmethod from_dense(mat, has_value=True)#
- classmethod from_edge_index(edge_index, edge_attr=None, sparse_sizes=None, is_sorted=False, trust_data=False)#
- from_scipy(has_value=True)#
- classmethod from_storage(storage)#
- classmethod from_torch_sparse_coo_tensor(mat, has_value=True)#
- classmethod from_torch_sparse_csr_tensor(mat, has_value=True)#
- get_diag()#
- half()#
- has_value()#
- index_select(dim, idx)#
- index_select_nnz(idx, layout=None)#
- int()#
- is_coalesced()#
- is_cuda()#
- is_floating_point()#
- is_pinned()#
- is_quadratic()#
- is_symmetric()#
- long()#
- masked_select(dim, mask)#
- masked_select_nnz(mask, layout=None)#
- matmul(other, reduce='sum')#
- max(dim=None)#
- mean(dim=None)#
- min(dim=None)#
- mul(other)#
- mul_(other)#
- mul_nnz(other, layout=None)#
- mul_nnz_(other, layout=None)#
- narrow(dim, start, length)#
- nnz()#
- numel()#
- partition(num_parts, recursive=False, weighted=False, node_weight=None, balance_edge=False)#
- permute(perm)#
- pin_memory()#
- random_walk(start, walk_length)#
- remove_diag(k=0)#
- requires_grad()#
- requires_grad_(requires_grad=True, dtype=None)#
- reverse_cuthill_mckee(is_symmetric=None)#
- saint_subgraph(node_idx)#
- sample(num_neighbors, subset=None)#
- sample_adj(subset, num_neighbors, replace=False)#
- select(dim, idx)#
- set_diag(values=None, k=0)#
- set_value(value, layout=None)#
- set_value_(value, layout=None)#
- short()#
- size(dim)#
- sizes()#
- sparse_reshape(num_rows, num_cols)#
- sparse_resize(sparse_sizes)#
- sparse_size(dim)#
- sparse_sizes()#
- sparsity()#
- spmm(other, reduce='sum')#
- spspmm(other, reduce='sum')#
- sum(dim=None)#
- t()#
- to(*args, **kwargs)#
- to_dense(dtype=None)#
- to_device(device, non_blocking=False)#
- to_scipy(layout=None, dtype=None)#
- to_symmetric(reduce='sum')#
- to_torch_sparse_coo_tensor(dtype=None)#
- to_torch_sparse_csc_tensor(dtype=None)#
- to_torch_sparse_csr_tensor(dtype=None)#
- type(dtype, non_blocking=False)#
- type_as(tensor, non_blocking=False)#
- storage: SparseStorage#
- class defaultdict#
Bases:
dictdefaultdict(default_factory=None, /, […]) –> dict with default factory
The default factory is called without arguments to produce a new value when a key is not present, in __getitem__ only. A defaultdict compares equal to a dict with the same items. All remaining arguments are treated the same as if they were passed to the dict constructor, including keyword arguments.
- __init__(*args, **kwargs)#
- copy() a shallow copy of D.#
- default_factory#
Factory for default value called by __missing__().
- collate_fn(batch)#
Overwrite torch_geometric.data.DataLoader collate function to use the DomainData class.
This ensures that the torch_geometric dataloaders work with sparse matrices that are not necessarily named adj. The function also generates the batch slices for the different cell dimensions.
- Parameters:
- batchlist
List of data objects (e.g., torch_geometric.data.Data).
- Returns:
- torch_geometric.data.Batch
A torch_geometric.data.Batch object.
- to_data_list(batch)#
Workaround needed since torch_geometric doesn’t work when using torch.sparse instead of torch_sparse.
- Parameters:
- batchtorch_geometric.data.Batch
The batch of data.
- Returns:
- list
List of data objects.