Module metrics¶
Base class¶
The base Metric
class is an abstract base class that are used as the building block for all other Module
metrics.
-
class
torchmetrics.
Metric
(compute_on_step=True, dist_sync_on_step=False, process_group=None, dist_sync_fn=None)[source] Base class for all metrics present in the Metrics API.
Implements
add_state()
,forward()
,reset()
and a few other things to handle distributed synchronization and per-step metric computation.Override
update()
andcompute()
functions to implement your own metric. Useadd_state()
to register metric state variables which keep track of state on each call ofupdate()
and are synchronized across processes whencompute()
is called.Note
Metric state variables can either be
torch.Tensors
or an empty list which can we used to store torch.Tensors`.Note
Different metrics only override
update()
and notforward()
. A call toupdate()
is valid, but it won’t return the metric value at the current step. A call toforward()
automatically callsupdate()
and also returns the metric value at the current step.- Parameters
compute_on_step¶ (
bool
) – Forward only callsupdate()
and returns None if this is set to False. default: Truedist_sync_on_step¶ (
bool
) – Synchronize metric state across processes at eachforward()
before returning the value at the step.process_group¶ (
Optional
[Any
]) – Specify the process group on which synchronization is called. default: None (which selects the entire world)dist_sync_fn¶ (
Optional
[Callable
]) – Callback that performs the allgather operation on the metric state. When None, DDP will be used to perform the allgather.
Initializes internal Module state, shared by both nn.Module and ScriptModule.
-
add_state
(name, default, dist_reduce_fx=None, persistent=False)[source] Adds metric state variable. Only used by subclasses.
- Parameters
name¶ (
str
) – The name of the state variable. The variable will then be accessible atself.name
.default¶ – Default value of the state; can either be a
torch.Tensor
or an empty list. The state will be reset to this value whenself.reset()
is called.dist_reduce_fx¶ (Optional) – Function to reduce state across multiple processes in distributed mode. If value is
"sum"
,"mean"
, or"cat"
, we will usetorch.sum
,torch.mean
, andtorch.cat
respectively, each with argumentdim=0
. Note that the"cat"
reduction only makes sense if the state is a list, and not a tensor. The user can also pass a custom function in this parameter.persistent¶ (Optional) – whether the state will be saved as part of the modules
state_dict
. Default isFalse
.
Note
Setting
dist_reduce_fx
to None will return the metric state synchronized across different processes. However, there won’t be any reduction function applied to the synchronized metric state.The metric states would be synced as follows
If the metric state is
torch.Tensor
, the synced value will be a stackedtorch.Tensor
across the process dimension if the metric state was atorch.Tensor
. The originaltorch.Tensor
metric state retains dimension and hence the synchronized output will be of shape(num_process, ...)
.If the metric state is a
list
, the synced value will be alist
containing the combined elements from all processes.
Note
When passing a custom function to
dist_reduce_fx
, expect the synchronized metric state to follow the format discussed in the above note.- Raises
ValueError – If
default
is not atensor
or anempty list
.ValueError – If
dist_reduce_fx
is not callable or one of"mean"
,"sum"
,"cat"
,None
.
-
clone
()[source] Make a copy of the metric
-
abstract
compute
()[source] Override this method to compute the final metric value from state variables synchronized across the distributed backend.
-
forward
(*args, **kwargs)[source] Automatically calls
update()
. Returns the metric value over inputs ifcompute_on_step
is True.
-
persistent
(mode=False)[source] Method for post-init to change if metric states should be saved to its state_dict
-
reset
()[source] This method automatically resets the metric state variables to their default value.
-
state_dict
(destination=None, prefix='', keep_vars=False)[source] Returns a dictionary containing a whole state of the module.
Both parameters and persistent buffers (e.g. running averages) are included. Keys are corresponding parameter and buffer names.
- Returns
a dictionary containing a whole state of the module
- Return type
Example:
>>> module.state_dict().keys() ['bias', 'weight']
We also have an AverageMeter
class that is helpful for defining ad-hoc metrics, when creating
your own metric type might be too burdensome.
-
class
torchmetrics.
AverageMeter
(compute_on_step=True, dist_sync_on_step=False, process_group=None, dist_sync_fn=None)[source] Computes the average of a stream of values.
- Forward accepts
value
(float tensor):(...)
weight
(float tensor):(...)
- Parameters
compute_on_step¶ (
bool
) – Forward only callsupdate()
and returns None if this is set to False. default: Truedist_sync_on_step¶ (
bool
) – Synchronize metric state across processes at eachforward()
before returning the value at the step.process_group¶ (
Optional
[Any
]) – Specify the process group on which synchronization is called. default: None (which selects the entire world)dist_sync_fn¶ (
Optional
[Callable
]) – Callback that performs the allgather operation on the metric state. When None, DDP will be used to perform the allgather.
- Example::
>>> from torchmetrics import AverageMeter >>> avg = AverageMeter() >>> avg.update(3) >>> avg.update(1) >>> avg.compute() tensor(2.)
>>> avg = AverageMeter() >>> values = torch.tensor([1., 2., 3.]) >>> avg(values) tensor(2.)
>>> avg = AverageMeter() >>> values = torch.tensor([1., 2.]) >>> weights = torch.tensor([3., 1.]) >>> avg(values, weights) tensor(1.2500)
Initializes internal Module state, shared by both nn.Module and ScriptModule.
-
compute
()[source] Override this method to compute the final metric value from state variables synchronized across the distributed backend.
- Return type
Classification Metrics¶
Input types¶
For the purposes of classification metrics, inputs (predictions and targets) are split
into these categories (N
stands for the batch size and C
for number of classes):
Type |
preds shape |
preds dtype |
target shape |
target dtype |
---|---|---|---|---|
Binary |
(N,) |
|
(N,) |
|
Multi-class |
(N,) |
|
(N,) |
|
Multi-class with probabilities |
(N, C) |
|
(N,) |
|
Multi-label |
(N, …) |
|
(N, …) |
|
Multi-dimensional multi-class |
(N, …) |
|
(N, …) |
|
Multi-dimensional multi-class with probabilities |
(N, C, …) |
|
(N, …) |
|
Note
All dimensions of size 1 (except N
) are “squeezed out” at the beginning, so
that, for example, a tensor of shape (N, 1)
is treated as (N, )
.
When predictions or targets are integers, it is assumed that class labels start at 0, i.e. the possible class labels are 0, 1, 2, 3, etc. Below are some examples of different input types
# Binary inputs
binary_preds = torch.tensor([0.6, 0.1, 0.9])
binary_target = torch.tensor([1, 0, 2])
# Multi-class inputs
mc_preds = torch.tensor([0, 2, 1])
mc_target = torch.tensor([0, 1, 2])
# Multi-class inputs with probabilities
mc_preds_probs = torch.tensor([[0.8, 0.2, 0], [0.1, 0.2, 0.7], [0.3, 0.6, 0.1]])
mc_target_probs = torch.tensor([0, 1, 2])
# Multi-label inputs
ml_preds = torch.tensor([[0.2, 0.8, 0.9], [0.5, 0.6, 0.1], [0.3, 0.1, 0.1]])
ml_target = torch.tensor([[0, 1, 1], [1, 0, 0], [0, 0, 0]])
Using the multiclass parameter¶
In some cases, you might have inputs which appear to be (multi-dimensional) multi-class but are actually binary/multi-label - for example, if both predictions and targets are integer (binary) tensors. Or it could be the other way around, you want to treat binary/multi-label inputs as 2-class (multi-dimensional) multi-class inputs.
For these cases, the metrics where this distinction would make a difference, expose the
multiclass
argument. Let’s see how this is used on the example of
StatScores
metric.
First, let’s consider the case with label predictions with 2 classes, which we want to treat as binary.
from torchmetrics.functional import stat_scores
# These inputs are supposed to be binary, but appear as multi-class
preds = torch.tensor([0, 1, 0])
target = torch.tensor([1, 1, 0])
As you can see below, by default the inputs are treated
as multi-class. We can set multiclass=False
to treat the inputs as binary -
which is the same as converting the predictions to float beforehand.
>>> stat_scores(preds, target, reduce='macro', num_classes=2)
tensor([[1, 1, 1, 0, 1],
[1, 0, 1, 1, 2]])
>>> stat_scores(preds, target, reduce='macro', num_classes=1, multiclass=False)
tensor([[1, 0, 1, 1, 2]])
>>> stat_scores(preds.float(), target, reduce='macro', num_classes=1)
tensor([[1, 0, 1, 1, 2]])
Next, consider the opposite example: inputs are binary (as predictions are probabilities), but we would like to treat them as 2-class multi-class, to obtain the metric for both classes.
preds = torch.tensor([0.2, 0.7, 0.3])
target = torch.tensor([1, 1, 0])
In this case we can set multiclass=True
, to treat the inputs as multi-class.
>>> stat_scores(preds, target, reduce='macro', num_classes=1)
tensor([[1, 0, 1, 1, 2]])
>>> stat_scores(preds, target, reduce='macro', num_classes=2, multiclass=True)
tensor([[1, 1, 1, 0, 1],
[1, 0, 1, 1, 2]])
Accuracy¶
-
class
torchmetrics.
Accuracy
(threshold=0.5, num_classes=None, average='micro', mdmc_average='global', ignore_index=None, top_k=None, multiclass=None, subset_accuracy=False, compute_on_step=True, dist_sync_on_step=False, process_group=None, dist_sync_fn=None)[source] Computes Accuracy:
Where
is a tensor of target values, and
is a tensor of predictions.
For multi-class and multi-dimensional multi-class data with probability predictions, the parameter
top_k
generalizes this metric to a Top-K accuracy metric: for each sample the top-K highest probability items are considered to find the correct label.For multi-label and multi-dimensional multi-class inputs, this metric computes the “global” accuracy by default, which counts all labels or sub-samples separately. This can be changed to subset accuracy (which requires all labels or sub-samples in the sample to be correctly predicted) by setting
subset_accuracy=True
.Accepts all input types listed in Input types.
- Parameters
num_classes¶ (
Optional
[int
]) – Number of classes. Necessary for'macro'
,'weighted'
andNone
average methods.threshold¶ (
float
) – Threshold probability value for transforming probability predictions to binary (0,1) predictions, in the case of binary or multi-label inputs.Defines the reduction that is applied. Should be one of the following:
'micro'
[default]: Calculate the metric globally, across all samples and classes.'macro'
: Calculate the metric for each class separately, and average the metrics across classes (with equal weights for each class).'weighted'
: Calculate the metric for each class separately, and average the metrics across classes, weighting each class by its support (tp + fn
).'none'
orNone
: Calculate the metric for each class separately, and return the metric for every class.'samples'
: Calculate the metric for each sample, and average the metrics across samples (with equal weights for each sample).
Note
What is considered a sample in the multi-dimensional multi-class case depends on the value of
mdmc_average
.mdmc_average¶ (
Optional
[str
]) –Defines how averaging is done for multi-dimensional multi-class inputs (on top of the
average
parameter). Should be one of the following:None
[default]: Should be left unchanged if your data is not multi-dimensional multi-class.'samplewise'
: In this case, the statistics are computed separately for each sample on theN
axis, and then averaged over samples. The computation for each sample is done by treating the flattened extra axes...
(see Input types) as theN
dimension within the sample, and computing the metric for the sample based on that.'global'
: In this case theN
and...
dimensions of the inputs (see Input types) are flattened into a newN_X
sample axis, i.e. the inputs are treated as if they were(N_X, C)
. From here on theaverage
parameter applies as usual.
ignore_index¶ (
Optional
[int
]) – Integer specifying a target class to ignore. If given, this class index does not contribute to the returned score, regardless of reduction method. If an index is ignored, andaverage=None
or'none'
, the score for the ignored class will be returned asnan
.Number of highest probability predictions considered to find the correct label, relevant only for (multi-dimensional) multi-class inputs with probability predictions. The default value (
None
) will be interpreted as 1 for these inputs.Should be left at default (
None
) for all other types of inputs.multiclass¶ (
Optional
[bool
]) – Used only in certain special cases, where you want to treat inputs as a different type than what they appear to be. See the parameter’s documentation section for a more detailed explanation and examples.Whether to compute subset accuracy for multi-label and multi-dimensional multi-class inputs (has no effect for other input types).
For multi-label inputs, if the parameter is set to
True
, then all labels for each sample must be correctly predicted for the sample to count as correct. If it is set toFalse
, then all labels are counted separately - this is equivalent to flattening inputs beforehand (i.e.preds = preds.flatten()
and same fortarget
).For multi-dimensional multi-class inputs, if the parameter is set to
True
, then all sub-sample (on the extra axis) must be correct for the sample to be counted as correct. If it is set toFalse
, then all sub-samples are counter separately - this is equivalent, in the case of label predictions, to flattening the inputs beforehand (i.e.preds = preds.flatten()
and same fortarget
). Note that thetop_k
parameter still applies in both cases, if set.
compute_on_step¶ (
bool
) – Forward only callsupdate()
and returnNone
if this is set toFalse
.dist_sync_on_step¶ (
bool
) – Synchronize metric state across processes at eachforward()
before returning the value at the stepprocess_group¶ (
Optional
[Any
]) – Specify the process group on which synchronization is called. default:None
(which selects the entire world)dist_sync_fn¶ (
Optional
[Callable
]) – Callback that performs the allgather operation on the metric state. WhenNone
, DDP will be used to perform the allgather
- Raises
ValueError – If
threshold
is not between0
and1
.ValueError – If
top_k
is not aninteger
larger than0
.ValueError – If
average
is none of"micro"
,"macro"
,"weighted"
,"samples"
,"none"
,None
.ValueError – If two different input modes are provided, eg. using
mult-label
withmulti-class
.ValueError – If
top_k
parameter is set formulti-label
inputs.
Example
>>> import torch >>> from torchmetrics import Accuracy >>> target = torch.tensor([0, 1, 2, 3]) >>> preds = torch.tensor([0, 2, 1, 3]) >>> accuracy = Accuracy() >>> accuracy(preds, target) tensor(0.5000)
>>> target = torch.tensor([0, 1, 2]) >>> preds = torch.tensor([[0.1, 0.9, 0], [0.3, 0.1, 0.6], [0.2, 0.5, 0.3]]) >>> accuracy = Accuracy(top_k=2) >>> accuracy(preds, target) tensor(0.6667)
Initializes internal Module state, shared by both nn.Module and ScriptModule.
-
compute
()[source] Computes accuracy based on inputs passed in to
update
previously.- Return type
AveragePrecision¶
-
class
torchmetrics.
AveragePrecision
(num_classes=None, pos_label=None, compute_on_step=True, dist_sync_on_step=False, process_group=None)[source] Computes the average precision score, which summarises the precision recall curve into one number. Works for both binary and multiclass problems. In the case of multiclass, the values will be calculated based on a one-vs-the-rest approach.
Forward accepts
preds
(float tensor):(N, ...)
(binary) or(N, C, ...)
(multiclass) tensor with probabilities, where C is the number of classes.target
(long tensor):(N, ...)
with integer labels
- Parameters
num_classes¶ (
Optional
[int
]) – integer with number of classes. Not nessesary to provide for binary problems.pos_label¶ (
Optional
[int
]) – integer determining the positive class. Default isNone
which for binary problem is translate to 1. For multiclass problems this argument should not be set as we iteratively change it in the range [0,num_classes-1]compute_on_step¶ (
bool
) – Forward only callsupdate()
and return None if this is set to False. default: Truedist_sync_on_step¶ (
bool
) – Synchronize metric state across processes at eachforward()
before returning the value at the step. default: Falseprocess_group¶ (
Optional
[Any
]) – Specify the process group on which synchronization is called. default: None (which selects the entire world)
- Example (binary case):
>>> from torchmetrics import AveragePrecision >>> pred = torch.tensor([0, 1, 2, 3]) >>> target = torch.tensor([0, 1, 1, 1]) >>> average_precision = AveragePrecision(pos_label=1) >>> average_precision(pred, target) tensor(1.)
- Example (multiclass case):
>>> pred = torch.tensor([[0.75, 0.05, 0.05, 0.05, 0.05], ... [0.05, 0.75, 0.05, 0.05, 0.05], ... [0.05, 0.05, 0.75, 0.05, 0.05], ... [0.05, 0.05, 0.05, 0.75, 0.05]]) >>> target = torch.tensor([0, 1, 3, 2]) >>> average_precision = AveragePrecision(num_classes=5) >>> average_precision(pred, target) [tensor(1.), tensor(1.), tensor(0.2500), tensor(0.2500), tensor(nan)]
Initializes internal Module state, shared by both nn.Module and ScriptModule.
-
compute
()[source] Compute the average precision score
AUC¶
-
class
torchmetrics.
AUC
(reorder=False, compute_on_step=True, dist_sync_on_step=False, process_group=None, dist_sync_fn=None)[source] Computes Area Under the Curve (AUC) using the trapezoidal rule
Forward accepts two input tensors that should be 1D and have the same number of elements
- Parameters
reorder¶ (
bool
) – AUC expects its first input to be sorted. If this is not the case, setting this argument toTrue
will use a stable sorting algorithm to sort the input in descending ordercompute_on_step¶ (
bool
) – Forward only callsupdate()
and return None if this is set to False.dist_sync_on_step¶ (
bool
) – Synchronize metric state across processes at eachforward()
before returning the value at the step.process_group¶ (
Optional
[Any
]) – Specify the process group on which synchronization is called. default: None (which selects the entire world)dist_sync_fn¶ (
Optional
[Callable
]) – Callback that performs theallgather
operation on the metric state. WhenNone
, DDP will be used to perform theallgather
.
Initializes internal Module state, shared by both nn.Module and ScriptModule.
AUROC¶
-
class
torchmetrics.
AUROC
(num_classes=None, pos_label=None, average='macro', max_fpr=None, compute_on_step=True, dist_sync_on_step=False, process_group=None, dist_sync_fn=None)[source] Compute Area Under the Receiver Operating Characteristic Curve (ROC AUC). Works for both binary, multilabel and multiclass problems. In the case of multiclass, the values will be calculated based on a one-vs-the-rest approach.
Forward accepts
preds
(float tensor):(N, ...)
(binary) or(N, C, ...)
(multiclass) tensor with probabilities, where C is the number of classes.target
(long tensor):(N, ...)
or(N, C, ...)
with integer labels
For non-binary input, if the
preds
andtarget
tensor have the same size the input will be interpretated as multilabel and ifpreds
have one dimension more than thetarget
tensor the input will be interpretated as multiclass.- Parameters
num_classes¶ (
Optional
[int
]) – integer with number of classes. Not nessesary to provide for binary problems.pos_label¶ (
Optional
[int
]) – integer determining the positive class. Default isNone
which for binary problem is translate to 1. For multiclass problems this argument should not be set as we iteratively change it in the range [0,num_classes-1]'micro'
computes metric globally. Only works for multilabel problems'macro'
computes metric for each class and uniformly averages them'weighted'
computes metric for each class and does a weighted-average, where each class is weighted by their support (accounts for class imbalance)None
computes and returns the metric per class
max_fpr¶ (
Optional
[float
]) – If notNone
, calculates standardized partial AUC over the range [0, max_fpr]. Should be a float between 0 and 1.compute_on_step¶ (
bool
) – Forward only callsupdate()
and return None if this is set to False. default: Truedist_sync_on_step¶ (
bool
) – Synchronize metric state across processes at eachforward()
before returning the value at the step.process_group¶ (
Optional
[Any
]) – Specify the process group on which synchronization is called. default: None (which selects the entire world)dist_sync_fn¶ (
Optional
[Callable
]) – Callback that performs the allgather operation on the metric state. WhenNone
, DDP will be used to perform the allgather
- Raises
ValueError – If
average
is none ofNone
,"macro"
or"weighted"
.ValueError – If
max_fpr
is not afloat
in the range(0, 1]
.RuntimeError – If
PyTorch version
isbelow 1.6
since max_fpr requirestorch.bucketize
which is not available below 1.6.ValueError – If the mode of data (binary, multi-label, multi-class) changes between batches.
- Example (binary case):
>>> from torchmetrics import AUROC >>> preds = torch.tensor([0.13, 0.26, 0.08, 0.19, 0.34]) >>> target = torch.tensor([0, 0, 1, 1, 1]) >>> auroc = AUROC(pos_label=1) >>> auroc(preds, target) tensor(0.5000)
- Example (multiclass case):
>>> preds = torch.tensor([[0.90, 0.05, 0.05], ... [0.05, 0.90, 0.05], ... [0.05, 0.05, 0.90], ... [0.85, 0.05, 0.10], ... [0.10, 0.10, 0.80]]) >>> target = torch.tensor([0, 1, 1, 2, 2]) >>> auroc = AUROC(num_classes=3) >>> auroc(preds, target) tensor(0.7778)
Initializes internal Module state, shared by both nn.Module and ScriptModule.
BinnedAveragePrecision¶
-
class
torchmetrics.
BinnedAveragePrecision
(num_classes, num_thresholds=100, compute_on_step=True, dist_sync_on_step=False, process_group=None)[source] Computes the average precision score, which summarises the precision recall curve into one number. Works for both binary and multiclass problems. In the case of multiclass, the values will be calculated based on a one-vs-the-rest approach.
Computation is performed in constant-memory by computing precision and recall for
num_thresholds
buckets/thresholds (evenly distributed between 0 and 1).Forward accepts
preds
(float tensor):(N, ...)
(binary) or(N, C, ...)
(multiclass) tensor with probabilities, where C is the number of classes.target
(long tensor):(N, ...)
with integer labels
- Parameters
num_classes¶ (
int
) – integer with number of classes. Not nessesary to provide for binary problems.num_thresholds¶ (
int
) – number of bins used for computation. More bins will lead to more detailed curve and accurate estimates, but will be slower and consume more memory. Default 100compute_on_step¶ (
bool
) – Forward only callsupdate()
and return None if this is set to False. default: Trueprocess_group¶ (
Optional
[Any
]) – Specify the process group on which synchronization is called. default: None (which selects the entire world)
- Example (binary case):
>>> from torchmetrics import BinnedAveragePrecision >>> pred = torch.tensor([0, 1, 2, 3]) >>> target = torch.tensor([0, 1, 1, 1]) >>> average_precision = BinnedAveragePrecision(num_classes=1, num_thresholds=10) >>> average_precision(pred, target) tensor(1.0000)
- Example (multiclass case):
>>> pred = torch.tensor([[0.75, 0.05, 0.05, 0.05, 0.05], ... [0.05, 0.75, 0.05, 0.05, 0.05], ... [0.05, 0.05, 0.75, 0.05, 0.05], ... [0.05, 0.05, 0.05, 0.75, 0.05]]) >>> target = torch.tensor([0, 1, 3, 2]) >>> average_precision = BinnedAveragePrecision(num_classes=5, num_thresholds=10) >>> average_precision(pred, target) [tensor(1.0000), tensor(1.0000), tensor(0.2500), tensor(0.2500), tensor(-0.)]
Initializes internal Module state, shared by both nn.Module and ScriptModule.
BinnedPrecisionRecallCurve¶
-
class
torchmetrics.
BinnedPrecisionRecallCurve
(num_classes, num_thresholds=100, compute_on_step=True, dist_sync_on_step=False, process_group=None)[source] Computes precision-recall pairs for different thresholds. Works for both binary and multiclass problems. In the case of multiclass, the values will be calculated based on a one-vs-the-rest approach.
Computation is performed in constant-memory by computing precision and recall for
num_thresholds
buckets/thresholds (evenly distributed between 0 and 1).Forward accepts
preds
(float tensor):(N, ...)
(binary) or(N, C, ...)
(multiclass) tensor with probabilities, where C is the number of classes.target
(long tensor):(N, ...)
or(N, C, ...)
with integer labels
- Parameters
num_classes¶ (
int
) – integer with number of classes. For binary, set to 1.num_thresholds¶ (
int
) – number of bins used for computation. More bins will lead to more detailed curve and accurate estimates, but will be slower and consume more memory. Default 100compute_on_step¶ (
bool
) – Forward only callsupdate()
and return None if this is set to False. default: Truedist_sync_on_step¶ (
bool
) – Synchronize metric state across processes at eachforward()
before returning the value at the step. default: Falseprocess_group¶ (
Optional
[Any
]) – Specify the process group on which synchronization is called. default: None (which selects the entire world)
- Example (binary case):
>>> from torchmetrics import BinnedPrecisionRecallCurve >>> pred = torch.tensor([0, 0.1, 0.8, 0.4]) >>> target = torch.tensor([0, 1, 1, 0]) >>> pr_curve = BinnedPrecisionRecallCurve(num_classes=1, num_thresholds=5) >>> precision, recall, thresholds = pr_curve(pred, target) >>> precision tensor([0.5000, 0.5000, 1.0000, 1.0000, 1.0000, 1.0000]) >>> recall tensor([1.0000, 0.5000, 0.5000, 0.5000, 0.0000, 0.0000]) >>> thresholds tensor([0.0000, 0.2500, 0.5000, 0.7500, 1.0000])
- Example (multiclass case):
>>> pred = torch.tensor([[0.75, 0.05, 0.05, 0.05, 0.05], ... [0.05, 0.75, 0.05, 0.05, 0.05], ... [0.05, 0.05, 0.75, 0.05, 0.05], ... [0.05, 0.05, 0.05, 0.75, 0.05]]) >>> target = torch.tensor([0, 1, 3, 2]) >>> pr_curve = BinnedPrecisionRecallCurve(num_classes=5, num_thresholds=3) >>> precision, recall, thresholds = pr_curve(pred, target) >>> precision [tensor([0.2500, 1.0000, 1.0000, 1.0000]), tensor([0.2500, 1.0000, 1.0000, 1.0000]), tensor([2.5000e-01, 1.0000e-06, 1.0000e+00, 1.0000e+00]), tensor([2.5000e-01, 1.0000e-06, 1.0000e+00, 1.0000e+00]), tensor([2.5000e-07, 1.0000e+00, 1.0000e+00, 1.0000e+00])] >>> recall [tensor([1.0000, 1.0000, 0.0000, 0.0000]), tensor([1.0000, 1.0000, 0.0000, 0.0000]), tensor([1.0000, 0.0000, 0.0000, 0.0000]), tensor([1.0000, 0.0000, 0.0000, 0.0000]), tensor([0., 0., 0., 0.])] >>> thresholds [tensor([0.0000, 0.5000, 1.0000]), tensor([0.0000, 0.5000, 1.0000]), tensor([0.0000, 0.5000, 1.0000]), tensor([0.0000, 0.5000, 1.0000]), tensor([0.0000, 0.5000, 1.0000])]
Initializes internal Module state, shared by both nn.Module and ScriptModule.
BinnedRecallAtFixedPrecision¶
-
class
torchmetrics.
BinnedRecallAtFixedPrecision
(num_classes, min_precision, num_thresholds=100, compute_on_step=True, dist_sync_on_step=False, process_group=None)[source] Computes the higest possible recall value given the minimum precision thresholds provided.
Computation is performed in constant-memory by computing precision and recall for
num_thresholds
buckets/thresholds (evenly distributed between 0 and 1).Forward accepts
preds
(float tensor):(N, ...)
(binary) or(N, C, ...)
(multiclass) tensor with probabilities, where C is the number of classes.target
(long tensor):(N, ...)
with integer labels
- Parameters
num_classes¶ (
int
) – integer with number of classes. Provide 1 for for binary problems.min_precision¶ (
float
) – float value specifying minimum precision threshold.num_thresholds¶ (
int
) – number of bins used for computation. More bins will lead to more detailed curve and accurate estimates, but will be slower and consume more memory. Default 100compute_on_step¶ (
bool
) – Forward only callsupdate()
and return None if this is set to False. default: Trueprocess_group¶ (
Optional
[Any
]) – Specify the process group on which synchronization is called. default: None (which selects the entire world)
- Example (binary case):
>>> from torchmetrics import BinnedRecallAtFixedPrecision >>> pred = torch.tensor([0, 0.2, 0.5, 0.8]) >>> target = torch.tensor([0, 1, 1, 0]) >>> average_precision = BinnedRecallAtFixedPrecision(num_classes=1, num_thresholds=10, min_precision=0.5) >>> average_precision(pred, target) (tensor(1.0000), tensor(0.1111))
- Example (multiclass case):
>>> pred = torch.tensor([[0.75, 0.05, 0.05, 0.05, 0.05], ... [0.05, 0.75, 0.05, 0.05, 0.05], ... [0.05, 0.05, 0.75, 0.05, 0.05], ... [0.05, 0.05, 0.05, 0.75, 0.05]]) >>> target = torch.tensor([0, 1, 3, 2]) >>> average_precision = BinnedRecallAtFixedPrecision(num_classes=5, num_thresholds=10, min_precision=0.5) >>> average_precision(pred, target) (tensor([1.0000, 1.0000, 0.0000, 0.0000, 0.0000]), tensor([6.6667e-01, 6.6667e-01, 1.0000e+06, 1.0000e+06, 1.0000e+06]))
Initializes internal Module state, shared by both nn.Module and ScriptModule.
CohenKappa¶
-
class
torchmetrics.
CohenKappa
(num_classes, weights=None, threshold=0.5, compute_on_step=True, dist_sync_on_step=False, process_group=None)[source] Calculates Cohen’s kappa score that measures inter-annotator agreement. It is defined as
where
is the empirical probability of agreement and
is the expected agreement when both annotators assign labels randomly. Note that
is estimated using a per-annotator empirical prior over the class labels.
Works with binary, multiclass, and multilabel data. Accepts probabilities from a model output or integer class values in prediction. Works with multi-dimensional preds and target.
- Forward accepts
preds
(float or long tensor):(N, ...)
or(N, C, ...)
where C is the number of classestarget
(long tensor):(N, ...)
If preds and target are the same shape and preds is a float tensor, we use the
self.threshold
argument to convert into integer labels. This is the case for binary and multi-label probabilities.If preds has an extra dimension as in the case of multi-class scores we perform an argmax on
dim=1
.- Parameters
weights¶ (
Optional
[str
]) – Weighting type to calculate the score. Choose from -None
or'none'
: no weighting -'linear'
: linear weighting -'quadratic'
: quadratic weightingthreshold¶ (
float
) – Threshold value for binary or multi-label probabilites. default: 0.5compute_on_step¶ (
bool
) – Forward only callsupdate()
and return None if this is set to False. default: Truedist_sync_on_step¶ (
bool
) – Synchronize metric state across processes at eachforward()
before returning the value at the step. default: Falseprocess_group¶ (
Optional
[Any
]) – Specify the process group on which synchronization is called. default: None (which selects the entire world)
Example
>>> from torchmetrics import CohenKappa >>> target = torch.tensor([1, 1, 0, 0]) >>> preds = torch.tensor([0, 1, 0, 0]) >>> cohenkappa = CohenKappa(num_classes=2) >>> cohenkappa(preds, target) tensor(0.5000)
Initializes internal Module state, shared by both nn.Module and ScriptModule.
ConfusionMatrix¶
-
class
torchmetrics.
ConfusionMatrix
(num_classes, normalize=None, threshold=0.5, multilabel=False, compute_on_step=True, dist_sync_on_step=False, process_group=None)[source] Computes the confusion matrix. Works with binary, multiclass, and multilabel data. Accepts probabilities from a model output or integer class values in prediction. Works with multi-dimensional preds and target, but it should be noted that additional dimensions will be flattened.
Forward accepts
preds
(float or long tensor):(N, ...)
or(N, C, ...)
where C is the number of classestarget
(long tensor):(N, ...)
If preds and target are the same shape and preds is a float tensor, we use the
self.threshold
argument to convert into integer labels. This is the case for binary and multi-label probabilities.If preds has an extra dimension as in the case of multi-class scores we perform an argmax on
dim=1
.If working with multilabel data, setting the is_multilabel argument to True will make sure that a confusion matrix gets calculated per label.
- Parameters
Normalization mode for confusion matrix. Choose from
None
or'none'
: no normalization (default)'true'
: normalization over the targets (most commonly used)'pred'
: normalization over the predictions'all'
: normalization over the whole matrix
threshold¶ (
float
) – Threshold value for binary or multi-label probabilites. default: 0.5multilabel¶ (
bool
) – determines if data is multilabel or not.compute_on_step¶ (
bool
) – Forward only callsupdate()
and return None if this is set to False. default: Truedist_sync_on_step¶ (
bool
) – Synchronize metric state across processes at eachforward()
before returning the value at the step. default: Falseprocess_group¶ (
Optional
[Any
]) – Specify the process group on which synchronization is called. default: None (which selects the entire world)
- Example (binary data):
>>> from torchmetrics import ConfusionMatrix >>> target = torch.tensor([1, 1, 0, 0]) >>> preds = torch.tensor([0, 1, 0, 0]) >>> confmat = ConfusionMatrix(num_classes=2) >>> confmat(preds, target) tensor([[2., 0.], [1., 1.]])
- Example (multiclass data):
>>> target = torch.tensor([2, 1, 0, 0]) >>> preds = torch.tensor([2, 1, 0, 1]) >>> confmat = ConfusionMatrix(num_classes=3) >>> confmat(preds, target) tensor([[1., 1., 0.], [0., 1., 0.], [0., 0., 1.]])
- Example (multilabel data):
>>> target = torch.tensor([[0, 1, 0], [1, 0, 1]]) >>> preds = torch.tensor([[0, 0, 1], [1, 0, 1]]) >>> confmat = ConfusionMatrix(num_classes=3, multilabel=True) >>> confmat(preds, target) tensor([[[1., 0.], [0., 1.]], [[1., 0.], [1., 0.]], [[0., 1.], [0., 1.]]])
Initializes internal Module state, shared by both nn.Module and ScriptModule.
-
compute
()[source] Computes confusion matrix.
- Return type
- Returns
If multilabel=False this will be a [n_classes, n_classes] tensor and if multilabel=True this will be a [n_classes, 2, 2] tensor
F1¶
-
class
torchmetrics.
F1
(num_classes=None, threshold=0.5, average='micro', mdmc_average=None, ignore_index=None, top_k=None, multiclass=None, compute_on_step=True, dist_sync_on_step=False, process_group=None, dist_sync_fn=None, multilabel=None)[source] Computes F1 metric. F1 metrics correspond to a harmonic mean of the precision and recall scores.
Works with binary, multiclass, and multilabel data. Accepts logits from a model output or integer class values in prediction. Works with multi-dimensional preds and target.
Forward accepts
preds
(float or long tensor):(N, ...)
or(N, C, ...)
where C is the number of classestarget
(long tensor):(N, ...)
If preds and target are the same shape and preds is a float tensor, we use the
self.threshold
argument. This is the case for binary and multi-label logits.If preds has an extra dimension as in the case of multi-class scores we perform an argmax on
dim=1
.- Parameters
num_classes¶ (
Optional
[int
]) – Number of classes. Necessary for'macro'
,'weighted'
andNone
average methods.threshold¶ (
float
) – Threshold probability value for transforming probability predictions to binary (0,1) predictions, in the case of binary or multi-label inputs.Defines the reduction that is applied. Should be one of the following:
'micro'
[default]: Calculate the metric globally, across all samples and classes.'macro'
: Calculate the metric for each class separately, and average the metrics across classes (with equal weights for each class).'weighted'
: Calculate the metric for each class separately, and average the metrics across classes, weighting each class by its support (tp + fn
).'none'
orNone
: Calculate the metric for each class separately, and return the metric for every class.'samples'
: Calculate the metric for each sample, and average the metrics across samples (with equal weights for each sample).
Note
What is considered a sample in the multi-dimensional multi-class case depends on the value of
mdmc_average
.mdmc_average¶ (
Optional
[str
]) –Defines how averaging is done for multi-dimensional multi-class inputs (on top of the
average
parameter). Should be one of the following:None
[default]: Should be left unchanged if your data is not multi-dimensional multi-class.'samplewise'
: In this case, the statistics are computed separately for each sample on theN
axis, and then averaged over samples. The computation for each sample is done by treating the flattened extra axes...
(see Input types) as theN
dimension within the sample, and computing the metric for the sample based on that.'global'
: In this case theN
and...
dimensions of the inputs (see Input types) are flattened into a newN_X
sample axis, i.e. the inputs are treated as if they were(N_X, C)
. From here on theaverage
parameter applies as usual.
ignore_index¶ (
Optional
[int
]) – Integer specifying a target class to ignore. If given, this class index does not contribute to the returned score, regardless of reduction method. If an index is ignored, andaverage=None
or'none'
, the score for the ignored class will be returned asnan
.Number of highest probability entries for each sample to convert to 1s - relevant only for inputs with probability predictions. If this parameter is set for multi-label inputs, it will take precedence over
threshold
. For (multi-dim) multi-class inputs, this parameter defaults to 1.Should be left unset (
None
) for inputs with label predictions.multiclass¶ (
Optional
[bool
]) – Used only in certain special cases, where you want to treat inputs as a different type than what they appear to be. See the parameter’s documentation section for a more detailed explanation and examples.compute_on_step¶ (
bool
) – Forward only callsupdate()
and returnNone
if this is set toFalse
.dist_sync_on_step¶ (
bool
) – Synchronize metric state across processes at eachforward()
before returning the value at the stepprocess_group¶ (
Optional
[Any
]) – Specify the process group on which synchronization is called. default:None
(which selects the entire world)dist_sync_fn¶ (
Optional
[Callable
]) – Callback that performs the allgather operation on the metric state. WhenNone
, DDP will be used to perform the allgather.multilabel¶ (
Optional
[bool
]) –Deprecated since version 0.3: Argument will not have any effect and will be removed in v0.4, please use
multiclass
intead.
Example
>>> from torchmetrics import F1 >>> target = torch.tensor([0, 1, 2, 0, 1, 2]) >>> preds = torch.tensor([0, 2, 1, 0, 0, 1]) >>> f1 = F1(num_classes=3) >>> f1(preds, target) tensor(0.3333)
Initializes internal Module state, shared by both nn.Module and ScriptModule.
FBeta¶
-
class
torchmetrics.
FBeta
(num_classes=None, beta=1.0, threshold=0.5, average='micro', mdmc_average=None, ignore_index=None, top_k=None, multiclass=None, compute_on_step=True, dist_sync_on_step=False, process_group=None, dist_sync_fn=None, multilabel=None)[source] Computes F-score, specifically:
Where
is some positive real factor. Works with binary, multiclass, and multilabel data. Accepts probabilities from a model output or integer class values in prediction. Works with multi-dimensional preds and target.
Forward accepts
preds
(float or long tensor):(N, ...)
or(N, C, ...)
where C is the number of classestarget
(long tensor):(N, ...)
If preds and target are the same shape and preds is a float tensor, we use the
self.threshold
argument to convert into integer labels. This is the case for binary and multi-label probabilities.If preds has an extra dimension as in the case of multi-class scores we perform an argmax on
dim=1
.- Parameters
num_classes¶ (
Optional
[int
]) – Number of classes. Necessary for'macro'
,'weighted'
andNone
average methods.threshold¶ (
float
) – Threshold probability value for transforming probability predictions to binary (0,1) predictions, in the case of binary or multi-label inputs.Defines the reduction that is applied. Should be one of the following:
'micro'
[default]: Calculate the metric globally, across all samples and classes.'macro'
: Calculate the metric for each class separately, and average the metrics across classes (with equal weights for each class).'weighted'
: Calculate the metric for each class separately, and average the metrics across classes, weighting each class by its support (tp + fn
).'none'
orNone
: Calculate the metric for each class separately, and return the metric for every class.'samples'
: Calculate the metric for each sample, and average the metrics across samples (with equal weights for each sample).
Note
What is considered a sample in the multi-dimensional multi-class case depends on the value of
mdmc_average
.mdmc_average¶ (
Optional
[str
]) –Defines how averaging is done for multi-dimensional multi-class inputs (on top of the
average
parameter). Should be one of the following:None
[default]: Should be left unchanged if your data is not multi-dimensional multi-class.'samplewise'
: In this case, the statistics are computed separately for each sample on theN
axis, and then averaged over samples. The computation for each sample is done by treating the flattened extra axes...
(see Input types) as theN
dimension within the sample, and computing the metric for the sample based on that.'global'
: In this case theN
and...
dimensions of the inputs (see Input types) are flattened into a newN_X
sample axis, i.e. the inputs are treated as if they were(N_X, C)
. From here on theaverage
parameter applies as usual.
ignore_index¶ (
Optional
[int
]) – Integer specifying a target class to ignore. If given, this class index does not contribute to the returned score, regardless of reduction method. If an index is ignored, andaverage=None
or'none'
, the score for the ignored class will be returned asnan
.Number of highest probability entries for each sample to convert to 1s - relevant only for inputs with probability predictions. If this parameter is set for multi-label inputs, it will take precedence over
threshold
. For (multi-dim) multi-class inputs, this parameter defaults to 1.Should be left unset (
None
) for inputs with label predictions.multiclass¶ (
Optional
[bool
]) – Used only in certain special cases, where you want to treat inputs as a different type than what they appear to be. See the parameter’s documentation section for a more detailed explanation and examples.compute_on_step¶ (
bool
) – Forward only callsupdate()
and returnNone
if this is set toFalse
.dist_sync_on_step¶ (
bool
) – Synchronize metric state across processes at eachforward()
before returning the value at the stepprocess_group¶ (
Optional
[Any
]) – Specify the process group on which synchronization is called. default:None
(which selects the entire world)dist_sync_fn¶ (
Optional
[Callable
]) – Callback that performs the allgather operation on the metric state. WhenNone
, DDP will be used to perform the allgather.multilabel¶ (
Optional
[bool
]) –Deprecated since version 0.3: Argument will not have any effect and will be removed in v0.4, please use
multiclass
intead.
- Raises
ValueError – If
average
is none of"micro"
,"macro"
,"weighted"
,"none"
,None
.
Example
>>> from torchmetrics import FBeta >>> target = torch.tensor([0, 1, 2, 0, 1, 2]) >>> preds = torch.tensor([0, 2, 1, 0, 0, 1]) >>> f_beta = FBeta(num_classes=3, beta=0.5) >>> f_beta(preds, target) tensor(0.3333)
Initializes internal Module state, shared by both nn.Module and ScriptModule.
HammingDistance¶
-
class
torchmetrics.
HammingDistance
(threshold=0.5, compute_on_step=True, dist_sync_on_step=False, process_group=None, dist_sync_fn=None)[source] Computes the average Hamming distance (also known as Hamming loss) between targets and predictions:
Where
is a tensor of target values,
is a tensor of predictions, and
refers to the
-th label of the
-th sample of that tensor.
This is the same as
1-accuracy
for binary data, while for all other types of inputs it treats each possible label separately - meaning that, for example, multi-class data is treated as if it were multi-label.Accepts all input types listed in Input types.
- Parameters
threshold¶ (
float
) – Threshold probability value for transforming probability predictions to binary (0 or 1) predictions, in the case of binary or multi-label inputs.compute_on_step¶ (
bool
) – Forward only callsupdate()
and returnNone
if this is set toFalse
.dist_sync_on_step¶ (
bool
) – Synchronize metric state across processes at eachforward()
before returning the value at the step.process_group¶ (
Optional
[Any
]) – Specify the process group on which synchronization is called. default:None
(which selects the entire world)dist_sync_fn¶ (
Optional
[Callable
]) – Callback that performs the allgather operation on the metric state. WhenNone
, DDP will be used to perform the all gather.
- Raises
ValueError – If
threshold
is not between0
and1
.
Example
>>> from torchmetrics import HammingDistance >>> target = torch.tensor([[0, 1], [1, 1]]) >>> preds = torch.tensor([[0, 1], [0, 1]]) >>> hamming_distance = HammingDistance() >>> hamming_distance(preds, target) tensor(0.2500)
Initializes internal Module state, shared by both nn.Module and ScriptModule.
-
compute
()[source] Computes hamming distance based on inputs passed in to
update
previously.- Return type
Hinge¶
-
class
torchmetrics.
Hinge
(squared=False, multiclass_mode=None, compute_on_step=True, dist_sync_on_step=False, process_group=None, dist_sync_fn=None)[source] Computes the mean Hinge loss, typically used for Support Vector Machines (SVMs). In the binary case it is defined as:
Where
is the target, and
is the prediction.
In the multi-class case, when
multiclass_mode=None
(default),multiclass_mode=MulticlassMode.CRAMMER_SINGER
ormulticlass_mode="crammer-singer"
, this metric will compute the multi-class hinge loss defined by Crammer and Singer as:Where
is the target class (where
is the number of classes), and
is the predicted output per class.
In the multi-class case when
multiclass_mode=MulticlassMode.ONE_VS_ALL
ormulticlass_mode='one-vs-all'
, this metric will use a one-vs-all approach to compute the hinge loss, giving a vector of C outputs where each entry pits that class against all remaining classes.This metric can optionally output the mean of the squared hinge loss by setting
squared=True
Only accepts inputs with preds shape of (N) (binary) or (N, C) (multi-class) and target shape of (N).
- Parameters
squared¶ (
bool
) – If True, this will compute the squared hinge loss. Otherwise, computes the regular hinge loss (default).multiclass_mode¶ (
Union
[str
,MulticlassMode
,None
]) – Which approach to use for multi-class inputs (has no effect in the binary case).None
(default),MulticlassMode.CRAMMER_SINGER
or"crammer-singer"
, uses the Crammer Singer multi-class hinge loss.MulticlassMode.ONE_VS_ALL
or"one-vs-all"
computes the hinge loss in a one-vs-all fashion.
- Raises
ValueError – If
multiclass_mode
is not: None,MulticlassMode.CRAMMER_SINGER
,"crammer-singer"
,MulticlassMode.ONE_VS_ALL
or"one-vs-all"
.
- Example (binary case):
>>> import torch >>> from torchmetrics import Hinge >>> target = torch.tensor([0, 1, 1]) >>> preds = torch.tensor([-2.2, 2.4, 0.1]) >>> hinge = Hinge() >>> hinge(preds, target) tensor(0.3000)
- Example (default / multiclass case):
>>> target = torch.tensor([0, 1, 2]) >>> preds = torch.tensor([[-1.0, 0.9, 0.2], [0.5, -1.1, 0.8], [2.2, -0.5, 0.3]]) >>> hinge = Hinge() >>> hinge(preds, target) tensor(2.9000)
- Example (multiclass example, one vs all mode):
>>> target = torch.tensor([0, 1, 2]) >>> preds = torch.tensor([[-1.0, 0.9, 0.2], [0.5, -1.1, 0.8], [2.2, -0.5, 0.3]]) >>> hinge = Hinge(multiclass_mode="one-vs-all") >>> hinge(preds, target) tensor([2.2333, 1.5000, 1.2333])
Initializes internal Module state, shared by both nn.Module and ScriptModule.
-
compute
()[source] Override this method to compute the final metric value from state variables synchronized across the distributed backend.
- Return type
-
update
(preds, target)[source] Override this method to update the state variables of your metric class.
IoU¶
-
class
torchmetrics.
IoU
(num_classes, ignore_index=None, absent_score=0.0, threshold=0.5, reduction='elementwise_mean', compute_on_step=True, dist_sync_on_step=False, process_group=None)[source] Computes Intersection over union, or Jaccard index calculation:
Where:
and
are both tensors of the same size, containing integer class values. They may be subject to conversion from input data (see description below). Note that it is different from box IoU.
Works with binary, multiclass and multi-label data. Accepts probabilities from a model output or integer class values in prediction. Works with multi-dimensional preds and target.
Forward accepts
preds
(float or long tensor):(N, ...)
or(N, C, ...)
where C is the number of classestarget
(long tensor):(N, ...)
If preds and target are the same shape and preds is a float tensor, we use the
self.threshold
argument to convert into integer labels. This is the case for binary and multi-label probabilities.If preds has an extra dimension as in the case of multi-class scores we perform an argmax on
dim=1
.- Parameters
ignore_index¶ (
Optional
[int
]) – optional int specifying a target class to ignore. If given, this class index does not contribute to the returned score, regardless of reduction method. Has no effect if given an int that is not in the range [0, num_classes-1]. By default, no index is ignored, and all classes are used.absent_score¶ (
float
) – score to use for an individual class, if no instances of the class index were present in pred AND no instances of the class index were present in target. For example, if we have 3 classes, [0, 0] for pred, and [0, 2] for target, then class 1 would be assigned the absent_score.threshold¶ (
float
) – Threshold value for binary or multi-label probabilities.a method to reduce metric score over labels.
'elementwise_mean'
: takes the mean (default)'sum'
: takes the sum'none'
: no reduction will be applied
compute_on_step¶ (
bool
) – Forward only callsupdate()
and return None if this is set to False.dist_sync_on_step¶ (
bool
) – Synchronize metric state across processes at eachforward()
before returning the value at the step.process_group¶ (
Optional
[Any
]) – Specify the process group on which synchronization is called. default: None (which selects the entire world)
Example
>>> from torchmetrics import IoU >>> target = torch.randint(0, 2, (10, 25, 25)) >>> pred = torch.tensor(target) >>> pred[2:5, 7:13, 9:15] = 1 - pred[2:5, 7:13, 9:15] >>> iou = IoU(num_classes=2) >>> iou(pred, target) tensor(0.9660)
Initializes internal Module state, shared by both nn.Module and ScriptModule.
MatthewsCorrcoef¶
-
class
torchmetrics.
MatthewsCorrcoef
(num_classes, threshold=0.5, compute_on_step=True, dist_sync_on_step=False, process_group=None, dist_sync_fn=None)[source] Calculates Matthews correlation coefficient that measures the general correlation or quality of a classification. In the binary case it is defined as:
where TP, TN, FP and FN are respectively the true postitives, true negatives, false positives and false negatives. Also works in the case of multi-label or multi-class input.
Note
This metric produces a multi-dimensional output, so it can not be directly logged.
Forward accepts
preds
(float or long tensor):(N, ...)
or(N, C, ...)
where C is the number of classestarget
(long tensor):(N, ...)
If preds and target are the same shape and preds is a float tensor, we use the
self.threshold
argument to convert into integer labels. This is the case for binary and multi-label probabilities.If preds has an extra dimension as in the case of multi-class scores we perform an argmax on
dim=1
.- Parameters
threshold¶ (
float
) – Threshold value for binary or multi-label probabilites. default: 0.5compute_on_step¶ (
bool
) – Forward only callsupdate()
and return None if this is set to False. default: Truedist_sync_on_step¶ (
bool
) – Synchronize metric state across processes at eachforward()
before returning the value at the step. default: Falseprocess_group¶ (
Optional
[Any
]) – Specify the process group on which synchronization is called. default: None (which selects the entire world)dist_sync_fn¶ (
Optional
[Callable
]) – Callback that performs the allgather operation on the metric state. WhenNone
, DDP will be used to perform the allgather
Example
>>> from torchmetrics import MatthewsCorrcoef >>> target = torch.tensor([1, 1, 0, 0]) >>> preds = torch.tensor([0, 1, 0, 0]) >>> matthews_corrcoef = MatthewsCorrcoef(num_classes=2) >>> matthews_corrcoef(preds, target) tensor(0.5774)
Initializes internal Module state, shared by both nn.Module and ScriptModule.
Precision¶
-
class
torchmetrics.
Precision
(num_classes=None, threshold=0.5, average='micro', mdmc_average=None, ignore_index=None, top_k=None, multiclass=None, compute_on_step=True, dist_sync_on_step=False, process_group=None, dist_sync_fn=None, multilabel=None, is_multiclass=None)[source] Computes Precision:
Where
and
represent the number of true positives and false positives respecitively. With the use of
top_k
parameter, this metric can generalize to Precision@K.The reduction method (how the precision scores are aggregated) is controlled by the
average
parameter, and additionally by themdmc_average
parameter in the multi-dimensional multi-class case. Accepts all inputs listed in Input types.- Parameters
num_classes¶ (
Optional
[int
]) – Number of classes. Necessary for'macro'
,'weighted'
andNone
average methods.threshold¶ (
float
) – Threshold probability value for transforming probability predictions to binary (0,1) predictions, in the case of binary or multi-label inputs.Defines the reduction that is applied. Should be one of the following:
'micro'
[default]: Calculate the metric globally, across all samples and classes.'macro'
: Calculate the metric for each class separately, and average the metrics across classes (with equal weights for each class).'weighted'
: Calculate the metric for each class separately, and average the metrics across classes, weighting each class by its support (tp + fn
).'none'
orNone
: Calculate the metric for each class separately, and return the metric for every class.'samples'
: Calculate the metric for each sample, and average the metrics across samples (with equal weights for each sample).
Note
What is considered a sample in the multi-dimensional multi-class case depends on the value of
mdmc_average
.mdmc_average¶ (
Optional
[str
]) –Defines how averaging is done for multi-dimensional multi-class inputs (on top of the
average
parameter). Should be one of the following:None
[default]: Should be left unchanged if your data is not multi-dimensional multi-class.'samplewise'
: In this case, the statistics are computed separately for each sample on theN
axis, and then averaged over samples. The computation for each sample is done by treating the flattened extra axes...
(see Input types) as theN
dimension within the sample, and computing the metric for the sample based on that.'global'
: In this case theN
and...
dimensions of the inputs (see Input types) are flattened into a newN_X
sample axis, i.e. the inputs are treated as if they were(N_X, C)
. From here on theaverage
parameter applies as usual.
ignore_index¶ (
Optional
[int
]) – Integer specifying a target class to ignore. If given, this class index does not contribute to the returned score, regardless of reduction method. If an index is ignored, andaverage=None
or'none'
, the score for the ignored class will be returned asnan
.Number of highest probability entries for each sample to convert to 1s - relevant only for inputs with probability predictions. If this parameter is set for multi-label inputs, it will take precedence over
threshold
. For (multi-dim) multi-class inputs, this parameter defaults to 1.Should be left unset (
None
) for inputs with label predictions.multiclass¶ (
Optional
[bool
]) – Used only in certain special cases, where you want to treat inputs as a different type than what they appear to be. See the parameter’s documentation section for a more detailed explanation and examples.compute_on_step¶ (
bool
) – Forward only callsupdate()
and returnNone
if this is set toFalse
.dist_sync_on_step¶ (
bool
) – Synchronize metric state across processes at eachforward()
before returning the value at the stepprocess_group¶ (
Optional
[Any
]) – Specify the process group on which synchronization is called. default:None
(which selects the entire world)dist_sync_fn¶ (
Optional
[Callable
]) – Callback that performs the allgather operation on the metric state. WhenNone
, DDP will be used to perform the allgather.multilabel¶ (
Optional
[bool
]) –Deprecated since version 0.3: Argument will not have any effect and will be removed in v0.4, please use
multiclass
intead.is_multiclass¶ (
Optional
[bool
]) –Deprecated since version 0.3: Argument will not have any effect and will be removed in v0.4, please use
multiclass
intead.
- Raises
ValueError – If
average
is none of"micro"
,"macro"
,"weighted"
,"samples"
,"none"
,None
.
Example
>>> from torchmetrics import Precision >>> preds = torch.tensor([2, 0, 2, 1]) >>> target = torch.tensor([1, 1, 2, 0]) >>> precision = Precision(average='macro', num_classes=3) >>> precision(preds, target) tensor(0.1667) >>> precision = Precision(average='micro') >>> precision(preds, target) tensor(0.2500)
Initializes internal Module state, shared by both nn.Module and ScriptModule.
-
compute
()[source] Computes the precision score based on inputs passed in to
update
previously.- Return type
- Returns
The shape of the returned tensor depends on the
average
parameterIf
average in ['micro', 'macro', 'weighted', 'samples']
, a one-element tensor will be returnedIf
average in ['none', None]
, the shape will be(C,)
, whereC
stands for the number of classes
PrecisionRecallCurve¶
-
class
torchmetrics.
PrecisionRecallCurve
(num_classes=None, pos_label=None, compute_on_step=True, dist_sync_on_step=False, process_group=None)[source] Computes precision-recall pairs for different thresholds. Works for both binary and multiclass problems. In the case of multiclass, the values will be calculated based on a one-vs-the-rest approach.
Forward accepts
preds
(float tensor):(N, ...)
(binary) or(N, C, ...)
(multiclass) tensor with probabilities, where C is the number of classes.target
(long tensor):(N, ...)
or(N, C, ...)
with integer labels
- Parameters
num_classes¶ (
Optional
[int
]) – integer with number of classes. Not nessesary to provide for binary problems.pos_label¶ (
Optional
[int
]) – integer determining the positive class. Default isNone
which for binary problem is translate to 1. For multiclass problems this argument should not be set as we iteratively change it in the range [0,num_classes-1]compute_on_step¶ (
bool
) – Forward only callsupdate()
and return None if this is set to False. default: Truedist_sync_on_step¶ (
bool
) – Synchronize metric state across processes at eachforward()
before returning the value at the step. default: Falseprocess_group¶ (
Optional
[Any
]) – Specify the process group on which synchronization is called. default: None (which selects the entire world)
- Example (binary case):
>>> from torchmetrics import PrecisionRecallCurve >>> pred = torch.tensor([0, 1, 2, 3]) >>> target = torch.tensor([0, 1, 1, 0]) >>> pr_curve = PrecisionRecallCurve(pos_label=1) >>> precision, recall, thresholds = pr_curve(pred, target) >>> precision tensor([0.6667, 0.5000, 0.0000, 1.0000]) >>> recall tensor([1.0000, 0.5000, 0.0000, 0.0000]) >>> thresholds tensor([1, 2, 3])
- Example (multiclass case):
>>> pred = torch.tensor([[0.75, 0.05, 0.05, 0.05, 0.05], ... [0.05, 0.75, 0.05, 0.05, 0.05], ... [0.05, 0.05, 0.75, 0.05, 0.05], ... [0.05, 0.05, 0.05, 0.75, 0.05]]) >>> target = torch.tensor([0, 1, 3, 2]) >>> pr_curve = PrecisionRecallCurve(num_classes=5) >>> precision, recall, thresholds = pr_curve(pred, target) >>> precision [tensor([1., 1.]), tensor([1., 1.]), tensor([0.2500, 0.0000, 1.0000]), tensor([0.2500, 0.0000, 1.0000]), tensor([0., 1.])] >>> recall [tensor([1., 0.]), tensor([1., 0.]), tensor([1., 0., 0.]), tensor([1., 0., 0.]), tensor([nan, 0.])] >>> thresholds [tensor([0.7500]), tensor([0.7500]), tensor([0.0500, 0.7500]), tensor([0.0500, 0.7500]), tensor([0.0500])]
Initializes internal Module state, shared by both nn.Module and ScriptModule.
-
compute
()[source] Compute the precision-recall curve
- Return type
Union
[Tuple
[Tensor
,Tensor
,Tensor
],Tuple
[List
[Tensor
],List
[Tensor
],List
[Tensor
]]]- Returns
3-element tuple containing
- precision:
tensor where element i is the precision of predictions with score >= thresholds[i] and the last element is 1. If multiclass, this is a list of such tensors, one for each class.
- recall:
tensor where element i is the recall of predictions with score >= thresholds[i] and the last element is 0. If multiclass, this is a list of such tensors, one for each class.
- thresholds:
Thresholds used for computing precision/recall scores
Recall¶
-
class
torchmetrics.
Recall
(num_classes=None, threshold=0.5, average='micro', mdmc_average=None, ignore_index=None, top_k=None, multiclass=None, compute_on_step=True, dist_sync_on_step=False, process_group=None, dist_sync_fn=None, multilabel=None, is_multiclass=None)[source] Computes Recall:
Where
and
represent the number of true positives and false negatives respecitively. With the use of
top_k
parameter, this metric can generalize to Recall@K.The reduction method (how the recall scores are aggregated) is controlled by the
average
parameter, and additionally by themdmc_average
parameter in the multi-dimensional multi-class case. Accepts all inputs listed in Input types.- Parameters
num_classes¶ (
Optional
[int
]) – Number of classes. Necessary for'macro'
,'weighted'
andNone
average methods.threshold¶ (
float
) – Threshold probability value for transforming probability predictions to binary (0,1) predictions, in the case of binary or multi-label inputs.Defines the reduction that is applied. Should be one of the following:
'micro'
[default]: Calculate the metric globally, across all samples and classes.'macro'
: Calculate the metric for each class separately, and average the metrics across classes (with equal weights for each class).'weighted'
: Calculate the metric for each class separately, and average the metrics across classes, weighting each class by its support (tp + fn
).'none'
orNone
: Calculate the metric for each class separately, and return the metric for every class.'samples'
: Calculate the metric for each sample, and average the metrics across samples (with equal weights for each sample).
Note
What is considered a sample in the multi-dimensional multi-class case depends on the value of
mdmc_average
.mdmc_average¶ (
Optional
[str
]) –Defines how averaging is done for multi-dimensional multi-class inputs (on top of the
average
parameter). Should be one of the following:None
[default]: Should be left unchanged if your data is not multi-dimensional multi-class.'samplewise'
: In this case, the statistics are computed separately for each sample on theN
axis, and then averaged over samples. The computation for each sample is done by treating the flattened extra axes...
(see Input types) as theN
dimension within the sample, and computing the metric for the sample based on that.'global'
: In this case theN
and...
dimensions of the inputs (see Input types) are flattened into a newN_X
sample axis, i.e. the inputs are treated as if they were(N_X, C)
. From here on theaverage
parameter applies as usual.
ignore_index¶ (
Optional
[int
]) – Integer specifying a target class to ignore. If given, this class index does not contribute to the returned score, regardless of reduction method. If an index is ignored, andaverage=None
or'none'
, the score for the ignored class will be returned asnan
.Number of highest probability entries for each sample to convert to 1s - relevant only for inputs with probability predictions. If this parameter is set for multi-label inputs, it will take precedence over
threshold
. For (multi-dim) multi-class inputs, this parameter defaults to 1.Should be left unset (
None
) for inputs with label predictions.multiclass¶ (
Optional
[bool
]) – Used only in certain special cases, where you want to treat inputs as a different type than what they appear to be. See the parameter’s documentation section for a more detailed explanation and examples.compute_on_step¶ (
bool
) – Forward only callsupdate()
and returnNone
if this is set toFalse
.dist_sync_on_step¶ (
bool
) – Synchronize metric state across processes at eachforward()
before returning the value at the stepprocess_group¶ (
Optional
[Any
]) – Specify the process group on which synchronization is called. default:None
(which selects the entire world)dist_sync_fn¶ (
Optional
[Callable
]) – Callback that performs the allgather operation on the metric state. WhenNone
, DDP will be used to perform the allgather.multilabel¶ (
Optional
[bool
]) –Deprecated since version 0.3: Argument will not have any effect and will be removed in v0.4, please use
multiclass
intead.is_multiclass¶ (
Optional
[bool
]) –Deprecated since version 0.3: Argument will not have any effect and will be removed in v0.4, please use
multiclass
intead.
- Raises
ValueError – If
average
is none of"micro"
,"macro"
,"weighted"
,"samples"
,"none"
,None
.
Example
>>> from torchmetrics import Recall >>> preds = torch.tensor([2, 0, 2, 1]) >>> target = torch.tensor([1, 1, 2, 0]) >>> recall = Recall(average='macro', num_classes=3) >>> recall(preds, target) tensor(0.3333) >>> recall = Recall(average='micro') >>> recall(preds, target) tensor(0.2500)
Initializes internal Module state, shared by both nn.Module and ScriptModule.
-
compute
()[source] Computes the recall score based on inputs passed in to
update
previously.- Return type
- Returns
The shape of the returned tensor depends on the
average
parameterIf
average in ['micro', 'macro', 'weighted', 'samples']
, a one-element tensor will be returnedIf
average in ['none', None]
, the shape will be(C,)
, whereC
stands for the number of classes
ROC¶
-
class
torchmetrics.
ROC
(num_classes=None, pos_label=None, compute_on_step=True, dist_sync_on_step=False, process_group=None, dist_sync_fn=None)[source] Computes the Receiver Operating Characteristic (ROC). Works for both binary, multiclass and multilabel problems. In the case of multiclass, the values will be calculated based on a one-vs-the-rest approach.
Forward accepts
preds
(float tensor):(N, ...)
(binary) or(N, C, ...)
(multiclass/multilabel) tensor with probabilities, where C is the number of classes/labels.target
(long tensor):(N, ...)
or(N, C, ...)
with integer labels
- Parameters
num_classes¶ (
Optional
[int
]) – integer with number of classes. Not nessesary to provide for binary problems.pos_label¶ (
Optional
[int
]) – integer determining the positive class. Default isNone
which for binary problem is translate to 1. For multiclass problems this argument should not be set as we iteratively change it in the range [0,num_classes-1]compute_on_step¶ (
bool
) – Forward only callsupdate()
and return None if this is set to False. default: Truedist_sync_on_step¶ (
bool
) – Synchronize metric state across processes at eachforward()
before returning the value at the step. default: Falseprocess_group¶ (
Optional
[Any
]) – Specify the process group on which synchronization is called. default: None (which selects the entire world)dist_sync_fn¶ (
Optional
[Callable
]) – Callback that performs the allgather operation on the metric state. WhenNone
, DDP will be used to perform the allgather
- Example (binary case):
>>> from torchmetrics import ROC >>> pred = torch.tensor([0, 1, 2, 3]) >>> target = torch.tensor([0, 1, 1, 1]) >>> roc = ROC(pos_label=1) >>> fpr, tpr, thresholds = roc(pred, target) >>> fpr tensor([0., 0., 0., 0., 1.]) >>> tpr tensor([0.0000, 0.3333, 0.6667, 1.0000, 1.0000]) >>> thresholds tensor([4, 3, 2, 1, 0])
- Example (multiclass case):
>>> pred = torch.tensor([[0.75, 0.05, 0.05, 0.05], ... [0.05, 0.75, 0.05, 0.05], ... [0.05, 0.05, 0.75, 0.05], ... [0.05, 0.05, 0.05, 0.75]]) >>> target = torch.tensor([0, 1, 3, 2]) >>> roc = ROC(num_classes=4) >>> fpr, tpr, thresholds = roc(pred, target) >>> fpr [tensor([0., 0., 1.]), tensor([0., 0., 1.]), tensor([0.0000, 0.3333, 1.0000]), tensor([0.0000, 0.3333, 1.0000])] >>> tpr [tensor([0., 1., 1.]), tensor([0., 1., 1.]), tensor([0., 0., 1.]), tensor([0., 0., 1.])] >>> thresholds [tensor([1.7500, 0.7500, 0.0500]), tensor([1.7500, 0.7500, 0.0500]), tensor([1.7500, 0.7500, 0.0500]), tensor([1.7500, 0.7500, 0.0500])]
- Example (multilabel case):
>>> pred = torch.tensor([[0.8191, 0.3680, 0.1138], ... [0.3584, 0.7576, 0.1183], ... [0.2286, 0.3468, 0.1338], ... [0.8603, 0.0745, 0.1837]]) >>> target = torch.tensor([[1, 1, 0], [0, 1, 0], [0, 0, 0], [0, 1, 1]]) >>> roc = ROC(num_classes=3, pos_label=1) >>> fpr, tpr, thresholds = roc(pred, target) >>> fpr [tensor([0.0000, 0.3333, 0.3333, 0.6667, 1.0000]), tensor([0., 0., 0., 1., 1.]), tensor([0.0000, 0.0000, 0.3333, 0.6667, 1.0000])] >>> tpr [tensor([0., 0., 1., 1., 1.]), tensor([0.0000, 0.3333, 0.6667, 0.6667, 1.0000]), tensor([0., 1., 1., 1., 1.])] >>> thresholds [tensor([1.8603, 0.8603, 0.8191, 0.3584, 0.2286]), tensor([1.7576, 0.7576, 0.3680, 0.3468, 0.0745]), tensor([1.1837, 0.1837, 0.1338, 0.1183, 0.1138])]
Initializes internal Module state, shared by both nn.Module and ScriptModule.
-
compute
()[source] Compute the receiver operating characteristic
- Return type
Union
[Tuple
[Tensor
,Tensor
,Tensor
],Tuple
[List
[Tensor
],List
[Tensor
],List
[Tensor
]]]- Returns
3-element tuple containing
- fpr:
tensor with false positive rates. If multiclass, this is a list of such tensors, one for each class.
- tpr:
tensor with true positive rates. If multiclass, this is a list of such tensors, one for each class.
- thresholds:
thresholds used for computing false- and true postive rates
StatScores¶
-
class
torchmetrics.
StatScores
(threshold=0.5, top_k=None, reduce='micro', num_classes=None, ignore_index=None, mdmc_reduce=None, multiclass=None, compute_on_step=True, dist_sync_on_step=False, process_group=None, dist_sync_fn=None, is_multiclass=None)[source] Computes the number of true positives, false positives, true negatives, false negatives. Related to Type I and Type II errors and the confusion matrix.
The reduction method (how the statistics are aggregated) is controlled by the
reduce
parameter, and additionally by themdmc_reduce
parameter in the multi-dimensional multi-class case.Accepts all inputs listed in Input types.
- Parameters
threshold¶ (
float
) – Threshold probability value for transforming probability predictions to binary (0 or 1) predictions, in the case of binary or multi-label inputs.Number of highest probability entries for each sample to convert to 1s - relevant only for inputs with probability predictions. If this parameter is set for multi-label inputs, it will take precedence over
threshold
. For (multi-dim) multi-class inputs, this parameter defaults to 1.Should be left unset (
None
) for inputs with label predictions.Defines the reduction that is applied. Should be one of the following:
'micro'
[default]: Counts the statistics by summing over all [sample, class] combinations (globally). Each statistic is represented by a single integer.'macro'
: Counts the statistics for each class separately (over all samples). Each statistic is represented by a(C,)
tensor. Requiresnum_classes
to be set.'samples'
: Counts the statistics for each sample separately (over all classes). Each statistic is represented by a(N, )
1d tensor.
Note
Wwhat is considered a sample in the multi-dimensional multi-class case depends on the value of
mdmc_reduce
.num_classes¶ (
Optional
[int
]) – Number of classes. Necessary for (multi-dimensional) multi-class or multi-label data.ignore_index¶ (
Optional
[int
]) – Specify a class (label) to ignore. If given, this class index does not contribute to the returned score, regardless of reduction method. If an index is ignored, andreduce='macro'
, the class statistics for the ignored class will all be returned as-1
.mdmc_reduce¶ (
Optional
[str
]) –Defines how the multi-dimensional multi-class inputs are handeled. Should be one of the following:
None
[default]: Should be left unchanged if your data is not multi-dimensional multi-class (see Input types for the definition of input types).'samplewise'
: In this case, the statistics are computed separately for each sample on theN
axis, and then the outputs are concatenated together. In each sample the extra axes...
are flattened to become the sub-sample axis, and statistics for each sample are computed by treating the sub-sample axis as theN
axis for that sample.'global'
: In this case theN
and...
dimensions of the inputs are flattened into a newN_X
sample axis, i.e. the inputs are treated as if they were(N_X, C)
. From here on thereduce
parameter applies as usual.
multiclass¶ (
Optional
[bool
]) – Used only in certain special cases, where you want to treat inputs as a different type than what they appear to be. See the parameter’s documentation section for a more detailed explanation and examples.compute_on_step¶ (
bool
) – Forward only callsupdate()
and returnNone
if this is set toFalse
.dist_sync_on_step¶ (
bool
) – Synchronize metric state across processes at eachforward()
before returning the value at the stepprocess_group¶ (
Optional
[Any
]) – Specify the process group on which synchronization is called. default:None
(which selects the entire world)dist_sync_fn¶ (
Optional
[Callable
]) – Callback that performs the allgather operation on the metric state. WhenNone
, DDP will be used to perform the allgather.is_multiclass¶ (
Optional
[bool
]) –Deprecated since version 0.3: Argument will not have any effect and will be removed in v0.4, please use
multiclass
intead.
- Raises
ValueError – If
threshold
is not afloat
between0
and1
.ValueError – If
reduce
is none of"micro"
,"macro"
or"samples"
.ValueError – If
mdmc_reduce
is none ofNone
,"samplewise"
,"global"
.ValueError – If
reduce
is set to"macro"
andnum_classes
is not provided.ValueError – If
num_classes
is set andignore_index
is not in the range0
<=ignore_index
<num_classes
.
Example
>>> from torchmetrics.classification import StatScores >>> preds = torch.tensor([1, 0, 2, 1]) >>> target = torch.tensor([1, 1, 2, 0]) >>> stat_scores = StatScores(reduce='macro', num_classes=3) >>> stat_scores(preds, target) tensor([[0, 1, 2, 1, 1], [1, 1, 1, 1, 2], [1, 0, 3, 0, 1]]) >>> stat_scores = StatScores(reduce='micro') >>> stat_scores(preds, target) tensor([2, 2, 6, 2, 4])
Initializes internal Module state, shared by both nn.Module and ScriptModule.
-
compute
()[source] Computes the stat scores based on inputs passed in to
update
previously.- Return type
- Returns
The metric returns a tensor of shape
(..., 5)
, where the last dimension corresponds to[tp, fp, tn, fn, sup]
(sup
stands for support and equalstp + fn
). The shape depends on thereduce
andmdmc_reduce
(in case of multi-dimensional multi-class data) parameters:If the data is not multi-dimensional multi-class, then
If
reduce='micro'
, the shape will be(5, )
If
reduce='macro'
, the shape will be(C, 5)
, whereC
stands for the number of classesIf
reduce='samples'
, the shape will be(N, 5)
, whereN
stands for the number of samples
If the data is multi-dimensional multi-class and
mdmc_reduce='global'
, thenIf
reduce='micro'
, the shape will be(5, )
If
reduce='macro'
, the shape will be(C, 5)
If
reduce='samples'
, the shape will be(N*X, 5)
, whereX
stands for the product of sizes of all “extra” dimensions of the data (i.e. all dimensions except forC
andN
)
If the data is multi-dimensional multi-class and
mdmc_reduce='samplewise'
, thenIf
reduce='micro'
, the shape will be(N, 5)
If
reduce='macro'
, the shape will be(N, C, 5)
If
reduce='samples'
, the shape will be(N, X, 5)
Regression Metrics¶
ExplainedVariance¶
-
class
torchmetrics.
ExplainedVariance
(multioutput='uniform_average', compute_on_step=True, dist_sync_on_step=False, process_group=None, dist_sync_fn=None)[source] Computes explained variance:
Where
is a tensor of target values, and
is a tensor of predictions.
Forward accepts
preds
(float tensor):(N,)
or(N, ...)
(multioutput)target
(long tensor):(N,)
or(N, ...)
(multioutput)
In the case of multioutput, as default the variances will be uniformly averaged over the additional dimensions. Please see argument multioutput for changing this behavior.
- Parameters
Defines aggregation in the case of multiple output scores. Can be one of the following strings (default is ‘uniform_average’.):
’raw_values’ returns full set of scores
’uniform_average’ scores are uniformly averaged
’variance_weighted’ scores are weighted by their individual variances
compute_on_step¶ (
bool
) – Forward only callsupdate()
and return None if this is set to False. default: Truedist_sync_on_step¶ (
bool
) – Synchronize metric state across processes at eachforward()
before returning the value at the step. default: Falseprocess_group¶ (
Optional
[Any
]) – Specify the process group on which synchronization is called. default: None (which selects the entire world)
- Raises
ValueError – If
multioutput
is not one of"raw_values"
,"uniform_average"
or"variance_weighted"
.
Example
>>> from torchmetrics import ExplainedVariance >>> target = torch.tensor([3, -0.5, 2, 7]) >>> preds = torch.tensor([2.5, 0.0, 2, 8]) >>> explained_variance = ExplainedVariance() >>> explained_variance(preds, target) tensor(0.9572)
>>> target = torch.tensor([[0.5, 1], [-1, 1], [7, -6]]) >>> preds = torch.tensor([[0, 2], [-1, 2], [8, -5]]) >>> explained_variance = ExplainedVariance(multioutput='raw_values') >>> explained_variance(preds, target) tensor([0.9677, 1.0000])
Initializes internal Module state, shared by both nn.Module and ScriptModule.
-
compute
()[source] Computes explained variance over state.
MeanAbsoluteError¶
-
class
torchmetrics.
MeanAbsoluteError
(compute_on_step=True, dist_sync_on_step=False, process_group=None, dist_sync_fn=None)[source] Computes mean absolute error (MAE):
Where
is a tensor of target values, and
is a tensor of predictions.
- Parameters
compute_on_step¶ (
bool
) – Forward only callsupdate()
and return None if this is set to False. default: Truedist_sync_on_step¶ (
bool
) – Synchronize metric state across processes at eachforward()
before returning the value at the step. default: Falseprocess_group¶ (
Optional
[Any
]) – Specify the process group on which synchronization is called. default: None (which selects the entire world)
Example
>>> from torchmetrics import MeanAbsoluteError >>> target = torch.tensor([3.0, -0.5, 2.0, 7.0]) >>> preds = torch.tensor([2.5, 0.0, 2.0, 8.0]) >>> mean_absolute_error = MeanAbsoluteError() >>> mean_absolute_error(preds, target) tensor(0.5000)
Initializes internal Module state, shared by both nn.Module and ScriptModule.
-
compute
()[source] Computes mean absolute error over state.
MeanSquaredError¶
-
class
torchmetrics.
MeanSquaredError
(compute_on_step=True, dist_sync_on_step=False, process_group=None, dist_sync_fn=None)[source] Computes mean squared error (MSE):
Where
is a tensor of target values, and
is a tensor of predictions.
- Parameters
compute_on_step¶ (
bool
) – Forward only callsupdate()
and return None if this is set to False. default: Truedist_sync_on_step¶ (
bool
) – Synchronize metric state across processes at eachforward()
before returning the value at the step. default: Falseprocess_group¶ (
Optional
[Any
]) – Specify the process group on which synchronization is called. default: None (which selects the entire world)
Example
>>> from torchmetrics import MeanSquaredError >>> target = torch.tensor([2.5, 5.0, 4.0, 8.0]) >>> preds = torch.tensor([3.0, 5.0, 2.5, 7.0]) >>> mean_squared_error = MeanSquaredError() >>> mean_squared_error(preds, target) tensor(0.8750)
Initializes internal Module state, shared by both nn.Module and ScriptModule.
-
compute
()[source] Computes mean squared error over state.
MeanSquaredLogError¶
-
class
torchmetrics.
MeanSquaredLogError
(compute_on_step=True, dist_sync_on_step=False, process_group=None, dist_sync_fn=None)[source] Computes mean squared logarithmic error (MSLE):
Where
is a tensor of target values, and
is a tensor of predictions.
- Parameters
compute_on_step¶ (
bool
) – Forward only callsupdate()
and return None if this is set to False. default: Truedist_sync_on_step¶ (
bool
) – Synchronize metric state across processes at eachforward()
before returning the value at the step. default: Falseprocess_group¶ (
Optional
[Any
]) – Specify the process group on which synchronization is called. default: None (which selects the entire world)
Example
>>> from torchmetrics import MeanSquaredLogError >>> target = torch.tensor([2.5, 5, 4, 8]) >>> preds = torch.tensor([3, 5, 2.5, 7]) >>> mean_squared_log_error = MeanSquaredLogError() >>> mean_squared_log_error(preds, target) tensor(0.0397)
Note
Half precision is only support on GPU for this metric
Initializes internal Module state, shared by both nn.Module and ScriptModule.
-
compute
()[source] Compute mean squared logarithmic error over state.
PearsonCorrcoef¶
-
class
torchmetrics.
PearsonCorrcoef
(compute_on_step=True, dist_sync_on_step=False, process_group=None)[source] Computes pearson correlation coefficient:
Where
is a tensor of target values, and
is a tensor of predictions.
Forward accepts
preds
(float tensor):(N,)
target``(float tensor): ``(N,)
- Parameters
compute_on_step¶ (
bool
) – Forward only callsupdate()
and return None if this is set to False. default: Truedist_sync_on_step¶ (
bool
) – Synchronize metric state across processes at eachforward()
before returning the value at the step. default: Falseprocess_group¶ (
Optional
[Any
]) – Specify the process group on which synchronization is called. default: None (which selects the entire world)
Example
>>> from torchmetrics import PearsonCorrcoef >>> target = torch.tensor([3, -0.5, 2, 7]) >>> preds = torch.tensor([2.5, 0.0, 2, 8]) >>> pearson = PearsonCorrcoef() >>> pearson(preds, target) tensor(0.9849)
Initializes internal Module state, shared by both nn.Module and ScriptModule.
-
compute
()[source] Computes pearson correlation coefficient over state.
PSNR¶
-
class
torchmetrics.
PSNR
(data_range=None, base=10.0, reduction='elementwise_mean', dim=None, compute_on_step=True, dist_sync_on_step=False, process_group=None)[source] Computes peak signal-to-noise ratio (PSNR):
Where
denotes the mean-squared-error function.
- Parameters
data_range¶ (
Optional
[float
]) – the range of the data. If None, it is determined from the data (max - min). Thedata_range
must be given whendim
is not None.a method to reduce metric score over labels.
'elementwise_mean'
: takes the mean (default)'sum'
: takes the sum'none'
: no reduction will be applied
dim¶ (
Union
[int
,Tuple
[int
, …],None
]) – Dimensions to reduce PSNR scores over, provided as either an integer or a list of integers. Default is None meaning scores will be reduced across all dimensions and all batches.compute_on_step¶ (
bool
) – Forward only callsupdate()
and return None if this is set to False. default: Truedist_sync_on_step¶ (
bool
) – Synchronize metric state across processes at eachforward()
before returning the value at the step. default: Falseprocess_group¶ (
Optional
[Any
]) – Specify the process group on which synchronization is called. default: None (which selects the entire world)
- Raises
ValueError – If
dim
is notNone
anddata_range
is not given.
Example
>>> from torchmetrics import PSNR >>> psnr = PSNR() >>> preds = torch.tensor([[0.0, 1.0], [2.0, 3.0]]) >>> target = torch.tensor([[3.0, 2.0], [1.0, 0.0]]) >>> psnr(preds, target) tensor(2.5527)
Note
Half precision is only support on GPU for this metric
Initializes internal Module state, shared by both nn.Module and ScriptModule.
-
compute
()[source] Compute peak signal-to-noise ratio over state.
R2Score¶
-
class
torchmetrics.
R2Score
(num_outputs=1, adjusted=0, multioutput='uniform_average', compute_on_step=True, dist_sync_on_step=False, process_group=None, dist_sync_fn=None)[source] Computes r2 score also known as coefficient of determination:
where
is the sum of residual squares, and
is total sum of squares. Can also calculate adjusted r2 score given by
where the parameter
(the number of independent regressors) should be provided as the adjusted argument.
Forward accepts
preds
(float tensor):(N,)
or(N, M)
(multioutput)target
(float tensor):(N,)
or(N, M)
(multioutput)
In the case of multioutput, as default the variances will be uniformly averaged over the additional dimensions. Please see argument multioutput for changing this behavior.
- Parameters
num_outputs¶ (
int
) – Number of outputs in multioutput setting (default is 1)adjusted¶ (
int
) – number of independent regressors for calculating adjusted r2 score. Default 0 (standard r2 score).Defines aggregation in the case of multiple output scores. Can be one of the following strings (default is
'uniform_average'
.):'raw_values'
returns full set of scores'uniform_average'
scores are uniformly averaged'variance_weighted'
scores are weighted by their individual variances
compute_on_step¶ (
bool
) – Forward only callsupdate()
and return None if this is set to False. default: Truedist_sync_on_step¶ (
bool
) – Synchronize metric state across processes at eachforward()
before returning the value at the step. default: Falseprocess_group¶ (
Optional
[Any
]) – Specify the process group on which synchronization is called. default: None (which selects the entire world)
- Raises
ValueError – If
adjusted
parameter is not an integer larger or equal to 0.ValueError – If
multioutput
is not one of"raw_values"
,"uniform_average"
or"variance_weighted"
.
Example
>>> from torchmetrics import R2Score >>> target = torch.tensor([3, -0.5, 2, 7]) >>> preds = torch.tensor([2.5, 0.0, 2, 8]) >>> r2score = R2Score() >>> r2score(preds, target) tensor(0.9486)
>>> target = torch.tensor([[0.5, 1], [-1, 1], [7, -6]]) >>> preds = torch.tensor([[0, 2], [-1, 2], [8, -5]]) >>> r2score = R2Score(num_outputs=2, multioutput='raw_values') >>> r2score(preds, target) tensor([0.9654, 0.9082])
Initializes internal Module state, shared by both nn.Module and ScriptModule.
SpearmanCorrcoef¶
-
class
torchmetrics.
SpearmanCorrcoef
(compute_on_step=True, dist_sync_on_step=False, process_group=None, dist_sync_fn=None)[source] Computes spearmans rank correlation coefficient.
where rg_x and rg_y are the rank associated to the variables x and y. Spearmans correlations coefficient corresponds to the standard pearsons correlation coefficient calculated on the rank variables.
- Parameters
compute_on_step¶ (
bool
) – Forward only callsupdate()
and return None if this is set to False. default: Truedist_sync_on_step¶ (
bool
) – Synchronize metric state across processes at eachforward()
before returning the value at the step. default: Falseprocess_group¶ (
Optional
[Any
]) – Specify the process group on which synchronization is called. default: None (which selects the entire world)dist_sync_fn¶ (
Optional
[Callable
]) – Callback that performs the allgather operation on the metric state. WhenNone
, DDP will be used to perform the allgather
Example
>>> from torchmetrics import SpearmanCorrcoef >>> target = torch.tensor([3, -0.5, 2, 7]) >>> preds = torch.tensor([2.5, 0.0, 2, 8]) >>> spearman = SpearmanCorrcoef() >>> spearman(preds, target) tensor(1.0000)
Initializes internal Module state, shared by both nn.Module and ScriptModule.
-
compute
()[source] Computes spearmans correlation coefficient
SSIM¶
-
class
torchmetrics.
SSIM
(kernel_size=(11, 11), sigma=(1.5, 1.5), reduction='elementwise_mean', data_range=None, k1=0.01, k2=0.03, compute_on_step=True, dist_sync_on_step=False, process_group=None)[source] Computes Structual Similarity Index Measure (SSIM).
- Parameters
kernel_size¶ (
Sequence
[int
]) – size of the gaussian kernel (default: (11, 11))sigma¶ (
Sequence
[float
]) – Standard deviation of the gaussian kernel (default: (1.5, 1.5))a method to reduce metric score over labels.
'elementwise_mean'
: takes the mean (default)'sum'
: takes the sum'none'
: no reduction will be applied
data_range¶ (
Optional
[float
]) – Range of the image. IfNone
, it is determined from the image (max - min)
- Returns
Tensor with SSIM score
Example
>>> from torchmetrics import SSIM >>> preds = torch.rand([16, 1, 16, 16]) >>> target = preds * 0.75 >>> ssim = SSIM() >>> ssim(preds, target) tensor(0.9219)
Initializes internal Module state, shared by both nn.Module and ScriptModule.
-
compute
()[source] Computes explained variance over state.
Retrieval¶
Input details¶
For the purposes of retrieval metrics, inputs (indexes, predictions and targets) must have the same size
(N
stands for the batch size) and the following types:
indexes shape |
indexes dtype |
preds shape |
preds dtype |
target shape |
target dtype |
---|---|---|---|---|---|
|
(N,…) |
|
(N,…) |
|
(N,…) |
Note
All dimensions are flattened at the beginning, so
that, for example, a tensor of shape (N, M)
is treated as (N * M, )
.
In Information Retrieval you have a query that is compared with a variable number of documents. For each pair (Q_i, D_j)
,
a score is computed that measures the relevance of document D
w.r.t. query Q
. Documents are then sorted by score
and you hope that relevant documents are scored higher. target
contains the labels for the documents (relevant or not).
Since a query may be compared with a variable number of documents, we use indexes
to keep track of which scores belong to
the set of pairs (Q_i, D_j)
having the same query Q_i
.
Note
Retrieval metrics are only intended to be used globally. This means that the average of the metric over each batch can be quite different
from the metric computed on the whole dataset. For this reason, we suggest to compute the metric only when all the examples
has been provided to the metric. When using Pytorch Lightning, we suggest to use on_step=False
and on_epoch=True
in self.log
or to place the metric calculation in training_epoch_end
, validation_epoch_end
or test_epoch_end
.
>>> from torchmetrics import RetrievalMAP
>>> # functional version works on a single query at a time
>>> from torchmetrics.functional import retrieval_average_precision
>>> # the first query was compared with two documents, the second with three
>>> indexes = torch.tensor([0, 0, 1, 1, 1])
>>> preds = torch.tensor([0.8, -0.4, 1.0, 1.4, 0.0])
>>> target = torch.tensor([0, 1, 0, 1, 1])
>>> map = RetrievalMAP() # or some other retrieval metric
>>> map(preds, target, indexes=indexes)
tensor(0.6667)
>>> # the previous instruction is roughly equivalent to
>>> res = []
>>> # iterate over indexes of first and second query
>>> for indexes in ([0, 1], [2, 3, 4]):
... res.append(retrieval_average_precision(preds[indexes], target[indexes]))
>>> torch.stack(res).mean()
tensor(0.6667)
RetrievalMAP¶
-
class
torchmetrics.
RetrievalMAP
(empty_target_action='neg', compute_on_step=True, dist_sync_on_step=False, process_group=None, dist_sync_fn=None)[source] Computes Mean Average Precision.
Works with binary target data. Accepts float predictions from a model output.
Forward accepts
preds
(float tensor):(N, ...)
target
(long or bool tensor):(N, ...)
indexes
(long tensor):(N, ...)
indexes
,preds
andtarget
must have the same dimension.indexes
indicate to which query a prediction belongs. Predictions will be first grouped byindexes
and then MAP will be computed as the mean of the Average Precisions over each query.- Parameters
Specify what to do with queries that do not have at least a positive
target
. Choose from:'neg'
: those queries count as0.0
(default)'pos'
: those queries count as1.0
'skip'
: skip those queries; if all queries are skipped,0.0
is returned'error'
: raise aValueError
compute_on_step¶ (
bool
) – Forward only callsupdate()
and return None if this is set to False. default: Truedist_sync_on_step¶ (
bool
) – Synchronize metric state across processes at eachforward()
before returning the value at the step. default: Falseprocess_group¶ (
Optional
[Any
]) – Specify the process group on which synchronization is called. default: None (which selects the entire world)dist_sync_fn¶ (
Optional
[Callable
]) – Callback that performs the allgather operation on the metric state. When None, DDP will be used to perform the allgather. default: None
Example
>>> from torchmetrics import RetrievalMAP >>> indexes = tensor([0, 0, 0, 1, 1, 1, 1]) >>> preds = tensor([0.2, 0.3, 0.5, 0.1, 0.3, 0.5, 0.2]) >>> target = tensor([False, False, True, False, True, False, True]) >>> map = RetrievalMAP() >>> map(preds, target, indexes=indexes) tensor(0.7917)
Initializes internal Module state, shared by both nn.Module and ScriptModule.
RetrievalMRR¶
-
class
torchmetrics.
RetrievalMRR
(empty_target_action='neg', compute_on_step=True, dist_sync_on_step=False, process_group=None, dist_sync_fn=None)[source] Computes Mean Reciprocal Rank.
Works with binary target data. Accepts float predictions from a model output.
Forward accepts
preds
(float tensor):(N, ...)
target
(long or bool tensor):(N, ...)
indexes
(long tensor):(N, ...)
indexes
,preds
andtarget
must have the same dimension.indexes
indicate to which query a prediction belongs. Predictions will be first grouped byindexes
and then MRR will be computed as the mean of the Reciprocal Rank over each query.- Parameters
Specify what to do with queries that do not have at least a positive
target
. Choose from:'neg'
: those queries count as0.0
(default)'pos'
: those queries count as1.0
'skip'
: skip those queries; if all queries are skipped,0.0
is returned'error'
: raise aValueError
compute_on_step¶ (
bool
) – Forward only callsupdate()
and return None if this is set to False. default: Truedist_sync_on_step¶ (
bool
) – Synchronize metric state across processes at eachforward()
before returning the value at the step. default: Falseprocess_group¶ (
Optional
[Any
]) – Specify the process group on which synchronization is called. default: None (which selects the entire world)dist_sync_fn¶ (
Optional
[Callable
]) – Callback that performs the allgather operation on the metric state. When None, DDP will be used to perform the allgather. default: None
Example
>>> from torchmetrics import RetrievalMRR >>> indexes = tensor([0, 0, 0, 1, 1, 1, 1]) >>> preds = tensor([0.2, 0.3, 0.5, 0.1, 0.3, 0.5, 0.2]) >>> target = tensor([False, False, True, False, True, False, True]) >>> mrr = RetrievalMRR() >>> mrr(preds, target, indexes=indexes) tensor(0.7500)
Initializes internal Module state, shared by both nn.Module and ScriptModule.
RetrievalPrecision¶
-
class
torchmetrics.
RetrievalPrecision
(empty_target_action='neg', compute_on_step=True, dist_sync_on_step=False, process_group=None, dist_sync_fn=None, k=None)[source] Computes Precision.
Works with binary target data. Accepts float predictions from a model output.
Forward accepts:
preds
(float tensor):(N, ...)
target
(long or bool tensor):(N, ...)
indexes
(long tensor):(N, ...)
indexes
,preds
andtarget
must have the same dimension.indexes
indicate to which query a prediction belongs. Predictions will be first grouped byindexes
and then Precision will be computed as the mean of the Precision over each query.- Parameters
Specify what to do with queries that do not have at least a positive
target
. Choose from:'neg'
: those queries count as0.0
(default)'pos'
: those queries count as1.0
'skip'
: skip those queries; if all queries are skipped,0.0
is returned'error'
: raise aValueError
compute_on_step¶ (
bool
) – Forward only callsupdate()
and return None if this is set to False. default: Truedist_sync_on_step¶ (
bool
) – Synchronize metric state across processes at eachforward()
before returning the value at the step. default: Falseprocess_group¶ (
Optional
[Any
]) – Specify the process group on which synchronization is called. default: None (which selects the entire world)dist_sync_fn¶ (
Optional
[Callable
]) – Callback that performs the allgather operation on the metric state. When None, DDP will be used to perform the allgather. default: Nonek¶ (
Optional
[int
]) – consider only the top k elements for each query. default: None
Example
>>> from torchmetrics import RetrievalPrecision >>> indexes = tensor([0, 0, 0, 1, 1, 1, 1]) >>> preds = tensor([0.2, 0.3, 0.5, 0.1, 0.3, 0.5, 0.2]) >>> target = tensor([False, False, True, False, True, False, True]) >>> p2 = RetrievalPrecision(k=2) >>> p2(preds, target, indexes=indexes) tensor(0.5000)
Initializes internal Module state, shared by both nn.Module and ScriptModule.
RetrievalRecall¶
-
class
torchmetrics.
RetrievalRecall
(empty_target_action='neg', compute_on_step=True, dist_sync_on_step=False, process_group=None, dist_sync_fn=None, k=None)[source] Computes Recall.
Works with binary target data. Accepts float predictions from a model output.
Forward accepts:
preds
(float tensor):(N, ...)
target
(long or bool tensor):(N, ...)
indexes
(long tensor):(N, ...)
indexes
,preds
andtarget
must have the same dimension.indexes
indicate to which query a prediction belongs. Predictions will be first grouped byindexes
and then Recall will be computed as the mean of the Recall over each query.- Parameters
Specify what to do with queries that do not have at least a positive
target
. Choose from:'neg'
: those queries count as0.0
(default)'pos'
: those queries count as1.0
'skip'
: skip those queries; if all queries are skipped,0.0
is returned'error'
: raise aValueError
compute_on_step¶ (
bool
) – Forward only callsupdate()
and return None if this is set to False. default: Truedist_sync_on_step¶ (
bool
) – Synchronize metric state across processes at eachforward()
before returning the value at the step. default: Falseprocess_group¶ (
Optional
[Any
]) – Specify the process group on which synchronization is called. default: None (which selects the entire world)dist_sync_fn¶ (
Optional
[Callable
]) – Callback that performs the allgather operation on the metric state. When None, DDP will be used to perform the allgather. default: Nonek¶ (
Optional
[int
]) – consider only the top k elements for each query. default: None
Example
>>> from torchmetrics import RetrievalRecall >>> indexes = tensor([0, 0, 0, 1, 1, 1, 1]) >>> preds = tensor([0.2, 0.3, 0.5, 0.1, 0.3, 0.5, 0.2]) >>> target = tensor([False, False, True, False, True, False, True]) >>> r2 = RetrievalRecall(k=2) >>> r2(preds, target, indexes=indexes) tensor(0.7500)
Initializes internal Module state, shared by both nn.Module and ScriptModule.
RetrievalFallOut¶
-
class
torchmetrics.
RetrievalFallOut
(empty_target_action='pos', compute_on_step=True, dist_sync_on_step=False, process_group=None, dist_sync_fn=None, k=None)[source] Computes Fall-out.
Works with binary target data. Accepts float predictions from a model output.
Forward accepts:
preds
(float tensor):(N, ...)
target
(long or bool tensor):(N, ...)
indexes
(long tensor):(N, ...)
indexes
,preds
andtarget
must have the same dimension.indexes
indicate to which query a prediction belongs. Predictions will be first grouped byindexes
and then Fall-out will be computed as the mean of the Fall-out over each query.- Parameters
Specify what to do with queries that do not have at least a negative
target
. Choose from:'neg'
: those queries count as0.0
(default)'pos'
: those queries count as1.0
'skip'
: skip those queries; if all queries are skipped,0.0
is returned'error'
: raise aValueError
compute_on_step¶ (
bool
) – Forward only callsupdate()
and return None if this is set to False. default: Truedist_sync_on_step¶ (
bool
) – Synchronize metric state across processes at eachforward()
before returning the value at the step. default: Falseprocess_group¶ (
Optional
[Any
]) – Specify the process group on which synchronization is called. default: None (which selects the entire world)dist_sync_fn¶ (
Optional
[Callable
]) – Callback that performs the allgather operation on the metric state. When None, DDP will be used to perform the allgather. default: Nonek¶ (
Optional
[int
]) – consider only the top k elements for each query. default: None
Example
>>> from torchmetrics import RetrievalFallOut >>> indexes = tensor([0, 0, 0, 1, 1, 1, 1]) >>> preds = tensor([0.2, 0.3, 0.5, 0.1, 0.3, 0.5, 0.2]) >>> target = tensor([False, False, True, False, True, False, True]) >>> fo = RetrievalFallOut(k=2) >>> fo(preds, target, indexes=indexes) tensor(0.5000)
Initializes internal Module state, shared by both nn.Module and ScriptModule.
-
compute
()[source] First concat state indexes, preds and target since they were stored as lists. After that, compute list of groups that will help in keeping together predictions about the same query. Finally, for each group compute the _metric if the number of negative targets is at least 1, otherwise behave as specified by self.empty_target_action.
- Return type
RetrievalNormalizedDCG¶
-
class
torchmetrics.
RetrievalNormalizedDCG
(empty_target_action='neg', compute_on_step=True, dist_sync_on_step=False, process_group=None, dist_sync_fn=None, k=None)[source] Computes Normalized Discounted Cumulative Gain.
Works with binary or positive integer target data. Accepts float predictions from a model output.
Forward accepts:
preds
(float tensor):(N, ...)
target
(long or bool tensor):(N, ...)
indexes
(long tensor):(N, ...)
indexes
,preds
andtarget
must have the same dimension.indexes
indicate to which query a prediction belongs. Predictions will be first grouped byindexes
and then Normalized Discounted Cumulative Gain will be computed as the mean of the Normalized Discounted Cumulative Gain over each query.- Parameters
Specify what to do with queries that do not have at least a positive
target
. Choose from:'neg'
: those queries count as0.0
(default)'pos'
: those queries count as1.0
'skip'
: skip those queries; if all queries are skipped,0.0
is returned'error'
: raise aValueError
compute_on_step¶ (
bool
) – Forward only callsupdate()
and return None if this is set to False. default: Truedist_sync_on_step¶ (
bool
) – Synchronize metric state across processes at eachforward()
before returning the value at the step. default: Falseprocess_group¶ (
Optional
[Any
]) – Specify the process group on which synchronization is called. default: None (which selects the entire world)dist_sync_fn¶ (
Optional
[Callable
]) – Callback that performs the allgather operation on the metric state. When None, DDP will be used to perform the allgather. default: Nonek¶ (
Optional
[int
]) – consider only the top k elements for each query. default: None
Example
>>> from torchmetrics import RetrievalNormalizedDCG >>> indexes = tensor([0, 0, 0, 1, 1, 1, 1]) >>> preds = tensor([0.2, 0.3, 0.5, 0.1, 0.3, 0.5, 0.2]) >>> target = tensor([False, False, True, False, True, False, True]) >>> ndcg = RetrievalNormalizedDCG() >>> ndcg(preds, target, indexes=indexes) tensor(0.8467)
Initializes internal Module state, shared by both nn.Module and ScriptModule.
Wrappers¶
Modular wrapper metrics are not metrics in themself, but instead take a metric and alter the internal logic of the base metric.
-
class
torchmetrics.
BootStrapper
(base_metric, num_bootstraps=10, mean=True, std=True, quantile=None, raw=False, sampling_strategy='poisson', compute_on_step=True, dist_sync_on_step=False, process_group=None, dist_sync_fn=None)[source] Use to turn a metric into a bootstrapped metric that can automate the process of getting confidence intervals for metric values. This wrapper class basically keeps multiple copies of the same base metric in memory and whenever
update
orforward
is called, all input tensors are resampled (with replacement) along the first dimension.- Parameters
base_metric¶ (
Metric
) – base metric class to wrapnum_bootstraps¶ (
int
) – number of copies to make of the base metric for bootstrappingstd¶ (
bool
) – ifTrue
return the standard diviation of the bootstrapsquantile¶ (
Union
[float
,Tensor
,None
]) – if given, returns the quantile of the bootstraps. Can only be used with pytorch version 1.6 or highersampling_strategy¶ (
str
) – Determines how to produce bootstrapped samplings. Either'poisson'
ormultinomial
. If'possion'
is chosen, the number of times each sample will be included in the bootstrap will be given by, which approximates the true bootstrap distribution when the number of samples is large. If
'multinomial'
is chosen, we will apply true bootstrapping at the batch level to approximate bootstrapping over the hole dataset.compute_on_step¶ (
bool
) – Forward only callsupdate()
and returnNone
if this is set toFalse
.dist_sync_on_step¶ (
bool
) – Synchronize metric state across processes at eachforward()
before returning the value at the stepprocess_group¶ (
Optional
[Any
]) – Specify the process group on which synchronization is called. default:None
(which selects the entire world)dist_sync_fn¶ (
Optional
[Callable
]) – Callback that performs the allgather operation on the metric state. WhenNone
, DDP will be used to perform the allgather.
- Example::
>>> from pprint import pprint >>> from torchmetrics import Accuracy, BootStrapper >>> _ = torch.manual_seed(123) >>> base_metric = Accuracy() >>> bootstrap = BootStrapper(base_metric, num_bootstraps=20) >>> bootstrap.update(torch.randint(5, (20,)), torch.randint(5, (20,))) >>> output = bootstrap.compute() >>> pprint(output) {'mean': tensor(0.2205), 'std': tensor(0.0859)}
-
compute
()[source] Computes the bootstrapped metric values. Allways returns a dict of tensors, which can contain the following keys:
mean
,std
,quantile
andraw
depending on how the class was initialized