Skip to content

Commit c0aace5

Browse files
authored
Add transform to handle empty box as training data (#6170)
Fixes #5990 . ### Description Add transforms to convert empty box with shape (0,M) or (0,) into (0,4) or (0,6). Provide format checking inside detector so users can know how to solve the format issue with empty box input. ### Types of changes <!--- Put an `x` in all the boxes that apply, and remove the not applicable items --> - [x] Non-breaking change (fix or new feature that would not break existing functionality). - [ ] Breaking change (fix or new feature that would cause existing functionality to change). - [x] New tests added to cover the changes. - [ ] Integration tests passed locally by running `./runtests.sh -f -u --net --coverage`. - [ ] Quick tests passed locally by running `./runtests.sh --quick --unittests --disttests`. - [ ] In-line docstrings updated. - [ ] Documentation updated, tested `make html` command in the `docs/` folder. --------- Signed-off-by: Can Zhao <[email protected]>
1 parent c885460 commit c0aace5

File tree

6 files changed

+170
-25
lines changed

6 files changed

+170
-25
lines changed

monai/apps/detection/networks/retinanet_detector.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -485,7 +485,9 @@ def forward(
485485
"""
486486
# 1. Check if input arguments are valid
487487
if self.training:
488-
check_training_targets(input_images, targets, self.spatial_dims, self.target_label_key, self.target_box_key)
488+
targets = check_training_targets(
489+
input_images, targets, self.spatial_dims, self.target_label_key, self.target_box_key
490+
)
489491
self._check_detector_training_components()
490492

491493
# 2. Pad list of images to a single Tensor `images` with spatial size divisible by self.size_divisible.
@@ -877,7 +879,7 @@ def get_cls_train_sample_per_image(
877879

878880
foreground_idxs_per_image = matched_idxs_per_image >= 0
879881

880-
num_foreground = foreground_idxs_per_image.sum()
882+
num_foreground = int(foreground_idxs_per_image.sum())
881883
num_gt_box = targets_per_image[self.target_box_key].shape[0]
882884

883885
if self.debug:

monai/apps/detection/transforms/array.py

Lines changed: 23 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -28,6 +28,7 @@
2828
convert_box_to_standard_mode,
2929
get_spatial_dims,
3030
spatial_crop_boxes,
31+
standardize_empty_box,
3132
)
3233
from monai.transforms import Rotate90, SpatialCrop
3334
from monai.transforms.transform import Transform
@@ -46,6 +47,7 @@
4647
)
4748

4849
__all__ = [
50+
"StandardizeEmptyBox",
4951
"ConvertBoxToStandardMode",
5052
"ConvertBoxMode",
5153
"AffineBox",
@@ -60,6 +62,27 @@
6062
]
6163

6264

65+
class StandardizeEmptyBox(Transform):
66+
"""
67+
When boxes are empty, this transform standardize it to shape of (0,4) or (0,6).
68+
69+
Args:
70+
spatial_dims: number of spatial dimensions of the bounding boxes.
71+
"""
72+
73+
backend = [TransformBackends.TORCH, TransformBackends.NUMPY]
74+
75+
def __init__(self, spatial_dims: int) -> None:
76+
self.spatial_dims = spatial_dims
77+
78+
def __call__(self, boxes: NdarrayOrTensor) -> NdarrayOrTensor:
79+
"""
80+
Args:
81+
boxes: source bounding boxes, Nx4 or Nx6 or 0xM torch tensor or ndarray.
82+
"""
83+
return standardize_empty_box(boxes, spatial_dims=self.spatial_dims)
84+
85+
6386
class ConvertBoxMode(Transform):
6487
"""
6588
This transform converts the boxes in src_mode to the dst_mode.

monai/apps/detection/transforms/dictionary.py

Lines changed: 49 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -34,6 +34,7 @@
3434
MaskToBox,
3535
RotateBox90,
3636
SpatialCropBox,
37+
StandardizeEmptyBox,
3738
ZoomBox,
3839
)
3940
from monai.apps.detection.transforms.box_ops import convert_box_to_mask
@@ -51,6 +52,9 @@
5152
from monai.utils.type_conversion import convert_data_type, convert_to_tensor
5253

5354
__all__ = [
55+
"StandardizeEmptyBoxd",
56+
"StandardizeEmptyBoxD",
57+
"StandardizeEmptyBoxDict",
5458
"ConvertBoxModed",
5559
"ConvertBoxModeD",
5660
"ConvertBoxModeDict",
@@ -95,6 +99,50 @@
9599
DEFAULT_POST_FIX = PostFix.meta()
96100

97101

102+
class StandardizeEmptyBoxd(MapTransform, InvertibleTransform):
103+
"""
104+
Dictionary-based wrapper of :py:class:`monai.apps.detection.transforms.array.StandardizeEmptyBox`.
105+
106+
When boxes are empty, this transform standardize it to shape of (0,4) or (0,6).
107+
108+
Example:
109+
.. code-block:: python
110+
111+
data = {"boxes": torch.ones(0,), "image": torch.ones(1, 128, 128, 128)}
112+
box_converter = StandardizeEmptyBoxd(box_keys=["boxes"], box_ref_image_keys="image")
113+
box_converter(data)
114+
"""
115+
116+
def __init__(self, box_keys: KeysCollection, box_ref_image_keys: str, allow_missing_keys: bool = False) -> None:
117+
"""
118+
Args:
119+
box_keys: Keys to pick data for transformation.
120+
box_ref_image_keys: The single key that represents the reference image to which ``box_keys`` are attached.
121+
allow_missing_keys: don't raise exception if key is missing.
122+
123+
See also :py:class:`monai.apps.detection,transforms.array.ConvertBoxToStandardMode`
124+
"""
125+
super().__init__(box_keys, allow_missing_keys)
126+
box_ref_image_keys_tuple = ensure_tuple(box_ref_image_keys)
127+
if len(box_ref_image_keys_tuple) > 1:
128+
raise ValueError(
129+
"Please provide a single key for box_ref_image_keys.\
130+
All boxes of box_keys are attached to box_ref_image_keys."
131+
)
132+
self.box_ref_image_keys = box_ref_image_keys
133+
134+
def __call__(self, data: Mapping[Hashable, NdarrayOrTensor]) -> dict[Hashable, NdarrayOrTensor]:
135+
d = dict(data)
136+
spatial_dims = len(d[self.box_ref_image_keys].shape) - 1
137+
self.converter = StandardizeEmptyBox(spatial_dims=spatial_dims)
138+
for key in self.key_iterator(d):
139+
d[key] = self.converter(d[key])
140+
return d
141+
142+
def inverse(self, data: Mapping[Hashable, NdarrayOrTensor]) -> dict[Hashable, NdarrayOrTensor]:
143+
return dict(data)
144+
145+
98146
class ConvertBoxModed(MapTransform, InvertibleTransform):
99147
"""
100148
Dictionary-based wrapper of :py:class:`monai.apps.detection.transforms.array.ConvertBoxMode`.
@@ -1353,3 +1401,4 @@ def inverse(self, data: Mapping[Hashable, torch.Tensor]) -> dict[Hashable, torch
13531401
RandCropBoxByPosNegLabelD = RandCropBoxByPosNegLabelDict = RandCropBoxByPosNegLabeld
13541402
RotateBox90D = RotateBox90Dict = RotateBox90d
13551403
RandRotateBox90D = RandRotateBox90Dict = RandRotateBox90d
1404+
StandardizeEmptyBoxD = StandardizeEmptyBoxDict = StandardizeEmptyBoxd

monai/apps/detection/utils/detector_utils.py

Lines changed: 23 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -11,13 +11,15 @@
1111

1212
from __future__ import annotations
1313

14+
import warnings
1415
from collections.abc import Sequence
1516
from typing import Any
1617

1718
import torch
1819
import torch.nn.functional as F
1920
from torch import Tensor
2021

22+
from monai.data.box_utils import standardize_empty_box
2123
from monai.transforms.croppad.array import SpatialPad
2224
from monai.transforms.utils import compute_divisible_spatial_size, convert_pad_mode
2325
from monai.utils import PytorchPadMode, ensure_tuple_rep
@@ -56,7 +58,7 @@ def check_training_targets(
5658
spatial_dims: int,
5759
target_label_key: str,
5860
target_box_key: str,
59-
) -> None:
61+
) -> list[dict[str, Tensor]]:
6062
"""
6163
Validate the input images/targets during training (raise a `ValueError` if invalid).
6264
@@ -75,7 +77,8 @@ def check_training_targets(
7577
if len(input_images) != len(targets):
7678
raise ValueError(f"len(input_images) should equal to len(targets), got {len(input_images)}, {len(targets)}.")
7779

78-
for target in targets:
80+
for i in range(len(targets)):
81+
target = targets[i]
7982
if (target_label_key not in target.keys()) or (target_box_key not in target.keys()):
8083
raise ValueError(
8184
f"{target_label_key} and {target_box_key} are expected keys in targets. Got {target.keys()}."
@@ -85,10 +88,24 @@ def check_training_targets(
8588
if not isinstance(boxes, torch.Tensor):
8689
raise ValueError(f"Expected target boxes to be of type Tensor, got {type(boxes)}.")
8790
if len(boxes.shape) != 2 or boxes.shape[-1] != 2 * spatial_dims:
88-
raise ValueError(
89-
f"Expected target boxes to be a tensor " f"of shape [N, {2* spatial_dims}], got {boxes.shape}."
90-
)
91-
return
91+
if boxes.numel() == 0:
92+
warnings.warn(
93+
f"Warning: Given target boxes has shape of {boxes.shape}. "
94+
f"The detector reshaped it with boxes = torch.reshape(boxes, [0, {2* spatial_dims}])."
95+
)
96+
else:
97+
raise ValueError(
98+
f"Expected target boxes to be a tensor of shape [N, {2* spatial_dims}], got {boxes.shape}.)."
99+
)
100+
if not torch.is_floating_point(boxes):
101+
raise ValueError(f"Expected target boxes to be a float tensor, got {boxes.dtype}.")
102+
targets[i][target_box_key] = standardize_empty_box(boxes, spatial_dims=spatial_dims) # type: ignore
103+
104+
labels = target[target_label_key]
105+
if torch.is_floating_point(labels):
106+
warnings.warn(f"Warning: Given target labels is {labels.dtype}. The detector converted it to torch.long.")
107+
targets[i][target_label_key] = labels.long()
108+
return targets
92109

93110

94111
def pad_images(

monai/data/box_utils.py

Lines changed: 56 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -395,19 +395,41 @@ def get_spatial_dims(
395395

396396
# Check the validity of each input and add its corresponding spatial_dims to spatial_dims_set
397397
if boxes is not None:
398-
if int(boxes.shape[1]) not in [4, 6]:
398+
if len(boxes.shape) != 2:
399+
if boxes.shape[0] == 0:
400+
raise ValueError(
401+
f"Currently we support only boxes with shape [N,4] or [N,6], "
402+
f"got boxes with shape {boxes.shape}. "
403+
f"Please reshape it with boxes = torch.reshape(boxes, [0, 4]) or torch.reshape(boxes, [0, 6])."
404+
)
405+
else:
406+
raise ValueError(
407+
f"Currently we support only boxes with shape [N,4] or [N,6], got boxes with shape {boxes.shape}."
408+
)
409+
if int(boxes.shape[1] / 2) not in SUPPORTED_SPATIAL_DIMS:
399410
raise ValueError(
400411
f"Currently we support only boxes with shape [N,4] or [N,6], got boxes with shape {boxes.shape}."
401412
)
402413
spatial_dims_set.add(int(boxes.shape[1] / 2))
403414
if points is not None:
415+
if len(points.shape) != 2:
416+
if points.shape[0] == 0:
417+
raise ValueError(
418+
f"Currently we support only points with shape [N,2] or [N,3], "
419+
f"got points with shape {points.shape}. "
420+
f"Please reshape it with points = torch.reshape(points, [0, 2]) or torch.reshape(points, [0, 3])."
421+
)
422+
else:
423+
raise ValueError(
424+
f"Currently we support only points with shape [N,2] or [N,3], got points with shape {points.shape}."
425+
)
404426
if int(points.shape[1]) not in SUPPORTED_SPATIAL_DIMS:
405427
raise ValueError(
406-
f"Currently we support only points with shape [N,2] or [N,3], got boxes with shape {points.shape}."
428+
f"Currently we support only points with shape [N,2] or [N,3], got points with shape {points.shape}."
407429
)
408430
spatial_dims_set.add(int(points.shape[1]))
409431
if corners is not None:
410-
if len(corners) not in [4, 6]:
432+
if len(corners) // 2 not in SUPPORTED_SPATIAL_DIMS:
411433
raise ValueError(
412434
f"Currently we support only boxes with shape [N,4] or [N,6], got box corner tuple with length {len(corners)}."
413435
)
@@ -494,6 +516,33 @@ def get_boxmode(mode: str | BoxMode | type[BoxMode] | None = None, *args, **kwar
494516
return StandardMode(*args, **kwargs)
495517

496518

519+
def standardize_empty_box(boxes: NdarrayOrTensor, spatial_dims: int) -> NdarrayOrTensor:
520+
"""
521+
When boxes are empty, this function standardize it to shape of (0,4) or (0,6).
522+
523+
Args:
524+
boxes: bounding boxes, Nx4 or Nx6 or empty torch tensor or ndarray
525+
spatial_dims: number of spatial dimensions of the bounding boxes.
526+
527+
Returns:
528+
bounding boxes with shape (N,4) or (N,6), N can be 0.
529+
530+
Example:
531+
.. code-block:: python
532+
533+
boxes = torch.ones(0,)
534+
standardize_empty_box(boxes, 3)
535+
"""
536+
# convert numpy to tensor if needed
537+
boxes_t, *_ = convert_data_type(boxes, torch.Tensor)
538+
# handle empty box
539+
if boxes_t.shape[0] == 0:
540+
boxes_t = torch.reshape(boxes_t, [0, spatial_dims * 2])
541+
# convert tensor back to numpy if needed
542+
boxes_dst, *_ = convert_to_dst_type(src=boxes_t, dst=boxes)
543+
return boxes_dst
544+
545+
497546
def convert_box_mode(
498547
boxes: NdarrayOrTensor,
499548
src_mode: str | BoxMode | type[BoxMode] | None = None,
@@ -522,6 +571,10 @@ def convert_box_mode(
522571
convert_box_mode(boxes=boxes, src_mode="xyxy", dst_mode=monai.data.box_utils.CenterSizeMode)
523572
convert_box_mode(boxes=boxes, src_mode="xyxy", dst_mode=monai.data.box_utils.CenterSizeMode())
524573
"""
574+
# handle empty box
575+
if boxes.shape[0] == 0:
576+
return boxes
577+
525578
src_boxmode = get_boxmode(src_mode)
526579
dst_boxmode = get_boxmode(dst_mode)
527580

tests/test_retinanet_detector.py

Lines changed: 15 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -134,20 +134,21 @@ def test_retina_detector_resnet_backbone_shape(self, input_param, input_shape):
134134

135135
detector.set_atss_matcher()
136136
detector.set_hard_negative_sampler(10, 0.5)
137-
gt_box_start = torch.randint(2, (3, input_param["spatial_dims"])).to(torch.float16)
138-
gt_box_end = gt_box_start + torch.randint(1, 10, (3, input_param["spatial_dims"]))
139-
one_target = {
140-
"boxes": torch.cat((gt_box_start, gt_box_end), dim=1),
141-
"labels": torch.randint(input_param["num_classes"], (3,)),
142-
}
143-
with train_mode(detector):
144-
input_data = torch.randn(input_shape)
145-
targets = [one_target] * len(input_data)
146-
result = detector.forward(input_data, targets)
147-
148-
input_data = [torch.randn(input_shape[1:]) for _ in range(random.randint(1, 9))]
149-
targets = [one_target] * len(input_data)
150-
result = detector.forward(input_data, targets)
137+
for num_gt_box in [0, 3]: # test for both empty and non-empty boxes
138+
gt_box_start = torch.randint(2, (num_gt_box, input_param["spatial_dims"])).to(torch.float16)
139+
gt_box_end = gt_box_start + torch.randint(1, 10, (num_gt_box, input_param["spatial_dims"]))
140+
one_target = {
141+
"boxes": torch.cat((gt_box_start, gt_box_end), dim=1),
142+
"labels": torch.randint(input_param["num_classes"], (num_gt_box,)),
143+
}
144+
with train_mode(detector):
145+
input_data = torch.randn(input_shape)
146+
targets = [one_target] * len(input_data)
147+
result = detector.forward(input_data, targets)
148+
149+
input_data = [torch.randn(input_shape[1:]) for _ in range(random.randint(1, 9))]
150+
targets = [one_target] * len(input_data)
151+
result = detector.forward(input_data, targets)
151152

152153
@parameterized.expand(TEST_CASES)
153154
def test_naive_retina_detector_shape(self, input_param, input_shape):

0 commit comments

Comments
 (0)