topobench.evaluator package#

Evaluators for model evaluation.

class AbstractEvaluator#

Bases: ABC

Abstract class for the evaluator class.

__init__()#
abstract compute()#

Compute the metrics.

abstract reset()#

Reset the metrics.

abstract update(model_out)#

Update the metrics with the model output.

Parameters:
model_outdict

The model output.

class TBEvaluator(task, **kwargs)#

Bases: AbstractEvaluator

Evaluator class that is responsible for computing the metrics.

Parameters:
taskstr

The task type. It can be either “classification” or “regression”.

**kwargsdict

Additional arguments for the class. The arguments depend on the task. In “classification” scenario, the following arguments are expected: - num_classes (int): The number of classes. - metrics (list[str]): A list of classification metrics to be computed. In “regression” scenario, the following arguments are expected: - metrics (list[str]): A list of regression metrics to be computed.

__init__(task, **kwargs)#
compute()#

Compute the metrics.

Returns:
dict

Dictionary containing the computed metrics.

reset()#

Reset the metrics.

This method should be called after each epoch.

update(model_out)#

Update the metrics with the model output.

Parameters:
model_outdict

The model output. It should contain the following keys: - logits : torch.Tensor The model predictions. - labels : torch.Tensor The ground truth labels. - batch : torch_geometric.data.Data (optional) The batch data containing target normalizer stats.

Raises:
ValueError

If the task is not valid.

Subpackages#

Submodules#