Source code for metrics.example

"""Loss module for the topobench package."""

from typing import Any

import torch
from torchmetrics import Metric
from torchmetrics.functional.regression.mse import (
    _mean_squared_error_compute,
    _mean_squared_error_update,
)


[docs] class ExampleRegressionMetric(Metric): r"""Example metric. Parameters ---------- squared : bool Whether to compute the squared error (default: True). num_outputs : int The number of outputs. **kwargs : Any Additional keyword arguments. """ is_differentiable = True higher_is_better = False full_state_update = False sum_squared_error: torch.Tensor total: torch.Tensor def __init__( self, squared: bool = True, num_outputs: int = 1, **kwargs: Any, ) -> None: super().__init__(**kwargs) if not isinstance(squared, bool): raise ValueError( f"Expected argument `squared` to be a boolean but got {squared}" ) self.squared = squared if not (isinstance(num_outputs, int) and num_outputs > 0): raise ValueError( f"Expected num_outputs to be a positive integer but got {num_outputs}" ) self.num_outputs = num_outputs self.add_state( "sum_squared_error", default=torch.zeros(num_outputs), dist_reduce_fx="sum", ) self.add_state("total", default=torch.tensor(0), dist_reduce_fx="sum")
[docs] def update(self, preds: torch.Tensor, target: torch.Tensor) -> None: """Update state with predictions and targets. Parameters ---------- preds : torch.Tensor Predictions from model. target : torch.Tensor Ground truth values. """ sum_squared_error, num_obs = _mean_squared_error_update( preds, target, num_outputs=self.num_outputs ) self.sum_squared_error += sum_squared_error self.total += num_obs
[docs] def compute(self) -> torch.Tensor: """Compute mean squared error over state. Returns ------- torch.Tensor Mean squared error. """ return _mean_squared_error_compute( self.sum_squared_error, self.total, squared=self.squared )