Skip to content

Commit 06756a8

Browse files
dipam7Bordarohitgr7akihironitta
authored
document exceptions for metrics/functional (#6273)
* document exceptions for metrics/functional * Apply suggestions from code review Co-authored-by: Rohit Gupta <[email protected]> * Apply suggestions from code review Co-authored-by: Jirka Borovec <[email protected]> Co-authored-by: Rohit Gupta <[email protected]> Co-authored-by: Akihiro Nitta <[email protected]>
1 parent 156847b commit 06756a8

File tree

11 files changed

+130
-0
lines changed

11 files changed

+130
-0
lines changed

pytorch_lightning/metrics/functional/accuracy.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -102,6 +102,10 @@ def accuracy(
102102
``preds = preds.flatten()`` and same for ``target``). Note that the ``top_k`` parameter
103103
still applies in both cases, if set.
104104
105+
Raises:
106+
ValueError:
107+
If ``top_k`` parameter is set for ``multi-label`` inputs.
108+
105109
Example:
106110
107111
>>> from pytorch_lightning.metrics.functional import accuracy

pytorch_lightning/metrics/functional/auc.py

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -63,6 +63,14 @@ def auc(x: torch.Tensor, y: torch.Tensor, reorder: bool = False) -> torch.Tensor
6363
Return:
6464
Tensor containing AUC score (float)
6565
66+
Raises:
67+
ValueError:
68+
If both ``x`` and ``y`` tensors are not ``1d``.
69+
ValueError:
70+
If both ``x`` and ``y`` don't have the same numnber of elements.
71+
ValueError:
72+
If ``x`` tesnsor is neither increasing or decreasing.
73+
6674
Example:
6775
>>> from pytorch_lightning.metrics.functional import auc
6876
>>> x = torch.tensor([0, 1, 2, 3])

pytorch_lightning/metrics/functional/auroc.py

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -165,6 +165,18 @@ def auroc(
165165
range [0, max_fpr]. Should be a float between 0 and 1.
166166
sample_weight: sample weights for each data point
167167
168+
Raises:
169+
ValueError:
170+
If ``max_fpr`` is not a ``float`` in the range ``(0, 1]``.
171+
RuntimeError:
172+
If ``PyTorch version`` is ``below 1.6`` since max_fpr requires `torch.bucketize`
173+
which is not available below 1.6.
174+
ValueError:
175+
If ``max_fpr`` is not set to ``None`` and the mode is ``not binary``
176+
since partial AUC computation is not available in multilabel/multiclass.
177+
ValueError:
178+
If ``average`` is none of ``None``, ``"macro"`` or ``"weighted"``.
179+
168180
Example (binary case):
169181
170182
>>> from pytorch_lightning.metrics.functional import auroc

pytorch_lightning/metrics/functional/classification.py

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -77,6 +77,9 @@ def stat_scores_multiple_classes(
7777
7878
.. warning :: Deprecated in favor of :func:`~pytorch_lightning.metrics.functional.stat_scores`
7979
80+
Raises:
81+
ValueError:
82+
If ``reduction`` is not one of ``"none"``, ``"sum"`` or ``"elementwise_mean"``.
8083
"""
8184

8285
rank_zero_warn(
@@ -439,6 +442,16 @@ def multiclass_auroc(
439442
Return:
440443
Tensor containing ROCAUC score
441444
445+
Raises:
446+
ValueError:
447+
If ``pred`` don't sum up to ``1`` over classes for ``Multiclass AUROC``.
448+
ValueError:
449+
If number of classes found in ``target`` does not equal the number of
450+
columns in ``pred``.
451+
ValueError:
452+
If number of classes deduced from ``pred`` does not equal the number of
453+
classes passed in ``num_classes``.
454+
442455
Example:
443456
444457
>>> pred = torch.tensor([[0.85, 0.05, 0.05, 0.05],

pytorch_lightning/metrics/functional/image_gradients.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -54,6 +54,12 @@ def image_gradients(img: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
5454
Return:
5555
Tuple of (dy, dx) with each gradient of shape ``[N, C, H, W]``
5656
57+
Raises:
58+
TypeError:
59+
If ``img`` is not of the type <torch.Tensor>.
60+
RuntimeError:
61+
If ``img`` is not a 4D tensor.
62+
5763
Example:
5864
>>> from pytorch_lightning.metrics.functional import image_gradients
5965
>>> image = torch.arange(0, 1*1*5*5, dtype=torch.float32)

pytorch_lightning/metrics/functional/precision_recall.py

Lines changed: 36 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -133,6 +133,18 @@ def precision(
133133
- If ``average in ['none', None]``, the shape will be ``(C,)``, where ``C`` stands for the number
134134
of classes
135135
136+
Raises:
137+
ValueError:
138+
If ``average`` is not one of ``"micro"``, ``"macro"``, ``"weighted"``,
139+
``"samples"``, ``"none"`` or ``None``.
140+
ValueError:
141+
If ``mdmc_average`` is not one of ``None``, ``"samplewise"``, ``"global"``.
142+
ValueError:
143+
If ``average`` is set but ``num_classes`` is not provided.
144+
ValueError:
145+
If ``num_classes`` is set
146+
and ``ignore_index`` is not in the range ``[0, num_classes)``.
147+
136148
Example:
137149
138150
>>> from pytorch_lightning.metrics.functional import precision
@@ -295,6 +307,18 @@ def recall(
295307
- If ``average in ['none', None]``, the shape will be ``(C,)``, where ``C`` stands for the number
296308
of classes
297309
310+
Raises:
311+
ValueError:
312+
If ``average`` is not one of ``"micro"``, ``"macro"``, ``"weighted"``,
313+
``"samples"``, ``"none"`` or ``None``.
314+
ValueError:
315+
If ``mdmc_average`` is not one of ``None``, ``"samplewise"``, ``"global"``.
316+
ValueError:
317+
If ``average`` is set but ``num_classes`` is not provided.
318+
ValueError:
319+
If ``num_classes`` is set
320+
and ``ignore_index`` is not in the range ``[0, num_classes)``.
321+
298322
Example:
299323
300324
>>> from pytorch_lightning.metrics.functional import recall
@@ -444,6 +468,18 @@ def precision_recall(
444468
- If ``average in ['none', None]``, they are a tensor of shape ``(C, )``, where ``C`` stands for
445469
the number of classes
446470
471+
Raises:
472+
ValueError:
473+
If ``average`` is not one of ``"micro"``, ``"macro"``, ``"weighted"``,
474+
``"samples"``, ``"none"`` or ``None``.
475+
ValueError:
476+
If ``mdmc_average`` is not one of ``None``, ``"samplewise"``, ``"global"``.
477+
ValueError:
478+
If ``average`` is set but ``num_classes`` is not provided.
479+
ValueError:
480+
If ``num_classes`` is set
481+
and ``ignore_index`` is not in the range ``[0, num_classes)``.
482+
447483
Example:
448484
449485
>>> from pytorch_lightning.metrics.functional import precision_recall

pytorch_lightning/metrics/functional/precision_recall_curve.py

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -197,6 +197,14 @@ def precision_recall_curve(
197197
>>> thresholds
198198
tensor([1, 2, 3])
199199
200+
Raises:
201+
ValueError:
202+
If ``preds`` and ``target`` don't have the same number of dimensions,
203+
or one additional dimension for ``preds``.
204+
ValueError:
205+
If the number of classes deduced from ``preds`` is not the same as the
206+
``num_classes`` provided.
207+
200208
Example (multiclass case):
201209
202210
>>> from pytorch_lightning.metrics.functional import precision_recall_curve

pytorch_lightning/metrics/functional/psnr.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -84,6 +84,10 @@ def psnr(
8484
Return:
8585
Tensor with PSNR score
8686
87+
Raises:
88+
ValueError:
89+
If ``dim`` is not ``None`` and ``data_range`` is not provided.
90+
8791
Example:
8892
>>> from pytorch_lightning.metrics.functional import psnr
8993
>>> pred = torch.tensor([[0.0, 1.0], [2.0, 3.0]])

pytorch_lightning/metrics/functional/r2score.py

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -114,6 +114,18 @@ def r2score(
114114
* ``'uniform_average'`` scores are uniformly averaged
115115
* ``'variance_weighted'`` scores are weighted by their individual variances
116116
117+
Raises:
118+
ValueError:
119+
If both ``preds`` and ``targets`` are not ``1D`` or ``2D`` tensors.
120+
ValueError:
121+
If ``len(preds)`` is less than ``2``
122+
since at least ``2`` sampels are needed to calculate r2 score.
123+
ValueError:
124+
If ``multioutput`` is not one of ``raw_values``,
125+
``uniform_average`` or ``variance_weighted``.
126+
ValueError:
127+
If ``adjusted`` is not an ``integer`` greater than ``0``.
128+
117129
Example:
118130
119131
>>> from pytorch_lightning.metrics.functional import r2score

pytorch_lightning/metrics/functional/ssim.py

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -143,6 +143,18 @@ def ssim(
143143
Return:
144144
Tensor with SSIM score
145145
146+
Raises:
147+
TypeError:
148+
If ``preds`` and ``target`` don't have the same data type.
149+
ValueError:
150+
If ``preds`` and ``target`` don't have ``BxCxHxW shape``.
151+
ValueError:
152+
If the length of ``kernel_size`` or ``sigma`` is not ``2``.
153+
ValueError:
154+
If one of the elements of ``kernel_size`` is not an ``odd positive number``.
155+
ValueError:
156+
If one of the elements of ``sigma`` is not a ``positive number``.
157+
146158
Example:
147159
>>> from pytorch_lightning.metrics.functional import ssim
148160
>>> preds = torch.rand([16, 1, 16, 16])

pytorch_lightning/metrics/functional/stat_scores.py

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -244,6 +244,21 @@ def stat_scores(
244244
- If ``reduce='macro'``, the shape will be ``(N, C, 5)``
245245
- If ``reduce='samples'``, the shape will be ``(N, X, 5)``
246246
247+
Raises:
248+
ValueError:
249+
If ``reduce`` is none of ``"micro"``, ``"macro"`` or ``"samples"``.
250+
ValueError:
251+
If ``mdmc_reduce`` is none of ``None``, ``"samplewise"``, ``"global"``.
252+
ValueError:
253+
If ``reduce`` is set to ``"macro"`` and ``num_classes`` is not provided.
254+
ValueError:
255+
If ``num_classes`` is set
256+
and ``ignore_index`` is not in the range ``[0, num_classes)``.
257+
ValueError:
258+
If ``ignore_index`` is used with ``binary data``.
259+
ValueError:
260+
If inputs are ``multi-dimensional multi-class`` and ``mdmc_reduce`` is not provided.
261+
247262
Example:
248263
249264
>>> from pytorch_lightning.metrics.functional import stat_scores

0 commit comments

Comments
 (0)