topobench.evaluator package#
Evaluators for model evaluation.
- class AbstractEvaluator#
Bases:
ABCAbstract class for the evaluator class.
- __init__()#
- abstract compute()#
Compute the metrics.
- abstract reset()#
Reset the metrics.
- abstract update(model_out)#
Update the metrics with the model output.
- Parameters:
- model_outdict
The model output.
- class TBEvaluator(task, **kwargs)#
Bases:
AbstractEvaluatorEvaluator class that is responsible for computing the metrics.
- Parameters:
- taskstr
The task type. It can be either “classification” or “regression”.
- **kwargsdict
Additional arguments for the class. The arguments depend on the task. In “classification” scenario, the following arguments are expected: - num_classes (int): The number of classes. - metrics (list[str]): A list of classification metrics to be computed. In “regression” scenario, the following arguments are expected: - metrics (list[str]): A list of regression metrics to be computed.
- __init__(task, **kwargs)#
- compute()#
Compute the metrics.
- Returns:
- dict
Dictionary containing the computed metrics.
- reset()#
Reset the metrics.
This method should be called after each epoch.
- update(model_out)#
Update the metrics with the model output.
- Parameters:
- model_outdict
The model output. It should contain the following keys: - logits : torch.Tensor The model predictions. - labels : torch.Tensor The ground truth labels. - batch : torch_geometric.data.Data (optional) The batch data containing target normalizer stats.
- Raises:
- ValueError
If the task is not valid.
Subpackages#
Submodules#
- topobench.evaluator.base module
- topobench.evaluator.evaluator module
AbstractEvaluatorMetricCollectionMetricCollection.__init__()MetricCollection.add_metrics()MetricCollection.clone()MetricCollection.compute()MetricCollection.forward()MetricCollection.items()MetricCollection.keys()MetricCollection.persistent()MetricCollection.plot()MetricCollection.reset()MetricCollection.set_dtype()MetricCollection.update()MetricCollection.values()MetricCollection.compute_groupsMetricCollection.metric_state
TBEvaluator