topobench.dataloader.utils module#

Dataloader utilities.

class Any(*args, **kwargs)#

Bases: object

Special type indicating an unconstrained type.

  • Any is compatible with every type.

  • Any assumed to have all methods.

  • All values assumed to be instances of Any.

Note that all the above statements are true from the point of view of static type checkers. At runtime, Any should not be used with instance checks.

class DomainData(x=None, edge_index=None, edge_attr=None, y=None, pos=None, time=None, **kwargs)#

Bases: Data

Helper Data class so that not only sparse matrices with adj in the name can work with PyG dataloaders.

It overwrites some methods from torch_geometric.data.Data

is_valid(string)#

Check if the string contains any of the valid names.

Parameters:
stringstr

String to check.

Returns:
bool

Whether the string contains any of the valid names.

class SparseTensor(row=None, rowptr=None, col=None, value=None, sparse_sizes=None, is_sorted=False, trust_data=False)#

Bases: object

__init__(row=None, rowptr=None, col=None, value=None, sparse_sizes=None, is_sorted=False, trust_data=False)#
add(other)#
add_(other)#
add_nnz(other, layout=None)#
add_nnz_(other, layout=None)#
avg_bandwidth()#
avg_col_length()#
avg_row_length()#
bandwidth()#
bandwidth_proportion(bandwidth)#
bfloat16()#
bool()#
byte()#
char()#
clear_cache_()#
clone()#
coalesce(reduce='sum')#
coo()#
copy()#
cpu()#
csc()#
csr()#
cuda(device=None, non_blocking=False)#
density()#
detach()#
detach_()#
device()#
device_as(tensor, non_blocking=False)#
dim()#
double()#
dtype()#
classmethod eye(M, N=None, has_value=True, dtype=None, device=None, fill_cache=False)#
fill_cache_()#
fill_diag(fill_value, k=0)#
fill_value(fill_value, dtype=None)#
fill_value_(fill_value, dtype=None)#
float()#
classmethod from_dense(mat, has_value=True)#
classmethod from_edge_index(edge_index, edge_attr=None, sparse_sizes=None, is_sorted=False, trust_data=False)#
from_scipy(has_value=True)#
classmethod from_storage(storage)#
classmethod from_torch_sparse_coo_tensor(mat, has_value=True)#
classmethod from_torch_sparse_csr_tensor(mat, has_value=True)#
get_diag()#
half()#
has_value()#
index_select(dim, idx)#
index_select_nnz(idx, layout=None)#
int()#
is_coalesced()#
is_cuda()#
is_floating_point()#
is_pinned()#
is_quadratic()#
is_shared()#
is_symmetric()#
long()#
masked_select(dim, mask)#
masked_select_nnz(mask, layout=None)#
matmul(other, reduce='sum')#
max(dim=None)#
mean(dim=None)#
min(dim=None)#
mul(other)#
mul_(other)#
mul_nnz(other, layout=None)#
mul_nnz_(other, layout=None)#
narrow(dim, start, length)#
nnz()#
numel()#
partition(num_parts, recursive=False, weighted=False, node_weight=None, balance_edge=False)#
permute(perm)#
pin_memory()#
random_walk(start, walk_length)#
remove_diag(k=0)#
requires_grad()#
requires_grad_(requires_grad=True, dtype=None)#
reverse_cuthill_mckee(is_symmetric=None)#
saint_subgraph(node_idx)#
sample(num_neighbors, subset=None)#
sample_adj(subset, num_neighbors, replace=False)#
select(dim, idx)#
set_diag(values=None, k=0)#
set_value(value, layout=None)#
set_value_(value, layout=None)#
share_memory_()#
short()#
size(dim)#
sizes()#
sparse_reshape(num_rows, num_cols)#
sparse_resize(sparse_sizes)#
sparse_size(dim)#
sparse_sizes()#
sparsity()#
spmm(other, reduce='sum')#
spspmm(other, reduce='sum')#
sum(dim=None)#
t()#
to(*args, **kwargs)#
to_dense(dtype=None)#
to_device(device, non_blocking=False)#
to_scipy(layout=None, dtype=None)#
to_symmetric(reduce='sum')#
to_torch_sparse_coo_tensor(dtype=None)#
to_torch_sparse_csc_tensor(dtype=None)#
to_torch_sparse_csr_tensor(dtype=None)#
type(dtype, non_blocking=False)#
type_as(tensor, non_blocking=False)#
storage: SparseStorage#
class defaultdict#

Bases: dict

defaultdict(default_factory=None, /, […]) –> dict with default factory

The default factory is called without arguments to produce a new value when a key is not present, in __getitem__ only. A defaultdict compares equal to a dict with the same items. All remaining arguments are treated the same as if they were passed to the dict constructor, including keyword arguments.

__init__(*args, **kwargs)#
copy() a shallow copy of D.#
default_factory#

Factory for default value called by __missing__().

collate_fn(batch)#

Overwrite torch_geometric.data.DataLoader collate function to use the DomainData class.

This ensures that the torch_geometric dataloaders work with sparse matrices that are not necessarily named adj. The function also generates the batch slices for the different cell dimensions.

Parameters:
batchlist

List of data objects (e.g., torch_geometric.data.Data).

Returns:
torch_geometric.data.Batch

A torch_geometric.data.Batch object.

to_data_list(batch)#

Workaround needed since torch_geometric doesn’t work when using torch.sparse instead of torch_sparse.

Parameters:
batchtorch_geometric.data.Batch

The batch of data.

Returns:
list

List of data objects.