Skip to content

Commit 5d73fbb

Browse files
lucadilielloBordaSkafteNickimergify[bot]
authored
Mean Average Precision metric for Information Retrieval (1/5) (#5032)
* init information retrieval metrics * changed retrieval metrics names, expanded arguments and fixed typo * added 'Retrieval' prefix to metrics and fixed conflict with already-present 'average_precision' file * improved code formatting * pep8 code compatibility * features/implemented new Mean Average Precision metrics for Information Retrieval + doc * fixed pep8 compatibility * removed threshold parameter and fixed typo on types in RetrievalMAP and improved doc * improved doc, put first class-specific args in RetrievalMetric and transformed RetrievalMetric in abstract class * implemented tests for functional and class metric. fixed typo when input tensors are empty or when all targets are False * fixed typos in doc and changed torch.true_divide to torch.div * fixed typos pep8 compatibility * fixed types in long division in ir_average_precision and example in mean_average_precision * RetrievalMetric states are not lists and _metric method accepts predictions and targets for easier extension * updated CHANGELOG file * added '# noqa: F401' flag to not used imports * added double space before '# noqa: F401' flag * Update CHANGELOG.md Co-authored-by: Jirka Borovec <[email protected]> * change get_mini_groups in get_group_indexes * added checks on target inputs * minor refactoring for code cleanness * split tests over exception raising in separate function && refactored test code into multiple functions * fixed pep8 compatibility * implemented suggestions of @SkafteNicki * fixed imports for isort and added types annontations to functions in test_map.py * isort on test_map and fixed typing * isort on retrieval and on __init__.py and utils.py in metrics package * fixed typo in pytorch_lightning/metrics/__init__.py regarding code style * fixed yapf compatibility * fixed yapf compatibility * fixed typo in doc Co-authored-by: Jirka Borovec <[email protected]> Co-authored-by: Nicki Skafte <[email protected]> Co-authored-by: mergify[bot] <37929162+mergify[bot]@users.noreply.github.com>
1 parent 06756a8 commit 5d73fbb

File tree

12 files changed

+484
-1
lines changed

12 files changed

+484
-1
lines changed

CHANGELOG.md

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,9 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
99

1010
### Added
1111

12+
- Added `RetrievalMAP` metric, the corresponding functional version `retrieval_average_precision` and a generic superclass for retrieval metrics `RetrievalMetric` ([#5032](https://github.com/PyTorchLightning/pytorch-lightning/pull/5032))
13+
14+
1215
- Added a way to print to terminal without breaking up the progress bar ([#5470](https://github.com/PyTorchLightning/pytorch-lightning/pull/5470))
1316

1417
- Added support to checkpoint after training steps in `ModelCheckpoint` callback ([#6146](https://github.com/PyTorchLightning/pytorch-lightning/pull/6146))

docs/source/extensions/metrics.rst

Lines changed: 24 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -876,6 +876,30 @@ bleu_score [func]
876876
.. autofunction:: pytorch_lightning.metrics.functional.bleu_score
877877
:noindex:
878878

879+
*****************************
880+
Information Retrieval Metrics
881+
*****************************
882+
883+
Class Metrics (IR)
884+
------------------
885+
886+
Mean Average Precision
887+
~~~~~~~~~~~~~~~~~~~~~~
888+
889+
.. autoclass:: pytorch_lightning.metrics.retrieval.RetrievalMAP
890+
:noindex:
891+
892+
893+
Functional Metrics (IR)
894+
-----------------------
895+
896+
average_precision_retrieval [func]
897+
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
898+
899+
.. autofunction:: pytorch_lightning.metrics.functional.ir_average_precision.retrieval_average_precision
900+
:noindex:
901+
902+
879903
********
880904
Pairwise
881905
********

pytorch_lightning/metrics/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -37,3 +37,4 @@
3737
R2Score,
3838
SSIM,
3939
)
40+
from pytorch_lightning.metrics.retrieval import RetrievalMAP # noqa: F401

pytorch_lightning/metrics/functional/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -28,6 +28,7 @@
2828
from pytorch_lightning.metrics.functional.hamming_distance import hamming_distance # noqa: F401
2929
from pytorch_lightning.metrics.functional.image_gradients import image_gradients # noqa: F401
3030
from pytorch_lightning.metrics.functional.iou import iou # noqa: F401
31+
from pytorch_lightning.metrics.functional.ir_average_precision import retrieval_average_precision # noqa: F401
3132
from pytorch_lightning.metrics.functional.mean_absolute_error import mean_absolute_error # noqa: F401
3233
from pytorch_lightning.metrics.functional.mean_squared_error import mean_squared_error # noqa: F401
3334
from pytorch_lightning.metrics.functional.mean_squared_log_error import mean_squared_log_error # noqa: F401
Lines changed: 54 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,54 @@
1+
# Copyright The PyTorch Lightning team.
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
import torch
15+
16+
17+
def retrieval_average_precision(preds: torch.Tensor, target: torch.Tensor) -> torch.Tensor:
18+
r"""
19+
Computes average precision (for information retrieval), as explained
20+
`here <https://en.wikipedia.org/wiki/Evaluation_measures_(information_retrieval)#Average_precision>`_.
21+
22+
`preds` and `target` should be of the same shape and live on the same device. If no `target` is ``True``,
23+
0 is returned. Target must be of type `bool` or `int`, otherwise an error is raised.
24+
25+
Args:
26+
preds: estimated probabilities of each document to be relevant.
27+
target: ground truth about each document being relevant or not. Requires `bool` or `int` tensor.
28+
29+
Return:
30+
a single-value tensor with the average precision (AP) of the predictions `preds` wrt the labels `target`.
31+
32+
Example:
33+
>>> preds = torch.tensor([0.2, 0.3, 0.5])
34+
>>> target = torch.tensor([True, False, True])
35+
>>> retrieval_average_precision(preds, target)
36+
tensor(0.8333)
37+
"""
38+
39+
if preds.shape != target.shape or preds.device != target.device:
40+
raise ValueError("`preds` and `target` must have the same shape and live on the same device")
41+
42+
if target.dtype not in (torch.bool, torch.int16, torch.int32, torch.int64):
43+
raise ValueError("`target` must be a tensor of booleans or integers")
44+
45+
if target.dtype is not torch.bool:
46+
target = target.bool()
47+
48+
if target.sum() == 0:
49+
return torch.tensor(0, device=preds.device)
50+
51+
target = target[torch.argsort(preds, dim=-1, descending=True)]
52+
positions = torch.arange(1, len(target) + 1, device=target.device, dtype=torch.float32)[target > 0]
53+
res = torch.div((torch.arange(len(positions), device=positions.device, dtype=torch.float32) + 1), positions).mean()
54+
return res
Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,15 @@
1+
# Copyright The PyTorch Lightning team.
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
from pytorch_lightning.metrics.retrieval.mean_average_precision import RetrievalMAP # noqa: F401
15+
from pytorch_lightning.metrics.retrieval.retrieval_metric import RetrievalMetric # noqa: F401
Lines changed: 61 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,61 @@
1+
import torch
2+
3+
from pytorch_lightning.metrics.functional.ir_average_precision import retrieval_average_precision
4+
from pytorch_lightning.metrics.retrieval.retrieval_metric import RetrievalMetric
5+
6+
7+
class RetrievalMAP(RetrievalMetric):
8+
r"""
9+
Computes `Mean Average Precision
10+
<https://en.wikipedia.org/wiki/Evaluation_measures_(information_retrieval)#Mean_average_precision>`_.
11+
12+
Works with binary data. Accepts integer or float predictions from a model output.
13+
14+
Forward accepts
15+
- ``indexes`` (long tensor): ``(N, ...)``
16+
- ``preds`` (float tensor): ``(N, ...)``
17+
- ``target`` (long or bool tensor): ``(N, ...)``
18+
19+
`indexes`, `preds` and `target` must have the same dimension.
20+
`indexes` indicate to which query a prediction belongs.
21+
Predictions will be first grouped by indexes and then MAP will be computed as the mean
22+
of the Average Precisions over each query.
23+
24+
Args:
25+
query_without_relevant_docs:
26+
Specify what to do with queries that do not have at least a positive target. Choose from:
27+
28+
- ``'skip'``: skip those queries (default); if all queries are skipped, ``0.0`` is returned
29+
- ``'error'``: raise a ``ValueError``
30+
- ``'pos'``: score on those queries is counted as ``1.0``
31+
- ``'neg'``: score on those queries is counted as ``0.0``
32+
exclude:
33+
Do not take into account predictions where the target is equal to this value. default `-100`
34+
compute_on_step:
35+
Forward only calls ``update()`` and return None if this is set to False. default: True
36+
dist_sync_on_step:
37+
Synchronize metric state across processes at each ``forward()``
38+
before returning the value at the step. default: False
39+
process_group:
40+
Specify the process group on which synchronization is called. default: None (which selects
41+
the entire world)
42+
dist_sync_fn:
43+
Callback that performs the allgather operation on the metric state. When `None`, DDP
44+
will be used to perform the allgather. default: None
45+
46+
Example:
47+
>>> from pytorch_lightning.metrics import RetrievalMAP
48+
>>> indexes = torch.tensor([0, 0, 0, 1, 1, 1, 1])
49+
>>> preds = torch.tensor([0.2, 0.3, 0.5, 0.1, 0.3, 0.5, 0.2])
50+
>>> target = torch.tensor([False, False, True, False, True, False, False])
51+
52+
>>> map = RetrievalMAP()
53+
>>> map(indexes, preds, target)
54+
tensor(0.7500)
55+
>>> map.compute()
56+
tensor(0.7500)
57+
"""
58+
59+
def _metric(self, preds: torch.Tensor, target: torch.Tensor) -> torch.Tensor:
60+
valid_indexes = target != self.exclude
61+
return retrieval_average_precision(preds[valid_indexes], target[valid_indexes])
Lines changed: 140 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,140 @@
1+
from abc import ABC, abstractmethod
2+
from typing import Any, Callable, Optional
3+
4+
import torch
5+
6+
from pytorch_lightning.metrics import Metric
7+
from pytorch_lightning.metrics.utils import get_group_indexes
8+
9+
#: get_group_indexes is used to group predictions belonging to the same query
10+
11+
IGNORE_IDX = -100
12+
13+
14+
class RetrievalMetric(Metric, ABC):
15+
r"""
16+
Works with binary data. Accepts integer or float predictions from a model output.
17+
18+
Forward accepts
19+
- ``indexes`` (long tensor): ``(N, ...)``
20+
- ``preds`` (float or int tensor): ``(N, ...)``
21+
- ``target`` (long or bool tensor): ``(N, ...)``
22+
23+
`indexes`, `preds` and `target` must have the same dimension and will be flatten
24+
to single dimension once provided.
25+
26+
`indexes` indicate to which query a prediction belongs.
27+
Predictions will be first grouped by indexes. Then the
28+
real metric, defined by overriding the `_metric` method,
29+
will be computed as the mean of the scores over each query.
30+
31+
Args:
32+
query_without_relevant_docs:
33+
Specify what to do with queries that do not have at least a positive target. Choose from:
34+
35+
- ``'skip'``: skip those queries (default); if all queries are skipped, ``0.0`` is returned
36+
- ``'error'``: raise a ``ValueError``
37+
- ``'pos'``: score on those queries is counted as ``1.0``
38+
- ``'neg'``: score on those queries is counted as ``0.0``
39+
exclude:
40+
Do not take into account predictions where the target is equal to this value. default `-100`
41+
compute_on_step:
42+
Forward only calls ``update()`` and return None if this is set to False. default: True
43+
dist_sync_on_step:
44+
Synchronize metric state across processes at each ``forward()``
45+
before returning the value at the step. default: False
46+
process_group:
47+
Specify the process group on which synchronization is called. default: None (which selects
48+
the entire world)
49+
dist_sync_fn:
50+
Callback that performs the allgather operation on the metric state. When `None`, DDP
51+
will be used to perform the allgather. default: None
52+
53+
"""
54+
55+
def __init__(
56+
self,
57+
query_without_relevant_docs: str = 'skip',
58+
exclude: int = IGNORE_IDX,
59+
compute_on_step: bool = True,
60+
dist_sync_on_step: bool = False,
61+
process_group: Optional[Any] = None,
62+
dist_sync_fn: Callable = None
63+
):
64+
super().__init__(
65+
compute_on_step=compute_on_step,
66+
dist_sync_on_step=dist_sync_on_step,
67+
process_group=process_group,
68+
dist_sync_fn=dist_sync_fn
69+
)
70+
71+
query_without_relevant_docs_options = ('error', 'skip', 'pos', 'neg')
72+
if query_without_relevant_docs not in query_without_relevant_docs_options:
73+
raise ValueError(
74+
f"`query_without_relevant_docs` received a wrong value {query_without_relevant_docs}. "
75+
f"Allowed values are {query_without_relevant_docs_options}"
76+
)
77+
78+
self.query_without_relevant_docs = query_without_relevant_docs
79+
self.exclude = exclude
80+
81+
self.add_state("idx", default=[], dist_reduce_fx=None)
82+
self.add_state("preds", default=[], dist_reduce_fx=None)
83+
self.add_state("target", default=[], dist_reduce_fx=None)
84+
85+
def update(self, idx: torch.Tensor, preds: torch.Tensor, target: torch.Tensor) -> None:
86+
if not (idx.shape == target.shape == preds.shape):
87+
raise ValueError("`idx`, `preds` and `target` must be of the same shape")
88+
89+
idx = idx.to(dtype=torch.int64).flatten()
90+
preds = preds.to(dtype=torch.float32).flatten()
91+
target = target.to(dtype=torch.int64).flatten()
92+
93+
self.idx.append(idx)
94+
self.preds.append(preds)
95+
self.target.append(target)
96+
97+
def compute(self) -> torch.Tensor:
98+
r"""
99+
First concat state `idx`, `preds` and `target` since they were stored as lists. After that,
100+
compute list of groups that will help in keeping together predictions about the same query.
101+
Finally, for each group compute the `_metric` if the number of positive targets is at least
102+
1, otherwise behave as specified by `self.query_without_relevant_docs`.
103+
"""
104+
105+
idx = torch.cat(self.idx, dim=0)
106+
preds = torch.cat(self.preds, dim=0)
107+
target = torch.cat(self.target, dim=0)
108+
109+
res = []
110+
kwargs = {'device': idx.device, 'dtype': torch.float32}
111+
112+
groups = get_group_indexes(idx)
113+
for group in groups:
114+
115+
mini_preds = preds[group]
116+
mini_target = target[group]
117+
118+
if not mini_target.sum():
119+
if self.query_without_relevant_docs == 'error':
120+
raise ValueError(
121+
f"`{self.__class__.__name__}.compute()` was provided with "
122+
f"a query without positive targets, indexes: {group}"
123+
)
124+
if self.query_without_relevant_docs == 'pos':
125+
res.append(torch.tensor(1.0, **kwargs))
126+
elif self.query_without_relevant_docs == 'neg':
127+
res.append(torch.tensor(0.0, **kwargs))
128+
else:
129+
res.append(self._metric(mini_preds, mini_target))
130+
131+
if len(res) > 0:
132+
return torch.stack(res).mean()
133+
return torch.tensor(0.0, **kwargs)
134+
135+
@abstractmethod
136+
def _metric(self, preds: torch.Tensor, target: torch.Tensor) -> torch.Tensor:
137+
r"""
138+
Compute a metric over a predictions and target of a single group.
139+
This method should be overridden by subclasses.
140+
"""

pytorch_lightning/metrics/utils.py

Lines changed: 30 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,7 @@
1111
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
1212
# See the License for the specific language governing permissions and
1313
# limitations under the License.
14-
from typing import Optional, Tuple
14+
from typing import List, Optional, Tuple
1515

1616
import torch
1717

@@ -93,6 +93,35 @@ def _input_format_classification_one_hot(
9393
return preds.reshape(num_classes, -1), target.reshape(num_classes, -1)
9494

9595

96+
def get_group_indexes(idx: torch.Tensor) -> List[torch.Tensor]:
97+
"""
98+
Given an integer `torch.Tensor` `idx`, return a `torch.Tensor` of indexes for
99+
each different value in `idx`.
100+
101+
Args:
102+
idx: a `torch.Tensor` of integers
103+
104+
Return:
105+
A list of integer `torch.Tensor`s
106+
107+
Example:
108+
109+
>>> indexes = torch.tensor([0, 0, 0, 1, 1, 1, 1])
110+
>>> groups = get_group_indexes(indexes)
111+
>>> groups
112+
[tensor([0, 1, 2]), tensor([3, 4, 5, 6])]
113+
"""
114+
115+
indexes = dict()
116+
for i, _id in enumerate(idx):
117+
_id = _id.item()
118+
if _id in indexes:
119+
indexes[_id] += [i]
120+
else:
121+
indexes[_id] = [i]
122+
return [torch.tensor(x, dtype=torch.int64) for x in indexes.values()]
123+
124+
96125
def to_onehot(
97126
label_tensor: torch.Tensor,
98127
num_classes: Optional[int] = None,

0 commit comments

Comments
 (0)