Skip to content

Commit 3a56a60

Browse files
authored
Prune metrics: other classification 7/n (#6584)
* confusion_matrix * iou * f_beta * hamming_distance * stat_scores * tests * flake8 * chlog
1 parent 3b72bcc commit 3a56a60

20 files changed

+155
-2421
lines changed

CHANGELOG.md

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -78,6 +78,8 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
7878

7979
[#6573](https://github.com/PyTorchLightning/pytorch-lightning/pull/6573),
8080

81+
[#6584](https://github.com/PyTorchLightning/pytorch-lightning/pull/6584),
82+
8183
)
8284

8385

pytorch_lightning/metrics/classification/confusion_matrix.py

Lines changed: 7 additions & 83 deletions
Original file line numberDiff line numberDiff line change
@@ -13,64 +13,14 @@
1313
# limitations under the License.
1414
from typing import Any, Optional
1515

16-
import torch
17-
from torchmetrics import Metric
16+
from torchmetrics import ConfusionMatrix as _ConfusionMatrix
1817

19-
from pytorch_lightning.metrics.functional.confusion_matrix import _confusion_matrix_compute, _confusion_matrix_update
18+
from pytorch_lightning.utilities.deprecation import deprecated
2019

2120

22-
class ConfusionMatrix(Metric):
23-
"""
24-
Computes the `confusion matrix
25-
<https://scikit-learn.org/stable/modules/model_evaluation.html#confusion-matrix>`_. Works with binary,
26-
multiclass, and multilabel data. Accepts probabilities from a model output or
27-
integer class values in prediction. Works with multi-dimensional preds and
28-
target.
29-
30-
Note:
31-
This metric produces a multi-dimensional output, so it can not be directly logged.
32-
33-
Forward accepts
34-
35-
- ``preds`` (float or long tensor): ``(N, ...)`` or ``(N, C, ...)`` where C is the number of classes
36-
- ``target`` (long tensor): ``(N, ...)``
37-
38-
If preds and target are the same shape and preds is a float tensor, we use the ``self.threshold`` argument
39-
to convert into integer labels. This is the case for binary and multi-label probabilities.
40-
41-
If preds has an extra dimension as in the case of multi-class scores we perform an argmax on ``dim=1``.
42-
43-
Args:
44-
num_classes: Number of classes in the dataset.
45-
normalize: Normalization mode for confusion matrix. Choose from
46-
47-
- ``None`` or ``'none'``: no normalization (default)
48-
- ``'true'``: normalization over the targets (most commonly used)
49-
- ``'pred'``: normalization over the predictions
50-
- ``'all'``: normalization over the whole matrix
51-
52-
threshold:
53-
Threshold value for binary or multi-label probabilites. default: 0.5
54-
compute_on_step:
55-
Forward only calls ``update()`` and return None if this is set to False. default: True
56-
dist_sync_on_step:
57-
Synchronize metric state across processes at each ``forward()``
58-
before returning the value at the step. default: False
59-
process_group:
60-
Specify the process group on which synchronization is called. default: None (which selects the entire world)
61-
62-
Example:
63-
64-
>>> from pytorch_lightning.metrics import ConfusionMatrix
65-
>>> target = torch.tensor([1, 1, 0, 0])
66-
>>> preds = torch.tensor([0, 1, 0, 0])
67-
>>> confmat = ConfusionMatrix(num_classes=2)
68-
>>> confmat(preds, target)
69-
tensor([[2., 0.],
70-
[1., 1.]])
71-
72-
"""
21+
class ConfusionMatrix(_ConfusionMatrix):
7322

23+
@deprecated(target=_ConfusionMatrix, ver_deprecate="1.3.0", ver_remove="1.5.0")
7424
def __init__(
7525
self,
7626
num_classes: int,
@@ -80,35 +30,9 @@ def __init__(
8030
dist_sync_on_step: bool = False,
8131
process_group: Optional[Any] = None,
8232
):
83-
84-
super().__init__(
85-
compute_on_step=compute_on_step,
86-
dist_sync_on_step=dist_sync_on_step,
87-
process_group=process_group,
88-
)
89-
self.num_classes = num_classes
90-
self.normalize = normalize
91-
self.threshold = threshold
92-
93-
allowed_normalize = ('true', 'pred', 'all', 'none', None)
94-
assert self.normalize in allowed_normalize, \
95-
f"Argument average needs to one of the following: {allowed_normalize}"
96-
97-
self.add_state("confmat", default=torch.zeros(num_classes, num_classes), dist_reduce_fx="sum")
98-
99-
def update(self, preds: torch.Tensor, target: torch.Tensor):
100-
"""
101-
Update state with predictions and targets.
102-
103-
Args:
104-
preds: Predictions from model
105-
target: Ground truth values
10633
"""
107-
confmat = _confusion_matrix_update(preds, target, self.num_classes, self.threshold)
108-
self.confmat += confmat
34+
This implementation refers to :class:`~torchmetrics.ConfusionMatrix`.
10935
110-
def compute(self) -> torch.Tensor:
111-
"""
112-
Computes confusion matrix
36+
.. deprecated::
37+
Use :class:`~torchmetrics.ConfusionMatrix`. Will be removed in v1.5.0.
11338
"""
114-
return _confusion_matrix_compute(self.confmat, self.normalize)

pytorch_lightning/metrics/classification/f_beta.py

Lines changed: 15 additions & 165 deletions
Original file line numberDiff line numberDiff line change
@@ -13,72 +13,15 @@
1313
# limitations under the License.
1414
from typing import Any, Optional
1515

16-
import torch
17-
from torchmetrics import Metric
16+
from torchmetrics import F1 as _F1
17+
from torchmetrics import FBeta as _FBeta
1818

19-
from pytorch_lightning.metrics.functional.f_beta import _fbeta_compute, _fbeta_update
20-
from pytorch_lightning.utilities import rank_zero_warn
19+
from pytorch_lightning.utilities.deprecation import deprecated
2120

2221

23-
class FBeta(Metric):
24-
r"""
25-
Computes `F-score <https://en.wikipedia.org/wiki/F-score>`_, specifically:
26-
27-
.. math::
28-
F_\beta = (1 + \beta^2) * \frac{\text{precision} * \text{recall}}
29-
{(\beta^2 * \text{precision}) + \text{recall}}
30-
31-
Where :math:`\beta` is some positive real factor. Works with binary, multiclass, and multilabel data.
32-
Accepts probabilities from a model output or integer class values in prediction.
33-
Works with multi-dimensional preds and target.
34-
35-
Forward accepts
36-
37-
- ``preds`` (float or long tensor): ``(N, ...)`` or ``(N, C, ...)`` where C is the number of classes
38-
- ``target`` (long tensor): ``(N, ...)``
39-
40-
If preds and target are the same shape and preds is a float tensor, we use the ``self.threshold`` argument
41-
to convert into integer labels. This is the case for binary and multi-label probabilities.
42-
43-
If preds has an extra dimension as in the case of multi-class scores we perform an argmax on ``dim=1``.
44-
45-
Args:
46-
num_classes: Number of classes in the dataset.
47-
beta: Beta coefficient in the F measure.
48-
threshold:
49-
Threshold value for binary or multi-label probabilities. default: 0.5
50-
51-
average:
52-
- ``'micro'`` computes metric globally
53-
- ``'macro'`` computes metric for each class and uniformly averages them
54-
- ``'weighted'`` computes metric for each class and does a weighted-average,
55-
where each class is weighted by their support (accounts for class imbalance)
56-
- ``'none'`` or ``None`` computes and returns the metric per class
57-
58-
multilabel: If predictions are from multilabel classification.
59-
compute_on_step:
60-
Forward only calls ``update()`` and return None if this is set to False. default: True
61-
dist_sync_on_step:
62-
Synchronize metric state across processes at each ``forward()``
63-
before returning the value at the step. default: False
64-
process_group:
65-
Specify the process group on which synchronization is called. default: None (which selects the entire world)
66-
67-
Raises:
68-
ValueError:
69-
If ``average`` is none of ``"micro"``, ``"macro"``, ``"weighted"``, ``"none"``, ``None``.
70-
71-
Example:
72-
73-
>>> from pytorch_lightning.metrics import FBeta
74-
>>> target = torch.tensor([0, 1, 2, 0, 1, 2])
75-
>>> preds = torch.tensor([0, 2, 1, 0, 0, 1])
76-
>>> f_beta = FBeta(num_classes=3, beta=0.5)
77-
>>> f_beta(preds, target)
78-
tensor(0.3333)
79-
80-
"""
22+
class FBeta(_FBeta):
8123

24+
@deprecated(target=_FBeta, ver_deprecate="1.3.0", ver_remove="1.5.0")
8225
def __init__(
8326
self,
8427
num_classes: int,
@@ -90,103 +33,17 @@ def __init__(
9033
dist_sync_on_step: bool = False,
9134
process_group: Optional[Any] = None,
9235
):
93-
super().__init__(
94-
compute_on_step=compute_on_step,
95-
dist_sync_on_step=dist_sync_on_step,
96-
process_group=process_group,
97-
)
98-
99-
self.num_classes = num_classes
100-
self.beta = beta
101-
self.threshold = threshold
102-
self.average = average
103-
self.multilabel = multilabel
104-
105-
allowed_average = ("micro", "macro", "weighted", "none", None)
106-
if self.average not in allowed_average:
107-
raise ValueError(
108-
'Argument `average` expected to be one of the following:'
109-
f' {allowed_average} but got {self.average}'
110-
)
111-
112-
self.add_state("true_positives", default=torch.zeros(num_classes), dist_reduce_fx="sum")
113-
self.add_state("predicted_positives", default=torch.zeros(num_classes), dist_reduce_fx="sum")
114-
self.add_state("actual_positives", default=torch.zeros(num_classes), dist_reduce_fx="sum")
115-
116-
def update(self, preds: torch.Tensor, target: torch.Tensor):
117-
"""
118-
Update state with predictions and targets.
119-
120-
Args:
121-
preds: Predictions from model
122-
target: Ground truth values
12336
"""
124-
true_positives, predicted_positives, actual_positives = _fbeta_update(
125-
preds, target, self.num_classes, self.threshold, self.multilabel
126-
)
127-
128-
self.true_positives += true_positives
129-
self.predicted_positives += predicted_positives
130-
self.actual_positives += actual_positives
37+
This implementation refers to :class:`~torchmetrics.FBeta`.
13138
132-
def compute(self) -> torch.Tensor:
39+
.. deprecated::
40+
Use :class:`~torchmetrics.FBeta`. Will be removed in v1.5.0.
13341
"""
134-
Computes fbeta over state.
135-
"""
136-
return _fbeta_compute(
137-
self.true_positives, self.predicted_positives, self.actual_positives, self.beta, self.average
138-
)
139-
140-
141-
class F1(FBeta):
142-
"""
143-
Computes F1 metric. F1 metrics correspond to a harmonic mean of the
144-
precision and recall scores.
145-
146-
Works with binary, multiclass, and multilabel data.
147-
Accepts logits from a model output or integer class values in prediction.
148-
Works with multi-dimensional preds and target.
14942

150-
Forward accepts
15143

152-
- ``preds`` (float or long tensor): ``(N, ...)`` or ``(N, C, ...)`` where C is the number of classes
153-
- ``target`` (long tensor): ``(N, ...)``
154-
155-
If preds and target are the same shape and preds is a float tensor, we use the ``self.threshold`` argument.
156-
This is the case for binary and multi-label logits.
157-
158-
If preds has an extra dimension as in the case of multi-class scores we perform an argmax on ``dim=1``.
159-
160-
Args:
161-
num_classes: Number of classes in the dataset.
162-
threshold:
163-
Threshold value for binary or multi-label logits. default: 0.5
164-
165-
average:
166-
- ``'micro'`` computes metric globally
167-
- ``'macro'`` computes metric for each class and uniformly averages them
168-
- ``'weighted'`` computes metric for each class and does a weighted-average,
169-
where each class is weighted by their support (accounts for class imbalance)
170-
- ``'none'`` or ``None`` computes and returns the metric per class
171-
172-
multilabel: If predictions are from multilabel classification.
173-
compute_on_step:
174-
Forward only calls ``update()`` and returns None if this is set to False. default: True
175-
dist_sync_on_step:
176-
Synchronize metric state across processes at each ``forward()``
177-
before returning the value at the step. default: False
178-
process_group:
179-
Specify the process group on which synchronization is called. default: None (which selects the entire world)
180-
181-
Example:
182-
>>> from pytorch_lightning.metrics import F1
183-
>>> target = torch.tensor([0, 1, 2, 0, 1, 2])
184-
>>> preds = torch.tensor([0, 2, 1, 0, 0, 1])
185-
>>> f1 = F1(num_classes=3)
186-
>>> f1(preds, target)
187-
tensor(0.3333)
188-
"""
44+
class F1(_F1):
18945

46+
@deprecated(target=_F1, ver_deprecate="1.3.0", ver_remove="1.5.0")
19047
def __init__(
19148
self,
19249
num_classes: int,
@@ -197,16 +54,9 @@ def __init__(
19754
dist_sync_on_step: bool = False,
19855
process_group: Optional[Any] = None,
19956
):
200-
if multilabel is not False:
201-
rank_zero_warn(f'The `multilabel={multilabel}` parameter is unused and will not have any effect.')
57+
"""
58+
This implementation refers to :class:`~torchmetrics.F1`.
20259
203-
super().__init__(
204-
num_classes=num_classes,
205-
beta=1.0,
206-
threshold=threshold,
207-
average=average,
208-
multilabel=multilabel,
209-
compute_on_step=compute_on_step,
210-
dist_sync_on_step=dist_sync_on_step,
211-
process_group=process_group,
212-
)
60+
.. deprecated::
61+
Use :class:`~torchmetrics.F1`. Will be removed in v1.5.0.
62+
"""

0 commit comments

Comments
 (0)