Skip to content

Commit a731269

Browse files
authored
Prune deprecated metrics for 1.3 (#6161)
* prune deprecated metrics for 1.3 * isort / yapf
1 parent 1d9c553 commit a731269

File tree

8 files changed

+20
-356
lines changed

8 files changed

+20
-356
lines changed

CHANGELOG.md

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,8 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
99

1010
### Added
1111

12+
- Added a way to print to terminal without breaking up the progress bar ([#5470](https://github.com/PyTorchLightning/pytorch-lightning/pull/5470))
13+
1214

1315
### Changed
1416

@@ -24,6 +26,11 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
2426
- Removed deprecated Trainer argument `enable_pl_optimizer` and `automatic_optimization` ([#6163](https://github.com/PyTorchLightning/pytorch-lightning/pull/6163))
2527

2628

29+
- Removed deprecated metrics ([#6161](https://github.com/PyTorchLightning/pytorch-lightning/pull/6161))
30+
* from `pytorch_lightning.metrics.functional.classification` removed `to_onehot`, `to_categorical`, `get_num_classes`, `roc`, `multiclass_roc`, `average_precision`, `precision_recall_curve`, `multiclass_precision_recall_curve`
31+
* from `pytorch_lightning.metrics.functional.reduction` removed `reduce`, `class_reduce`
32+
33+
2734
### Fixed
2835

2936
- Made the `Plugin.reduce` method more consistent across all Plugins to reflect a mean-reduction by default ([#6011](https://github.com/PyTorchLightning/pytorch-lightning/pull/6011))
@@ -93,7 +100,6 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
93100
- Added `Trainer` flag to activate Stochastic Weight Averaging (SWA) `Trainer(stochastic_weight_avg=True)` ([#6038](https://github.com/PyTorchLightning/pytorch-lightning/pull/6038))
94101
- Added DeepSpeed integration ([#5954](https://github.com/PyTorchLightning/pytorch-lightning/pull/5954),
95102
[#6042](https://github.com/PyTorchLightning/pytorch-lightning/pull/6042))
96-
- Added a way to print to terminal without breaking up the progress bar ([#5470](https://github.com/PyTorchLightning/pytorch-lightning/pull/5470))
97103

98104
### Changed
99105

pytorch_lightning/metrics/functional/__init__.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -21,7 +21,6 @@
2121
multiclass_auroc,
2222
stat_scores_multiple_classes,
2323
to_categorical,
24-
to_onehot,
2524
)
2625
from pytorch_lightning.metrics.functional.confusion_matrix import confusion_matrix # noqa: F401
2726
from pytorch_lightning.metrics.functional.explained_variance import explained_variance # noqa: F401

pytorch_lightning/metrics/functional/classification.py

Lines changed: 10 additions & 247 deletions
Original file line numberDiff line numberDiff line change
@@ -18,70 +18,11 @@
1818

1919
from pytorch_lightning.metrics.functional.auc import auc as __auc
2020
from pytorch_lightning.metrics.functional.auroc import auroc as __auroc
21-
from pytorch_lightning.metrics.functional.average_precision import average_precision as __ap
2221
from pytorch_lightning.metrics.functional.iou import iou as __iou
23-
from pytorch_lightning.metrics.functional.precision_recall_curve import _binary_clf_curve
24-
from pytorch_lightning.metrics.functional.precision_recall_curve import precision_recall_curve as __prc
25-
from pytorch_lightning.metrics.functional.roc import roc as __roc
26-
from pytorch_lightning.metrics.utils import class_reduce
27-
from pytorch_lightning.metrics.utils import get_num_classes as __gnc
28-
from pytorch_lightning.metrics.utils import reduce
29-
from pytorch_lightning.metrics.utils import to_categorical as __tc
30-
from pytorch_lightning.metrics.utils import to_onehot as __to
22+
from pytorch_lightning.metrics.utils import class_reduce, get_num_classes, reduce, to_categorical
3123
from pytorch_lightning.utilities import rank_zero_warn
3224

3325

34-
def to_onehot(
35-
tensor: torch.Tensor,
36-
num_classes: Optional[int] = None,
37-
) -> torch.Tensor:
38-
"""
39-
Converts a dense label tensor to one-hot format
40-
41-
.. warning :: Deprecated in favor of :func:`~pytorch_lightning.metrics.utils.to_onehot`
42-
"""
43-
rank_zero_warn(
44-
"This `to_onehot` was deprecated in v1.1.0 in favor of"
45-
" `from pytorch_lightning.metrics.utils import to_onehot`."
46-
" It will be removed in v1.3.0", DeprecationWarning
47-
)
48-
return __to(tensor, num_classes)
49-
50-
51-
def to_categorical(tensor: torch.Tensor, argmax_dim: int = 1) -> torch.Tensor:
52-
"""
53-
Converts a tensor of probabilities to a dense label tensor
54-
55-
.. warning :: Deprecated in favor of :func:`~pytorch_lightning.metrics.utils.to_categorical`
56-
57-
"""
58-
rank_zero_warn(
59-
"This `to_categorical` was deprecated in v1.1.0 in favor of"
60-
" `from pytorch_lightning.metrics.utils import to_categorical`."
61-
" It will be removed in v1.3.0", DeprecationWarning
62-
)
63-
return __tc(tensor)
64-
65-
66-
def get_num_classes(
67-
pred: torch.Tensor,
68-
target: torch.Tensor,
69-
num_classes: Optional[int] = None,
70-
) -> int:
71-
"""
72-
Calculates the number of classes for a given prediction and target tensor.
73-
74-
.. warning :: Deprecated in favor of :func:`~pytorch_lightning.metrics.utils.get_num_classes`
75-
76-
"""
77-
rank_zero_warn(
78-
"This `get_num_classes` was deprecated in v1.1.0 in favor of"
79-
" `from pytorch_lightning.metrics.utils import get_num_classes`."
80-
" It will be removed in v1.3.0", DeprecationWarning
81-
)
82-
return __gnc(pred, target, num_classes)
83-
84-
8526
def stat_scores(
8627
pred: torch.Tensor,
8728
target: torch.Tensor,
@@ -122,6 +63,7 @@ def stat_scores(
12263
return tp, fp, tn, fn, sup
12364

12465

66+
# todo: remove in 1.4
12567
def stat_scores_multiple_classes(
12668
pred: torch.Tensor,
12769
target: torch.Tensor,
@@ -210,6 +152,7 @@ def _confmat_normalize(cm):
210152
return cm
211153

212154

155+
# todo: remove in 1.4
213156
def precision_recall(
214157
pred: torch.Tensor,
215158
target: torch.Tensor,
@@ -268,6 +211,7 @@ def precision_recall(
268211
return precision, recall
269212

270213

214+
# todo: remove in 1.4
271215
def precision(
272216
pred: torch.Tensor,
273217
target: torch.Tensor,
@@ -311,6 +255,7 @@ def precision(
311255
return precision_recall(pred=pred, target=target, num_classes=num_classes, class_reduction=class_reduction)[0]
312256

313257

258+
# todo: remove in 1.4
314259
def recall(
315260
pred: torch.Tensor,
316261
target: torch.Tensor,
@@ -353,128 +298,7 @@ def recall(
353298
return precision_recall(pred=pred, target=target, num_classes=num_classes, class_reduction=class_reduction)[1]
354299

355300

356-
# todo: remove in 1.3
357-
def roc(
358-
pred: torch.Tensor,
359-
target: torch.Tensor,
360-
sample_weight: Optional[Sequence] = None,
361-
pos_label: int = 1.,
362-
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
363-
"""
364-
Computes the Receiver Operating Characteristic (ROC). It assumes classifier is binary.
365-
366-
.. warning :: Deprecated in favor of :func:`~pytorch_lightning.metrics.functional.roc.roc`
367-
"""
368-
rank_zero_warn(
369-
"This `multiclass_roc` was deprecated in v1.1.0 in favor of"
370-
" `from pytorch_lightning.metrics.functional.roc import roc`."
371-
" It will be removed in v1.3.0", DeprecationWarning
372-
)
373-
return __roc(preds=pred, target=target, sample_weights=sample_weight, pos_label=pos_label)
374-
375-
376-
# TODO: deprecated in favor of general ROC in pytorch_lightning/metrics/functional/roc.py
377-
def _roc(
378-
pred: torch.Tensor,
379-
target: torch.Tensor,
380-
sample_weight: Optional[Sequence] = None,
381-
pos_label: int = 1.,
382-
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
383-
"""
384-
Computes the Receiver Operating Characteristic (ROC). It assumes classifier is binary.
385-
386-
.. warning :: Deprecated in favor of :func:`~pytorch_lightning.metrics.functional.roc.roc`
387-
388-
Example:
389-
390-
>>> x = torch.tensor([0, 1, 2, 3])
391-
>>> y = torch.tensor([0, 1, 1, 1])
392-
>>> fpr, tpr, thresholds = _roc(x, y)
393-
>>> fpr
394-
tensor([0., 0., 0., 0., 1.])
395-
>>> tpr
396-
tensor([0.0000, 0.3333, 0.6667, 1.0000, 1.0000])
397-
>>> thresholds
398-
tensor([4, 3, 2, 1, 0])
399-
400-
"""
401-
rank_zero_warn(
402-
"This `multiclass_roc` was deprecated in v1.1.0 in favor of"
403-
" `from pytorch_lightning.metrics.functional.roc import roc`."
404-
" It will be removed in v1.3.0", DeprecationWarning
405-
)
406-
fps, tps, thresholds = _binary_clf_curve(pred, target, sample_weights=sample_weight, pos_label=pos_label)
407-
408-
# Add an extra threshold position
409-
# to make sure that the curve starts at (0, 0)
410-
tps = torch.cat([torch.zeros(1, dtype=tps.dtype, device=tps.device), tps])
411-
fps = torch.cat([torch.zeros(1, dtype=fps.dtype, device=fps.device), fps])
412-
thresholds = torch.cat([thresholds[0][None] + 1, thresholds])
413-
414-
if fps[-1] <= 0:
415-
raise ValueError("No negative samples in targets, false positive value should be meaningless")
416-
417-
fpr = fps / fps[-1]
418-
419-
if tps[-1] <= 0:
420-
raise ValueError("No positive samples in targets, true positive value should be meaningless")
421-
422-
tpr = tps / tps[-1]
423-
424-
return fpr, tpr, thresholds
425-
426-
427-
# TODO: deprecated in favor of general ROC in pytorch_lightning/metrics/functional/roc.py
428-
def multiclass_roc(
429-
pred: torch.Tensor,
430-
target: torch.Tensor,
431-
sample_weight: Optional[Sequence] = None,
432-
num_classes: Optional[int] = None,
433-
) -> Tuple[Tuple[torch.Tensor, torch.Tensor, torch.Tensor]]:
434-
"""
435-
Computes the Receiver Operating Characteristic (ROC) for multiclass predictors.
436-
437-
.. warning :: Deprecated in favor of :func:`~pytorch_lightning.metrics.functional.roc.roc`
438-
439-
Args:
440-
pred: estimated probabilities
441-
target: ground-truth labels
442-
sample_weight: sample weights
443-
num_classes: number of classes (default: None, computes automatically from data)
444-
445-
Return:
446-
returns roc for each class.
447-
Number of classes, false-positive rate (fpr), true-positive rate (tpr), thresholds
448-
449-
Example:
450-
451-
>>> pred = torch.tensor([[0.85, 0.05, 0.05, 0.05],
452-
... [0.05, 0.85, 0.05, 0.05],
453-
... [0.05, 0.05, 0.85, 0.05],
454-
... [0.05, 0.05, 0.05, 0.85]])
455-
>>> target = torch.tensor([0, 1, 3, 2])
456-
>>> multiclass_roc(pred, target) # doctest: +NORMALIZE_WHITESPACE
457-
((tensor([0., 0., 1.]), tensor([0., 1., 1.]), tensor([1.8500, 0.8500, 0.0500])),
458-
(tensor([0., 0., 1.]), tensor([0., 1., 1.]), tensor([1.8500, 0.8500, 0.0500])),
459-
(tensor([0.0000, 0.3333, 1.0000]), tensor([0., 0., 1.]), tensor([1.8500, 0.8500, 0.0500])),
460-
(tensor([0.0000, 0.3333, 1.0000]), tensor([0., 0., 1.]), tensor([1.8500, 0.8500, 0.0500])))
461-
"""
462-
rank_zero_warn(
463-
"This `multiclass_roc` was deprecated in v1.1.0 in favor of"
464-
" `from pytorch_lightning.metrics.functional.roc import roc`."
465-
" It will be removed in v1.3.0", DeprecationWarning
466-
)
467-
num_classes = get_num_classes(pred, target, num_classes)
468-
469-
class_roc_vals = []
470-
for c in range(num_classes):
471-
pred_c = pred[:, c]
472-
473-
class_roc_vals.append(_roc(pred=pred_c, target=target, sample_weight=sample_weight, pos_label=c))
474-
475-
return tuple(class_roc_vals)
476-
477-
301+
# todo: remove in 1.4
478302
def auc(
479303
x: torch.Tensor,
480304
y: torch.Tensor,
@@ -508,6 +332,7 @@ def auc(
508332
return __auc(x, y)
509333

510334

335+
# todo: remove in 1.4
511336
def auc_decorator() -> Callable:
512337
rank_zero_warn("This `auc_decorator` was deprecated in v1.2.0." " It will be removed in v1.4.0", DeprecationWarning)
513338

@@ -524,6 +349,7 @@ def new_func(*args, **kwargs) -> torch.Tensor:
524349
return wrapper
525350

526351

352+
# todo: remove in 1.4
527353
def multiclass_auc_decorator() -> Callable:
528354
rank_zero_warn(
529355
"This `multiclass_auc_decorator` was deprecated in v1.2.0."
@@ -546,6 +372,7 @@ def new_func(*args, **kwargs) -> torch.Tensor:
546372
return wrapper
547373

548374

375+
# todo: remove in 1.4
549376
def auroc(
550377
pred: torch.Tensor,
551378
target: torch.Tensor,
@@ -588,6 +415,7 @@ def auroc(
588415
)
589416

590417

418+
# todo: remove in 1.4
591419
def multiclass_auroc(
592420
pred: torch.Tensor,
593421
target: torch.Tensor,
@@ -767,68 +595,3 @@ def iou(
767595
num_classes=num_classes,
768596
reduction=reduction
769597
)
770-
771-
772-
# todo: remove in 1.3
773-
def precision_recall_curve(
774-
pred: torch.Tensor,
775-
target: torch.Tensor,
776-
sample_weight: Optional[Sequence] = None,
777-
pos_label: int = 1.,
778-
):
779-
"""
780-
Computes precision-recall pairs for different thresholds.
781-
782-
.. warning :: Deprecated in favor of
783-
:func:`~pytorch_lightning.metrics.functional.precision_recall_curve.precision_recall_curve`
784-
"""
785-
rank_zero_warn(
786-
"This `precision_recall_curve` was deprecated in v1.1.0 in favor of"
787-
" `from pytorch_lightning.metrics.functional.precision_recall_curve import precision_recall_curve`."
788-
" It will be removed in v1.3.0", DeprecationWarning
789-
)
790-
return __prc(preds=pred, target=target, sample_weights=sample_weight, pos_label=pos_label)
791-
792-
793-
# todo: remove in 1.3
794-
def multiclass_precision_recall_curve(
795-
pred: torch.Tensor,
796-
target: torch.Tensor,
797-
sample_weight: Optional[Sequence] = None,
798-
num_classes: Optional[int] = None,
799-
):
800-
"""
801-
Computes precision-recall pairs for different thresholds given a multiclass scores.
802-
803-
.. warning :: Deprecated in favor of
804-
:func:`~pytorch_lightning.metrics.functional.precision_recall_curve.precision_recall_curve`
805-
"""
806-
rank_zero_warn(
807-
"This `multiclass_precision_recall_curve` was deprecated in v1.1.0 in favor of"
808-
" `from pytorch_lightning.metrics.functional.precision_recall_curve import precision_recall_curve`."
809-
" It will be removed in v1.3.0", DeprecationWarning
810-
)
811-
if num_classes is None:
812-
num_classes = get_num_classes(pred, target, num_classes)
813-
return __prc(preds=pred, target=target, sample_weights=sample_weight, num_classes=num_classes)
814-
815-
816-
# todo: remove in 1.3
817-
def average_precision(
818-
pred: torch.Tensor,
819-
target: torch.Tensor,
820-
sample_weight: Optional[Sequence] = None,
821-
pos_label: int = 1.,
822-
):
823-
"""
824-
Compute average precision from prediction scores.
825-
826-
.. warning :: Deprecated in favor of
827-
:func:`~pytorch_lightning.metrics.functional.average_precision.average_precision`
828-
"""
829-
rank_zero_warn(
830-
"This `average_precision` was deprecated in v1.1.0 in favor of"
831-
" `pytorch_lightning.metrics.functional.average_precision import average_precision`."
832-
" It will be removed in v1.3.0", DeprecationWarning
833-
)
834-
return __ap(preds=pred, target=target, sample_weights=sample_weight, pos_label=pos_label)

pytorch_lightning/metrics/functional/iou.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -16,8 +16,7 @@
1616
import torch
1717

1818
from pytorch_lightning.metrics.functional.confusion_matrix import _confusion_matrix_update
19-
from pytorch_lightning.metrics.functional.reduction import reduce
20-
from pytorch_lightning.metrics.utils import get_num_classes
19+
from pytorch_lightning.metrics.utils import get_num_classes, reduce
2120

2221

2322
def _iou_from_confmat(

0 commit comments

Comments
 (0)