13
13
# limitations under the License.
14
14
from typing import Any , Optional
15
15
16
- import torch
17
- from torchmetrics import Metric
16
+ from torchmetrics import F1 as _F1
17
+ from torchmetrics import FBeta as _FBeta
18
18
19
- from pytorch_lightning .metrics .functional .f_beta import _fbeta_compute , _fbeta_update
20
- from pytorch_lightning .utilities import rank_zero_warn
19
+ from pytorch_lightning .utilities .deprecation import deprecated
21
20
22
21
23
- class FBeta (Metric ):
24
- r"""
25
- Computes `F-score <https://en.wikipedia.org/wiki/F-score>`_, specifically:
26
-
27
- .. math::
28
- F_\beta = (1 + \beta^2) * \frac{\text{precision} * \text{recall}}
29
- {(\beta^2 * \text{precision}) + \text{recall}}
30
-
31
- Where :math:`\beta` is some positive real factor. Works with binary, multiclass, and multilabel data.
32
- Accepts probabilities from a model output or integer class values in prediction.
33
- Works with multi-dimensional preds and target.
34
-
35
- Forward accepts
36
-
37
- - ``preds`` (float or long tensor): ``(N, ...)`` or ``(N, C, ...)`` where C is the number of classes
38
- - ``target`` (long tensor): ``(N, ...)``
39
-
40
- If preds and target are the same shape and preds is a float tensor, we use the ``self.threshold`` argument
41
- to convert into integer labels. This is the case for binary and multi-label probabilities.
42
-
43
- If preds has an extra dimension as in the case of multi-class scores we perform an argmax on ``dim=1``.
44
-
45
- Args:
46
- num_classes: Number of classes in the dataset.
47
- beta: Beta coefficient in the F measure.
48
- threshold:
49
- Threshold value for binary or multi-label probabilities. default: 0.5
50
-
51
- average:
52
- - ``'micro'`` computes metric globally
53
- - ``'macro'`` computes metric for each class and uniformly averages them
54
- - ``'weighted'`` computes metric for each class and does a weighted-average,
55
- where each class is weighted by their support (accounts for class imbalance)
56
- - ``'none'`` or ``None`` computes and returns the metric per class
57
-
58
- multilabel: If predictions are from multilabel classification.
59
- compute_on_step:
60
- Forward only calls ``update()`` and return None if this is set to False. default: True
61
- dist_sync_on_step:
62
- Synchronize metric state across processes at each ``forward()``
63
- before returning the value at the step. default: False
64
- process_group:
65
- Specify the process group on which synchronization is called. default: None (which selects the entire world)
66
-
67
- Raises:
68
- ValueError:
69
- If ``average`` is none of ``"micro"``, ``"macro"``, ``"weighted"``, ``"none"``, ``None``.
70
-
71
- Example:
72
-
73
- >>> from pytorch_lightning.metrics import FBeta
74
- >>> target = torch.tensor([0, 1, 2, 0, 1, 2])
75
- >>> preds = torch.tensor([0, 2, 1, 0, 0, 1])
76
- >>> f_beta = FBeta(num_classes=3, beta=0.5)
77
- >>> f_beta(preds, target)
78
- tensor(0.3333)
79
-
80
- """
22
+ class FBeta (_FBeta ):
81
23
24
+ @deprecated (target = _FBeta , ver_deprecate = "1.3.0" , ver_remove = "1.5.0" )
82
25
def __init__ (
83
26
self ,
84
27
num_classes : int ,
@@ -90,103 +33,17 @@ def __init__(
90
33
dist_sync_on_step : bool = False ,
91
34
process_group : Optional [Any ] = None ,
92
35
):
93
- super ().__init__ (
94
- compute_on_step = compute_on_step ,
95
- dist_sync_on_step = dist_sync_on_step ,
96
- process_group = process_group ,
97
- )
98
-
99
- self .num_classes = num_classes
100
- self .beta = beta
101
- self .threshold = threshold
102
- self .average = average
103
- self .multilabel = multilabel
104
-
105
- allowed_average = ("micro" , "macro" , "weighted" , "none" , None )
106
- if self .average not in allowed_average :
107
- raise ValueError (
108
- 'Argument `average` expected to be one of the following:'
109
- f' { allowed_average } but got { self .average } '
110
- )
111
-
112
- self .add_state ("true_positives" , default = torch .zeros (num_classes ), dist_reduce_fx = "sum" )
113
- self .add_state ("predicted_positives" , default = torch .zeros (num_classes ), dist_reduce_fx = "sum" )
114
- self .add_state ("actual_positives" , default = torch .zeros (num_classes ), dist_reduce_fx = "sum" )
115
-
116
- def update (self , preds : torch .Tensor , target : torch .Tensor ):
117
- """
118
- Update state with predictions and targets.
119
-
120
- Args:
121
- preds: Predictions from model
122
- target: Ground truth values
123
36
"""
124
- true_positives , predicted_positives , actual_positives = _fbeta_update (
125
- preds , target , self .num_classes , self .threshold , self .multilabel
126
- )
127
-
128
- self .true_positives += true_positives
129
- self .predicted_positives += predicted_positives
130
- self .actual_positives += actual_positives
37
+ This implementation refers to :class:`~torchmetrics.FBeta`.
131
38
132
- def compute (self ) -> torch .Tensor :
39
+ .. deprecated::
40
+ Use :class:`~torchmetrics.FBeta`. Will be removed in v1.5.0.
133
41
"""
134
- Computes fbeta over state.
135
- """
136
- return _fbeta_compute (
137
- self .true_positives , self .predicted_positives , self .actual_positives , self .beta , self .average
138
- )
139
-
140
-
141
- class F1 (FBeta ):
142
- """
143
- Computes F1 metric. F1 metrics correspond to a harmonic mean of the
144
- precision and recall scores.
145
-
146
- Works with binary, multiclass, and multilabel data.
147
- Accepts logits from a model output or integer class values in prediction.
148
- Works with multi-dimensional preds and target.
149
42
150
- Forward accepts
151
43
152
- - ``preds`` (float or long tensor): ``(N, ...)`` or ``(N, C, ...)`` where C is the number of classes
153
- - ``target`` (long tensor): ``(N, ...)``
154
-
155
- If preds and target are the same shape and preds is a float tensor, we use the ``self.threshold`` argument.
156
- This is the case for binary and multi-label logits.
157
-
158
- If preds has an extra dimension as in the case of multi-class scores we perform an argmax on ``dim=1``.
159
-
160
- Args:
161
- num_classes: Number of classes in the dataset.
162
- threshold:
163
- Threshold value for binary or multi-label logits. default: 0.5
164
-
165
- average:
166
- - ``'micro'`` computes metric globally
167
- - ``'macro'`` computes metric for each class and uniformly averages them
168
- - ``'weighted'`` computes metric for each class and does a weighted-average,
169
- where each class is weighted by their support (accounts for class imbalance)
170
- - ``'none'`` or ``None`` computes and returns the metric per class
171
-
172
- multilabel: If predictions are from multilabel classification.
173
- compute_on_step:
174
- Forward only calls ``update()`` and returns None if this is set to False. default: True
175
- dist_sync_on_step:
176
- Synchronize metric state across processes at each ``forward()``
177
- before returning the value at the step. default: False
178
- process_group:
179
- Specify the process group on which synchronization is called. default: None (which selects the entire world)
180
-
181
- Example:
182
- >>> from pytorch_lightning.metrics import F1
183
- >>> target = torch.tensor([0, 1, 2, 0, 1, 2])
184
- >>> preds = torch.tensor([0, 2, 1, 0, 0, 1])
185
- >>> f1 = F1(num_classes=3)
186
- >>> f1(preds, target)
187
- tensor(0.3333)
188
- """
44
+ class F1 (_F1 ):
189
45
46
+ @deprecated (target = _F1 , ver_deprecate = "1.3.0" , ver_remove = "1.5.0" )
190
47
def __init__ (
191
48
self ,
192
49
num_classes : int ,
@@ -197,16 +54,9 @@ def __init__(
197
54
dist_sync_on_step : bool = False ,
198
55
process_group : Optional [Any ] = None ,
199
56
):
200
- if multilabel is not False :
201
- rank_zero_warn ( f'The `multilabel= { multilabel } ` parameter is unused and will not have any effect.' )
57
+ """
58
+ This implementation refers to :class:`~torchmetrics.F1`.
202
59
203
- super ().__init__ (
204
- num_classes = num_classes ,
205
- beta = 1.0 ,
206
- threshold = threshold ,
207
- average = average ,
208
- multilabel = multilabel ,
209
- compute_on_step = compute_on_step ,
210
- dist_sync_on_step = dist_sync_on_step ,
211
- process_group = process_group ,
212
- )
60
+ .. deprecated::
61
+ Use :class:`~torchmetrics.F1`. Will be removed in v1.5.0.
62
+ """
0 commit comments