topobench.evaluator.metrics.example module#
Loss module for the topobench package.
- class Any(*args, **kwargs)#
Bases:
objectSpecial type indicating an unconstrained type.
Any is compatible with every type.
Any assumed to have all methods.
All values assumed to be instances of Any.
Note that all the above statements are true from the point of view of static type checkers. At runtime, Any should not be used with instance checks.
- class ExampleRegressionMetric(squared=True, num_outputs=1, **kwargs)#
Bases:
MetricExample metric.
- Parameters:
- squaredbool
Whether to compute the squared error (default: True).
- num_outputsint
The number of outputs.
- **kwargsAny
Additional keyword arguments.
- __init__(squared=True, num_outputs=1, **kwargs)#
Initialize internal Module state, shared by both nn.Module and ScriptModule.
- compute()#
Compute mean squared error over state.
- Returns:
- torch.Tensor
Mean squared error.
- update(preds, target)#
Update state with predictions and targets.
- Parameters:
- predstorch.Tensor
Predictions from model.
- targettorch.Tensor
Ground truth values.
- sum_squared_error: Tensor#
- total: Tensor#
- class Metric(**kwargs)#
Bases:
Module,ABCBase class for all metrics present in the Metrics API.
This class is inherited by all metrics and implements the following functionality:
Handles the transfer of metric states to the correct device.
Handles the synchronization of metric states across processes.
Provides properties and methods to control the overall behavior of the metric and its states.
The three core methods of the base class are:
add_state(),forward()andreset()which should almost never be overwritten by child classes. Instead, the following methods should be overwrittenupdate()andcompute().- Parameters:
kwargs (Any) –
additional keyword arguments, see Metric kwargs for more info.
- compute_on_cpu:
If metric state should be stored on CPU during computations. Only works for list states.
- dist_sync_on_step:
If metric state should synchronize on
forward(). Default isFalse.
- process_group:
The process group on which the synchronization is called. Default is the world.
- dist_sync_fn:
Function that performs the allgather option on the metric state. Default is a custom implementation that calls
torch.distributed.all_gatherinternally.
- distributed_available_fn:
Function that checks if the distributed backend is available. Defaults to a check of
torch.distributed.is_available()andtorch.distributed.is_initialized().
- sync_on_compute:
If metric state should synchronize when
computeis called. Default isTrue.
- compute_with_cache:
If results from
computeshould be cached. Default isTrue.
- __init__(**kwargs)#
Initialize internal Module state, shared by both nn.Module and ScriptModule.
- add_state(name, default, dist_reduce_fx=None, persistent=False)#
Add metric state variable. Only used by subclasses.
Metric state variables are either :class:`~torch.Tensor or an empty list, which can be appended to by the metric. Each state variable must have a unique name associated with it. State variables are accessible as attributes of the metric i.e, if
nameis"my_state"then its value can be accessed from an instancemetricasmetric.my_state. Metric states behave like buffers and parameters ofModuleas they are also updated when.to()is called. Unlike parameters and buffers, metric states are not by default saved in the modulesstate_dict.- Parameters:
name (str) – The name of the state variable. The variable will then be accessible at
self.name.default (list | Tensor) – Default value of the state; can either be a
Tensoror an empty list. The state will be reset to this value whenself.reset()is called.dist_reduce_fx (Optional) – Function to reduce state across multiple processes in distributed mode. If value is
"sum","mean","cat","min"or"max"we will usetorch.sum,torch.mean,torch.cat,torch.minandtorch.max`respectively, each with argumentdim=0. Note that the"cat"reduction only makes sense if the state is a list, and not a tensor. The user can also pass a custom function in this parameter.persistent (Optional) – whether the state will be saved as part of the modules
state_dict. Default isFalse.
Note
Setting
dist_reduce_fxto None will return the metric state synchronized across different processes. However, there won’t be any reduction function applied to the synchronized metric state.The metric states would be synced as follows
If the metric state is
Tensor, the synced value will be a stackedTensoracross the process dimension if the metric state was aTensor. The originalTensormetric state retains dimension and hence the synchronized output will be of shape(num_process, ...).If the metric state is a
list, the synced value will be alistcontaining the combined elements from all processes.
Important
When passing a custom function to
dist_reduce_fx, expect the synchronized metric state to follow the format discussed in the above note.Caution
The values inserted into a list state are deleted whenever
reset()is called. This allows device memory to be automatically reallocated, but may produce unexpected effects when referencing list states. To retain such values afterreset()is called, you must first copy them to another object.- Raises:
ValueError – If
defaultis not atensoror anempty list.ValueError – If
dist_reduce_fxis not callable or one of"mean","sum","cat","min","max"orNone.
- clone()#
Make a copy of the metric.
- abstract compute()#
Override this method to compute the final metric value.
This method will automatically synchronize state variables when running in distributed backend.
- double()#
Override default and prevent dtype casting.
Please use
Metric.set_dtype()instead.
- float()#
Override default and prevent dtype casting.
Please use
Metric.set_dtype()instead.
- forward(*args, **kwargs)#
Aggregate and evaluate batch input directly.
Serves the dual purpose of both computing the metric on the current batch of inputs but also add the batch statistics to the overall accumulating metric state. Input arguments are the exact same as corresponding
updatemethod. The returned output is the exact same as the output ofcompute.- Parameters:
- Returns:
The output of the
computemethod evaluated on the current batch.- Raises:
TorchMetricsUserError – If the metric is already synced and
forwardis called again.- Return type:
- half()#
Override default and prevent dtype casting.
Please use
Metric.set_dtype()instead.
- merge_state(incoming_state)#
Merge incoming metric state to the current state of the metric.
- Parameters:
incoming_state (dict[str, Any] | Metric) – either a dict containing a metric state similar to the metric itself or an instance of the metric class.
- Raises:
ValueError – If the incoming state is neither a dict nor an instance of the metric class.
RuntimeError – If the metric has
full_state_update=Trueordist_sync_on_step=True. In these cases, the metric cannot be merged with another metric state in a simple way. The user should overwrite the method in the metric class to handle the merge operation.ValueError – If the incoming state is a metric instance but the class is different from the current metric class.
Example with a metric instance:
>>> from torchmetrics.aggregation import SumMetric >>> metric1 = SumMetric() >>> metric2 = SumMetric() >>> metric1.update(1) >>> metric2.update(2) >>> metric1.merge_state(metric2) >>> metric1.compute() tensor(3.)
Example with a dict:
>>> from torchmetrics.aggregation import SumMetric >>> metric = SumMetric() >>> metric.update(1) >>> # SumMetric has one state variable called `sum_value` >>> metric.merge_state({"sum_value": torch.tensor(2)}) >>> metric.compute() tensor(3.)
- persistent(mode=False)#
Change post-init if metric states should be saved to its state_dict.
- plot(*_, **__)#
Override this method plot the metric value.
- reset()#
Reset metric state variables to their default value.
- set_dtype(dst_type)#
Transfer all metric state to specific dtype. Special version of standard type method.
- Parameters:
dst_type (str | dtype) – the desired type as string or dtype object
- state_dict(destination=None, prefix='', keep_vars=False)#
Get the current state of metric as an dictionary.
- Parameters:
destination (dict[str, Any] | None) – Optional dictionary, that if provided, the state of module will be updated into the dict and the same object is returned. Otherwise, an
OrderedDictwill be created and returned.prefix (str) – optional string, a prefix added to parameter and buffer names to compose the keys in state_dict.
keep_vars (bool) – by default the
Tensorreturned in the state dict are detached from autograd. If set toTrue, detaching will not be performed.
- sync(dist_sync_fn=None, process_group=None, should_sync=True, distributed_available=None)#
Sync function for manually controlling when metrics states should be synced across processes.
- Parameters:
dist_sync_fn (Callable | None) – Function to be used to perform states synchronization
process_group (Any | None) – Specify the process group on which synchronization is called. default: None (which selects the entire world)
should_sync (bool) – Whether to apply to state synchronization. This will have an impact only when running in a distributed setting.
distributed_available (Callable | None) – Function to determine if we are running inside a distributed setting
- Raises:
TorchMetricsUserError – If the metric is already synced and
syncis called again.
- sync_context(dist_sync_fn=None, process_group=None, should_sync=True, should_unsync=True, distributed_available=None)#
Context manager to synchronize states.
This context manager is used in distributed setting and makes sure that the local cache states are restored after yielding the synchronized state.
- Parameters:
dist_sync_fn (Callable | None) – Function to be used to perform states synchronization
process_group (Any | None) – Specify the process group on which synchronization is called. default: None (which selects the entire world)
should_sync (bool) – Whether to apply to state synchronization. This will have an impact only when running in a distributed setting.
should_unsync (bool) – Whether to restore the cache state so that the metrics can continue to be accumulated.
distributed_available (Callable | None) – Function to determine if we are running inside a distributed setting
- type(dst_type)#
Override default and prevent dtype casting.
Please use
Metric.set_dtype()instead.
- unsync(should_unsync=True)#
Unsync function for manually controlling when metrics states should be reverted back to their local states.
- Parameters:
should_unsync (bool) – Whether to perform unsync
- abstract update(*_, **__)#
Override this method to update the state variables of your metric class.
- property device: device#
Return the device of the metric.
- property dtype: dtype#
Return the default dtype of the metric.