Skip to content

Commit 32a237a

Browse files
myronmonai-bot
andauthored
fixes DiceCELoss for multichannel targets (#5292)
Fixes DiceCELoss for multichannel targets. Currently if "target" (ground truth label) is provided as a multichannel data (each channel is binary or float), then current DiceCELoss attempts to convert it to 1-channel using argmax (which could be impossible with overlapping labels). There is no need for argmax, since pytorch's cross entropy can handle multi-channel targets already. ### Types of changes <!--- Put an `x` in all the boxes that apply, and remove the not applicable items --> - [x] Non-breaking change (fix or new feature that would not break existing functionality). - [ ] Breaking change (fix or new feature that would cause existing functionality to change). - [ ] New tests added to cover the changes. - [ ] Integration tests passed locally by running `./runtests.sh -f -u --net --coverage`. - [ ] Quick tests passed locally by running `./runtests.sh --quick --unittests --disttests`. - [ ] In-line docstrings updated. - [ ] Documentation updated, tested `make html` command in the `docs/` folder. Signed-off-by: myron <[email protected]> Signed-off-by: monai-bot <[email protected]> Co-authored-by: monai-bot <[email protected]>
1 parent 9731823 commit 32a237a

File tree

2 files changed

+12
-7
lines changed

2 files changed

+12
-7
lines changed

monai/data/box_utils.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -657,7 +657,7 @@ def boxes_center_distance(
657657
center2 = box_centers(boxes2_t.to(COMPUTE_DTYPE)) # (M, spatial_dims)
658658

659659
if euclidean:
660-
dists = (center1[:, None] - center2[None]).pow(2).sum(-1).sqrt()
660+
dists = (center1[:, None] - center2[None]).pow(2).sum(-1).sqrt() # type: ignore
661661
else:
662662
# before sum: (N, M, spatial_dims)
663663
dists = (center1[:, None] - center2[None]).sum(-1)

monai/losses/dice.py

Lines changed: 11 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -21,7 +21,7 @@
2121
from monai.losses.focal_loss import FocalLoss
2222
from monai.losses.spatial_mask import MaskedLoss
2323
from monai.networks import one_hot
24-
from monai.utils import DiceCEReduction, LossReduction, Weight, look_up_option
24+
from monai.utils import DiceCEReduction, LossReduction, Weight, look_up_option, pytorch_after
2525

2626

2727
class DiceLoss(_Loss):
@@ -692,6 +692,7 @@ def __init__(
692692
raise ValueError("lambda_ce should be no less than 0.0.")
693693
self.lambda_dice = lambda_dice
694694
self.lambda_ce = lambda_ce
695+
self.old_pt_ver = not pytorch_after(1, 10)
695696

696697
def ce(self, input: torch.Tensor, target: torch.Tensor):
697698
"""
@@ -701,12 +702,16 @@ def ce(self, input: torch.Tensor, target: torch.Tensor):
701702
702703
"""
703704
n_pred_ch, n_target_ch = input.shape[1], target.shape[1]
704-
if n_pred_ch == n_target_ch:
705-
# target is in the one-hot format, convert to BH[WD] format to calculate ce loss
706-
target = torch.argmax(target, dim=1)
707-
else:
705+
if n_pred_ch != n_target_ch and n_target_ch == 1:
708706
target = torch.squeeze(target, dim=1)
709-
target = target.long()
707+
target = target.long()
708+
elif self.old_pt_ver:
709+
warnings.warn(
710+
f"Multichannel targets are not supported in this older Pytorch version {torch.__version__}. "
711+
"Using argmax (as a workaround) to convert target to a single channel."
712+
)
713+
target = torch.argmax(target, dim=1)
714+
710715
return self.cross_entropy(input, target)
711716

712717
def forward(self, input: torch.Tensor, target: torch.Tensor) -> torch.Tensor:

0 commit comments

Comments
 (0)