Skip to content

Commit bb3ecf6

Browse files
authored
3694 loss function as cumulative metric (#5513)
Signed-off-by: Wenqi Li <[email protected]> Fixes #3694 ### Description adds a wrapper for loss function: ```py dice_loss = DiceLoss(include_background=True) loss_metric = LossMetric(loss_fn=dice_loss) ``` so that `loss_metric` it can be used as a metric in: https://github.com/Project-MONAI/MONAI/blob/b030839e98e6a00c1ab0e53545f60b62dc4da4a4/monai/handlers/ignite_metric.py#L32 ### Types of changes <!--- Put an `x` in all the boxes that apply, and remove the not applicable items --> - [x] Non-breaking change (fix or new feature that would not break existing functionality). - [ ] Breaking change (fix or new feature that would cause existing functionality to change). - [x] New tests added to cover the changes. - [ ] Integration tests passed locally by running `./runtests.sh -f -u --net --coverage`. - [x] Quick tests passed locally by running `./runtests.sh --quick --unittests --disttests`. - [x] In-line docstrings updated. - [x] Documentation updated, tested `make html` command in the `docs/` folder. Signed-off-by: Wenqi Li <[email protected]>
1 parent 5a6672c commit bb3ecf6

File tree

5 files changed

+170
-2
lines changed

5 files changed

+170
-2
lines changed

docs/source/metrics.rst

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -46,6 +46,11 @@ Metrics
4646
.. autoclass:: CumulativeIterationMetric
4747
:members:
4848

49+
`LossMetric`
50+
------------
51+
.. autoclass:: LossMetric
52+
:members:
53+
4954
`Mean Dice`
5055
-----------
5156
.. autofunction:: compute_meandice

monai/metrics/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,7 @@
1616
from .froc import compute_fp_tp_probs, compute_froc_curve_data, compute_froc_score
1717
from .generalized_dice import GeneralizedDiceScore, compute_generalized_dice
1818
from .hausdorff_distance import HausdorffDistanceMetric, compute_hausdorff_distance, compute_percent_hausdorff_distance
19+
from .loss_metric import LossMetric
1920
from .meandice import DiceMetric, compute_dice, compute_meandice
2021
from .meaniou import MeanIoU, compute_iou, compute_meaniou
2122
from .metric import Cumulative, CumulativeIterationMetric, IterationMetric, Metric

monai/metrics/loss_metric.py

Lines changed: 106 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,106 @@
1+
# Copyright (c) MONAI Consortium
2+
# Licensed under the Apache License, Version 2.0 (the "License");
3+
# you may not use this file except in compliance with the License.
4+
# You may obtain a copy of the License at
5+
# http://www.apache.org/licenses/LICENSE-2.0
6+
# Unless required by applicable law or agreed to in writing, software
7+
# distributed under the License is distributed on an "AS IS" BASIS,
8+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
9+
# See the License for the specific language governing permissions and
10+
# limitations under the License.
11+
12+
from typing import Union
13+
14+
import torch
15+
from torch.nn.modules.loss import _Loss
16+
17+
from monai.metrics.utils import do_metric_reduction
18+
from monai.utils import MetricReduction
19+
20+
from .metric import CumulativeIterationMetric
21+
22+
23+
class LossMetric(CumulativeIterationMetric):
24+
"""
25+
A wrapper to make ``loss_fn`` available as a cumulative metric. That is, the loss values computed from
26+
mini-batches can be combined in the ``reduction`` mode across multiple iterations, as a quantitative measurement
27+
of a model.
28+
29+
Example:
30+
31+
.. code-block:: python
32+
33+
import torch
34+
from monai.losses import DiceLoss
35+
from monai.metrics import LossMetric
36+
37+
dice_loss = DiceLoss(include_background=True)
38+
loss_metric = LossMetric(loss_fn=dice_loss)
39+
40+
# first iteration
41+
y_pred = torch.tensor([[[[1.0, 0.0], [0.0, 1.0]]]]) # shape [batch=1, channel=1, 2, 2]
42+
y = torch.tensor([[[[1.0, 0.0], [1.0, 1.0]]]]) # shape [batch=1, channel=1, 2, 2]
43+
loss_metric(y_pred, y)
44+
45+
# second iteration
46+
y_pred = torch.tensor([[[[1.0, 0.0], [0.0, 0.0]]]]) # shape [batch=1, channel=1, 2, 2]
47+
y = torch.tensor([[[[1.0, 0.0], [1.0, 1.0]]]]) # shape [batch=1, channel=1, 2, 2]
48+
loss_metric(y_pred, y)
49+
50+
# aggregate
51+
print(loss_metric.aggregate(reduction="none")) # tensor([[0.2000], [0.5000]]) (shape [batch=2, channel=1])
52+
53+
# reset
54+
loss_metric.reset()
55+
print(loss_metric.aggregate())
56+
57+
58+
Args:
59+
loss_fn: a callable function that takes ``y_pred`` and optionally ``y`` as input (in the "batch-first" format),
60+
returns a "batch-first" tensor of loss values.
61+
reduction: define mode of reduction to the metrics, will only apply reduction on `not-nan` values,
62+
available reduction modes: {``"none"``, ``"mean"``, ``"sum"``, ``"mean_batch"``, ``"sum_batch"``,
63+
``"mean_channel"``, ``"sum_channel"``}, default to ``"mean"``. if "none", will not do reduction.
64+
get_not_nans: whether to return the `not_nans` count, if True, aggregate() returns (metric, not_nans).
65+
Here `not_nans` count the number of not nans for the metric, thus its shape equals to the shape of the metric.
66+
67+
"""
68+
69+
def __init__(
70+
self, loss_fn: _Loss, reduction: Union[MetricReduction, str] = MetricReduction.MEAN, get_not_nans: bool = False
71+
) -> None:
72+
super().__init__()
73+
self.loss_fn = loss_fn
74+
self.reduction = reduction
75+
self.get_not_nans = get_not_nans
76+
77+
def aggregate(self, reduction: Union[MetricReduction, str, None] = None):
78+
"""
79+
Returns the aggregated loss value across multiple iterations.
80+
81+
Args:
82+
reduction: define mode of reduction to the metrics, will only apply reduction on `not-nan` values,
83+
available reduction modes: {``"none"``, ``"mean"``, ``"sum"``, ``"mean_batch"``, ``"sum_batch"``,
84+
``"mean_channel"``, ``"sum_channel"``}, default to `self.reduction`. if "none", will not do reduction.
85+
"""
86+
data = self.get_buffer()
87+
if data is None:
88+
return (torch.tensor(0.0), torch.tensor(0.0)) if self.get_not_nans else torch.tensor(0.0)
89+
f, not_nans = do_metric_reduction(data, reduction or self.reduction)
90+
return (f, not_nans) if self.get_not_nans else f
91+
92+
def _compute_tensor(self, y_pred: torch.Tensor, y: torch.Tensor = None): # type: ignore
93+
"""
94+
Input `y_pred` is compared with ground truth `y`.
95+
Both `y_pred` and `y` are expected to be a batch-first Tensor (BC[HWD]).
96+
97+
Returns:
98+
a tensor with shape (BC[HWD]), or a list of tensors, each tensor with shape (C[HWD]).
99+
"""
100+
iter_loss = self.loss_fn(y_pred) if y is None else self.loss_fn(y_pred, y)
101+
if isinstance(iter_loss, torch.Tensor):
102+
while iter_loss.dim() < 2:
103+
iter_loss = iter_loss[None]
104+
# to be compatible with `Cumulative`, iter_loss should at least have a batch dim.
105+
# to be compatible with `do_metric_reduction`, iter_loss should at least have a batch and a channel dim.
106+
return iter_loss

monai/metrics/metric.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -202,7 +202,7 @@ def extend(self, *data) -> None:
202202
for b, d in zip(self._buffers, data):
203203
# converting to pytorch tensors so that we can use the distributed API
204204
d_t, *_ = convert_data_type(d, output_type=torch.Tensor, wrap_sequence=True)
205-
try:
205+
try: # d_t must be a mini-batch of values
206206
b.extend([x[0] for x in torch.split(d_t, 1, dim=0)])
207207
except (AttributeError, IndexError, RuntimeError) as e:
208208
raise TypeError(
@@ -286,6 +286,7 @@ class CumulativeIterationMetric(Cumulative, IterationMetric):
286286
287287
Typically, it computes some intermediate results for each iteration, adds them to the buffers,
288288
then the buffer contents could be gathered and aggregated for the final result when epoch completed.
289+
Currently,``Cumulative.aggregate()`` and ``IterationMetric._compute_tensor()`` are expected to be implemented.
289290
290291
For example, `MeanDice` inherits this class and the usage is as follows:
291292
@@ -324,7 +325,8 @@ def __call__(self, y_pred: TensorOrList, y: Optional[TensorOrList] = None):
324325
or a `batch-first` Tensor.
325326
326327
Returns:
327-
The computed metric values at the iteration level.
328+
The computed metric values at the iteration level. The output shape should be
329+
a `batch-first` tensor (BC[HWD]) or a list of `batch-first` tensors.
328330
"""
329331
ret = super().__call__(y_pred=y_pred, y=y)
330332
if isinstance(ret, (tuple, list)):

tests/test_loss_metric.py

Lines changed: 54 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,54 @@
1+
# Copyright (c) MONAI Consortium
2+
# Licensed under the Apache License, Version 2.0 (the "License");
3+
# you may not use this file except in compliance with the License.
4+
# You may obtain a copy of the License at
5+
# http://www.apache.org/licenses/LICENSE-2.0
6+
# Unless required by applicable law or agreed to in writing, software
7+
# distributed under the License is distributed on an "AS IS" BASIS,
8+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
9+
# See the License for the specific language governing permissions and
10+
# limitations under the License.
11+
12+
import unittest
13+
14+
import numpy as np
15+
import torch
16+
from parameterized import parameterized
17+
18+
from monai.losses import DiceLoss
19+
from monai.metrics import LossMetric
20+
21+
_device = "cuda:0" if torch.cuda.is_available() else "cpu"
22+
TEST_CASE_1 = [ # y (1, 1, 2, 2), y_pred (1, 1, 2, 2), expected out (1, 1)
23+
{
24+
"loss_class": DiceLoss,
25+
"loss_kwargs": {"include_background": True},
26+
"reduction": "mean",
27+
"get_not_nans": False,
28+
"y_pred": torch.tensor([[[[1.0, 0.0], [0.0, 1.0]]]], device=_device),
29+
"y": torch.tensor([[[[1.0, 0.0], [1.0, 1.0]]]], device=_device),
30+
"include_background": True,
31+
},
32+
[0.2],
33+
]
34+
35+
36+
class TestComputeLossMetric(unittest.TestCase):
37+
@parameterized.expand([TEST_CASE_1])
38+
def test_value_class(self, input_data, expected_value):
39+
loss_fn = input_data["loss_class"](**input_data["loss_kwargs"])
40+
loss_metric = LossMetric(
41+
loss_fn=loss_fn, reduction=input_data["reduction"], get_not_nans=input_data["get_not_nans"]
42+
)
43+
44+
loss_metric(y_pred=input_data.get("y_pred"), y=input_data.get("y"))
45+
loss_metric(y_pred=input_data.get("y_pred"), y=input_data.get("y"))
46+
result = loss_metric.aggregate()
47+
np.testing.assert_allclose(result.cpu().numpy(), expected_value, atol=1e-4)
48+
loss_metric.reset()
49+
result = loss_metric.aggregate()
50+
np.testing.assert_allclose(result.cpu().numpy(), 0.0, atol=1e-4)
51+
52+
53+
if __name__ == "__main__":
54+
unittest.main()

0 commit comments

Comments
 (0)