Skip to content

Commit 641a079

Browse files
Nic-Mamonai-bot
andauthored
4173 Enhance decathlon datalist for test section format (#4186)
* [DLMED] enhance test datalist Signed-off-by: Nic Ma <[email protected]> * [MONAI] python code formatting Signed-off-by: monai-bot <[email protected]> * [DLMED] fix flake8 Signed-off-by: Nic Ma <[email protected]> Co-authored-by: monai-bot <[email protected]>
1 parent 5c15138 commit 641a079

File tree

3 files changed

+10
-3
lines changed

3 files changed

+10
-3
lines changed

monai/data/decathlon_datalist.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -122,7 +122,8 @@ def load_decathlon_datalist(
122122
if data_list_key not in json_data:
123123
raise ValueError(f'Data list {data_list_key} not specified in "{data_list_file_path}".')
124124
expected_data = json_data[data_list_key]
125-
if data_list_key == "test":
125+
if data_list_key == "test" and not isinstance(expected_data[0], dict):
126+
# decathlon datalist may save the test images in a list directly instead of dict
126127
expected_data = [{"image": i} for i in expected_data]
127128

128129
if base_dir is None:

monai/metrics/confusion_matrix.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -102,7 +102,9 @@ def _compute_tensor(self, y_pred: torch.Tensor, y: torch.Tensor): # type: ignor
102102

103103
return get_confusion_matrix(y_pred=y_pred, y=y, include_background=self.include_background)
104104

105-
def aggregate(self, compute_sample: bool = False, reduction: Union[MetricReduction, str, None] = None): # type: ignore
105+
def aggregate( # type: ignore
106+
self, compute_sample: bool = False, reduction: Union[MetricReduction, str, None] = None
107+
):
106108
"""
107109
Execute reduction for the confusion matrix values.
108110

tests/test_load_decathlon_datalist.py

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -29,7 +29,7 @@ def test_seg_values(self):
2929
{"image": "spleen_19.nii.gz", "label": "spleen_19.nii.gz"},
3030
{"image": "spleen_31.nii.gz", "label": "spleen_31.nii.gz"},
3131
],
32-
"test": ["spleen_15.nii.gz", "spleen_23.nii.gz"],
32+
"test": [{"image": "spleen_15.nii.gz"}, {"image": "spleen_23.nii.gz"}],
3333
}
3434
json_str = json.dumps(test_data)
3535
file_path = os.path.join(tempdir, "test_data.json")
@@ -38,6 +38,8 @@ def test_seg_values(self):
3838
result = load_decathlon_datalist(file_path, True, "training", tempdir)
3939
self.assertEqual(result[0]["image"], os.path.join(tempdir, "spleen_19.nii.gz"))
4040
self.assertEqual(result[0]["label"], os.path.join(tempdir, "spleen_19.nii.gz"))
41+
result = load_decathlon_datalist(file_path, True, "test", None)
42+
self.assertEqual(result[0]["image"], os.path.join(tempdir, "spleen_15.nii.gz"))
4143

4244
def test_cls_values(self):
4345
with tempfile.TemporaryDirectory() as tempdir:
@@ -81,6 +83,8 @@ def test_seg_no_basedir(self):
8183
result = load_decathlon_datalist(file_path, True, "training", None)
8284
self.assertEqual(result[0]["image"], os.path.join(tempdir, "spleen_19.nii.gz"))
8385
self.assertEqual(result[0]["label"], os.path.join(tempdir, "spleen_19.nii.gz"))
86+
result = load_decathlon_datalist(file_path, True, "test", None)
87+
self.assertEqual(result[0]["image"], os.path.join(tempdir, "spleen_15.nii.gz"))
8488

8589
def test_seg_no_labels(self):
8690
with tempfile.TemporaryDirectory() as tempdir:

0 commit comments

Comments
 (0)