18
18
19
19
from pytorch_lightning .metrics .functional .auc import auc as __auc
20
20
from pytorch_lightning .metrics .functional .auroc import auroc as __auroc
21
- from pytorch_lightning .metrics .functional .average_precision import average_precision as __ap
22
21
from pytorch_lightning .metrics .functional .iou import iou as __iou
23
- from pytorch_lightning .metrics .functional .precision_recall_curve import _binary_clf_curve
24
- from pytorch_lightning .metrics .functional .precision_recall_curve import precision_recall_curve as __prc
25
- from pytorch_lightning .metrics .functional .roc import roc as __roc
26
- from pytorch_lightning .metrics .utils import class_reduce
27
- from pytorch_lightning .metrics .utils import get_num_classes as __gnc
28
- from pytorch_lightning .metrics .utils import reduce
29
- from pytorch_lightning .metrics .utils import to_categorical as __tc
30
- from pytorch_lightning .metrics .utils import to_onehot as __to
22
+ from pytorch_lightning .metrics .utils import class_reduce , get_num_classes , reduce , to_categorical
31
23
from pytorch_lightning .utilities import rank_zero_warn
32
24
33
25
34
- def to_onehot (
35
- tensor : torch .Tensor ,
36
- num_classes : Optional [int ] = None ,
37
- ) -> torch .Tensor :
38
- """
39
- Converts a dense label tensor to one-hot format
40
-
41
- .. warning :: Deprecated in favor of :func:`~pytorch_lightning.metrics.utils.to_onehot`
42
- """
43
- rank_zero_warn (
44
- "This `to_onehot` was deprecated in v1.1.0 in favor of"
45
- " `from pytorch_lightning.metrics.utils import to_onehot`."
46
- " It will be removed in v1.3.0" , DeprecationWarning
47
- )
48
- return __to (tensor , num_classes )
49
-
50
-
51
- def to_categorical (tensor : torch .Tensor , argmax_dim : int = 1 ) -> torch .Tensor :
52
- """
53
- Converts a tensor of probabilities to a dense label tensor
54
-
55
- .. warning :: Deprecated in favor of :func:`~pytorch_lightning.metrics.utils.to_categorical`
56
-
57
- """
58
- rank_zero_warn (
59
- "This `to_categorical` was deprecated in v1.1.0 in favor of"
60
- " `from pytorch_lightning.metrics.utils import to_categorical`."
61
- " It will be removed in v1.3.0" , DeprecationWarning
62
- )
63
- return __tc (tensor )
64
-
65
-
66
- def get_num_classes (
67
- pred : torch .Tensor ,
68
- target : torch .Tensor ,
69
- num_classes : Optional [int ] = None ,
70
- ) -> int :
71
- """
72
- Calculates the number of classes for a given prediction and target tensor.
73
-
74
- .. warning :: Deprecated in favor of :func:`~pytorch_lightning.metrics.utils.get_num_classes`
75
-
76
- """
77
- rank_zero_warn (
78
- "This `get_num_classes` was deprecated in v1.1.0 in favor of"
79
- " `from pytorch_lightning.metrics.utils import get_num_classes`."
80
- " It will be removed in v1.3.0" , DeprecationWarning
81
- )
82
- return __gnc (pred , target , num_classes )
83
-
84
-
85
26
def stat_scores (
86
27
pred : torch .Tensor ,
87
28
target : torch .Tensor ,
@@ -122,6 +63,7 @@ def stat_scores(
122
63
return tp , fp , tn , fn , sup
123
64
124
65
66
+ # todo: remove in 1.4
125
67
def stat_scores_multiple_classes (
126
68
pred : torch .Tensor ,
127
69
target : torch .Tensor ,
@@ -210,6 +152,7 @@ def _confmat_normalize(cm):
210
152
return cm
211
153
212
154
155
+ # todo: remove in 1.4
213
156
def precision_recall (
214
157
pred : torch .Tensor ,
215
158
target : torch .Tensor ,
@@ -268,6 +211,7 @@ def precision_recall(
268
211
return precision , recall
269
212
270
213
214
+ # todo: remove in 1.4
271
215
def precision (
272
216
pred : torch .Tensor ,
273
217
target : torch .Tensor ,
@@ -311,6 +255,7 @@ def precision(
311
255
return precision_recall (pred = pred , target = target , num_classes = num_classes , class_reduction = class_reduction )[0 ]
312
256
313
257
258
+ # todo: remove in 1.4
314
259
def recall (
315
260
pred : torch .Tensor ,
316
261
target : torch .Tensor ,
@@ -353,128 +298,7 @@ def recall(
353
298
return precision_recall (pred = pred , target = target , num_classes = num_classes , class_reduction = class_reduction )[1 ]
354
299
355
300
356
- # todo: remove in 1.3
357
- def roc (
358
- pred : torch .Tensor ,
359
- target : torch .Tensor ,
360
- sample_weight : Optional [Sequence ] = None ,
361
- pos_label : int = 1. ,
362
- ) -> Tuple [torch .Tensor , torch .Tensor , torch .Tensor ]:
363
- """
364
- Computes the Receiver Operating Characteristic (ROC). It assumes classifier is binary.
365
-
366
- .. warning :: Deprecated in favor of :func:`~pytorch_lightning.metrics.functional.roc.roc`
367
- """
368
- rank_zero_warn (
369
- "This `multiclass_roc` was deprecated in v1.1.0 in favor of"
370
- " `from pytorch_lightning.metrics.functional.roc import roc`."
371
- " It will be removed in v1.3.0" , DeprecationWarning
372
- )
373
- return __roc (preds = pred , target = target , sample_weights = sample_weight , pos_label = pos_label )
374
-
375
-
376
- # TODO: deprecated in favor of general ROC in pytorch_lightning/metrics/functional/roc.py
377
- def _roc (
378
- pred : torch .Tensor ,
379
- target : torch .Tensor ,
380
- sample_weight : Optional [Sequence ] = None ,
381
- pos_label : int = 1. ,
382
- ) -> Tuple [torch .Tensor , torch .Tensor , torch .Tensor ]:
383
- """
384
- Computes the Receiver Operating Characteristic (ROC). It assumes classifier is binary.
385
-
386
- .. warning :: Deprecated in favor of :func:`~pytorch_lightning.metrics.functional.roc.roc`
387
-
388
- Example:
389
-
390
- >>> x = torch.tensor([0, 1, 2, 3])
391
- >>> y = torch.tensor([0, 1, 1, 1])
392
- >>> fpr, tpr, thresholds = _roc(x, y)
393
- >>> fpr
394
- tensor([0., 0., 0., 0., 1.])
395
- >>> tpr
396
- tensor([0.0000, 0.3333, 0.6667, 1.0000, 1.0000])
397
- >>> thresholds
398
- tensor([4, 3, 2, 1, 0])
399
-
400
- """
401
- rank_zero_warn (
402
- "This `multiclass_roc` was deprecated in v1.1.0 in favor of"
403
- " `from pytorch_lightning.metrics.functional.roc import roc`."
404
- " It will be removed in v1.3.0" , DeprecationWarning
405
- )
406
- fps , tps , thresholds = _binary_clf_curve (pred , target , sample_weights = sample_weight , pos_label = pos_label )
407
-
408
- # Add an extra threshold position
409
- # to make sure that the curve starts at (0, 0)
410
- tps = torch .cat ([torch .zeros (1 , dtype = tps .dtype , device = tps .device ), tps ])
411
- fps = torch .cat ([torch .zeros (1 , dtype = fps .dtype , device = fps .device ), fps ])
412
- thresholds = torch .cat ([thresholds [0 ][None ] + 1 , thresholds ])
413
-
414
- if fps [- 1 ] <= 0 :
415
- raise ValueError ("No negative samples in targets, false positive value should be meaningless" )
416
-
417
- fpr = fps / fps [- 1 ]
418
-
419
- if tps [- 1 ] <= 0 :
420
- raise ValueError ("No positive samples in targets, true positive value should be meaningless" )
421
-
422
- tpr = tps / tps [- 1 ]
423
-
424
- return fpr , tpr , thresholds
425
-
426
-
427
- # TODO: deprecated in favor of general ROC in pytorch_lightning/metrics/functional/roc.py
428
- def multiclass_roc (
429
- pred : torch .Tensor ,
430
- target : torch .Tensor ,
431
- sample_weight : Optional [Sequence ] = None ,
432
- num_classes : Optional [int ] = None ,
433
- ) -> Tuple [Tuple [torch .Tensor , torch .Tensor , torch .Tensor ]]:
434
- """
435
- Computes the Receiver Operating Characteristic (ROC) for multiclass predictors.
436
-
437
- .. warning :: Deprecated in favor of :func:`~pytorch_lightning.metrics.functional.roc.roc`
438
-
439
- Args:
440
- pred: estimated probabilities
441
- target: ground-truth labels
442
- sample_weight: sample weights
443
- num_classes: number of classes (default: None, computes automatically from data)
444
-
445
- Return:
446
- returns roc for each class.
447
- Number of classes, false-positive rate (fpr), true-positive rate (tpr), thresholds
448
-
449
- Example:
450
-
451
- >>> pred = torch.tensor([[0.85, 0.05, 0.05, 0.05],
452
- ... [0.05, 0.85, 0.05, 0.05],
453
- ... [0.05, 0.05, 0.85, 0.05],
454
- ... [0.05, 0.05, 0.05, 0.85]])
455
- >>> target = torch.tensor([0, 1, 3, 2])
456
- >>> multiclass_roc(pred, target) # doctest: +NORMALIZE_WHITESPACE
457
- ((tensor([0., 0., 1.]), tensor([0., 1., 1.]), tensor([1.8500, 0.8500, 0.0500])),
458
- (tensor([0., 0., 1.]), tensor([0., 1., 1.]), tensor([1.8500, 0.8500, 0.0500])),
459
- (tensor([0.0000, 0.3333, 1.0000]), tensor([0., 0., 1.]), tensor([1.8500, 0.8500, 0.0500])),
460
- (tensor([0.0000, 0.3333, 1.0000]), tensor([0., 0., 1.]), tensor([1.8500, 0.8500, 0.0500])))
461
- """
462
- rank_zero_warn (
463
- "This `multiclass_roc` was deprecated in v1.1.0 in favor of"
464
- " `from pytorch_lightning.metrics.functional.roc import roc`."
465
- " It will be removed in v1.3.0" , DeprecationWarning
466
- )
467
- num_classes = get_num_classes (pred , target , num_classes )
468
-
469
- class_roc_vals = []
470
- for c in range (num_classes ):
471
- pred_c = pred [:, c ]
472
-
473
- class_roc_vals .append (_roc (pred = pred_c , target = target , sample_weight = sample_weight , pos_label = c ))
474
-
475
- return tuple (class_roc_vals )
476
-
477
-
301
+ # todo: remove in 1.4
478
302
def auc (
479
303
x : torch .Tensor ,
480
304
y : torch .Tensor ,
@@ -508,6 +332,7 @@ def auc(
508
332
return __auc (x , y )
509
333
510
334
335
+ # todo: remove in 1.4
511
336
def auc_decorator () -> Callable :
512
337
rank_zero_warn ("This `auc_decorator` was deprecated in v1.2.0." " It will be removed in v1.4.0" , DeprecationWarning )
513
338
@@ -524,6 +349,7 @@ def new_func(*args, **kwargs) -> torch.Tensor:
524
349
return wrapper
525
350
526
351
352
+ # todo: remove in 1.4
527
353
def multiclass_auc_decorator () -> Callable :
528
354
rank_zero_warn (
529
355
"This `multiclass_auc_decorator` was deprecated in v1.2.0."
@@ -546,6 +372,7 @@ def new_func(*args, **kwargs) -> torch.Tensor:
546
372
return wrapper
547
373
548
374
375
+ # todo: remove in 1.4
549
376
def auroc (
550
377
pred : torch .Tensor ,
551
378
target : torch .Tensor ,
@@ -588,6 +415,7 @@ def auroc(
588
415
)
589
416
590
417
418
+ # todo: remove in 1.4
591
419
def multiclass_auroc (
592
420
pred : torch .Tensor ,
593
421
target : torch .Tensor ,
@@ -767,68 +595,3 @@ def iou(
767
595
num_classes = num_classes ,
768
596
reduction = reduction
769
597
)
770
-
771
-
772
- # todo: remove in 1.3
773
- def precision_recall_curve (
774
- pred : torch .Tensor ,
775
- target : torch .Tensor ,
776
- sample_weight : Optional [Sequence ] = None ,
777
- pos_label : int = 1. ,
778
- ):
779
- """
780
- Computes precision-recall pairs for different thresholds.
781
-
782
- .. warning :: Deprecated in favor of
783
- :func:`~pytorch_lightning.metrics.functional.precision_recall_curve.precision_recall_curve`
784
- """
785
- rank_zero_warn (
786
- "This `precision_recall_curve` was deprecated in v1.1.0 in favor of"
787
- " `from pytorch_lightning.metrics.functional.precision_recall_curve import precision_recall_curve`."
788
- " It will be removed in v1.3.0" , DeprecationWarning
789
- )
790
- return __prc (preds = pred , target = target , sample_weights = sample_weight , pos_label = pos_label )
791
-
792
-
793
- # todo: remove in 1.3
794
- def multiclass_precision_recall_curve (
795
- pred : torch .Tensor ,
796
- target : torch .Tensor ,
797
- sample_weight : Optional [Sequence ] = None ,
798
- num_classes : Optional [int ] = None ,
799
- ):
800
- """
801
- Computes precision-recall pairs for different thresholds given a multiclass scores.
802
-
803
- .. warning :: Deprecated in favor of
804
- :func:`~pytorch_lightning.metrics.functional.precision_recall_curve.precision_recall_curve`
805
- """
806
- rank_zero_warn (
807
- "This `multiclass_precision_recall_curve` was deprecated in v1.1.0 in favor of"
808
- " `from pytorch_lightning.metrics.functional.precision_recall_curve import precision_recall_curve`."
809
- " It will be removed in v1.3.0" , DeprecationWarning
810
- )
811
- if num_classes is None :
812
- num_classes = get_num_classes (pred , target , num_classes )
813
- return __prc (preds = pred , target = target , sample_weights = sample_weight , num_classes = num_classes )
814
-
815
-
816
- # todo: remove in 1.3
817
- def average_precision (
818
- pred : torch .Tensor ,
819
- target : torch .Tensor ,
820
- sample_weight : Optional [Sequence ] = None ,
821
- pos_label : int = 1. ,
822
- ):
823
- """
824
- Compute average precision from prediction scores.
825
-
826
- .. warning :: Deprecated in favor of
827
- :func:`~pytorch_lightning.metrics.functional.average_precision.average_precision`
828
- """
829
- rank_zero_warn (
830
- "This `average_precision` was deprecated in v1.1.0 in favor of"
831
- " `pytorch_lightning.metrics.functional.average_precision import average_precision`."
832
- " It will be removed in v1.3.0" , DeprecationWarning
833
- )
834
- return __ap (preds = pred , target = target , sample_weights = sample_weight , pos_label = pos_label )
0 commit comments