Skip to content

Commit 062a95d

Browse files
add support for dicom images (#1050)
Signed-off-by: Can Zhao <[email protected]> Fixes #. ### Description The original source of mhd/raw images was removed. Now add support for dicom images. And reformatted code. ### Checks <!--- Put an `x` in all the boxes that apply, and remove the not applicable items --> - [ ] Notebook runs automatically `./runner [-p <regex_pattern>]` Signed-off-by: Can Zhao <[email protected]> Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
1 parent 7eccb85 commit 062a95d

8 files changed

+302
-53
lines changed

detection/README.md

Lines changed: 11 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,8 @@ The dataset we are experimenting in this example is LUNA16 (https://luna16.grand
1919

2020
LUNA16 is a public dataset of CT lung nodule detection. Using raw CT scans, the goal is to identify locations of possible nodules, and to assign a probability for being a nodule to each location.
2121

22+
Users can either download mhd/raw data from [LUNA16](https://luna16.grand-challenge.org/Home/), or DICOM data from [LIDC-IDRI](https://wiki.cancerimagingarchive.net/pages/viewpage.action?pageId=1966254).
23+
2224
Disclaimer: We are not the host of the data. Please make sure to read the requirements and usage policies of the data and give credit to the authors of the dataset! We acknowledge the National Cancer Institute and the Foundation for the National Institutes of Health, and their critical role in the creation of the free publicly available LIDC/IDRI Database used in this study.
2325

2426
We follow the official 10-fold data splitting from LUNA16 challenge and generate data split json files using the script from [nnDetection](https://github.com/MIC-DKFZ/nnDetection/blob/main/projects/Task016_Luna/scripts/prepare.py).
@@ -36,15 +38,21 @@ In these files, the values of "box" are the ground truth boxes in world coordina
3638

3739
The raw CT images in LUNA16 have various of voxel sizes. The first step is to resample them to the same voxel size, which is defined in the value of "spacing" in [./config/config_train_luna16_16g.json](./config/config_train_luna16_16g.json).
3840

39-
Then, please open [luna16_prepare_env_files.py](luna16_prepare_env_files.py), change the value of "raw_data_base_dir" to the directory where you store the downloaded images, the value of "downloaded_datasplit_dir" to where you downloaded the data split json files, and the value of "resampled_data_base_dir" to the target directory where you will save the resampled images.
41+
Then, please open [luna16_prepare_env_files.py](luna16_prepare_env_files.py), change the value of "raw_data_base_dir" to the directory where you store the downloaded images, the value of "downloaded_datasplit_dir" to where you downloaded the data split json files, and the value of "resampled_data_base_dir" to the target directory where you will save the resampled images. If you are using DICOM data, please also provide path to "dicom_meta_data_csv" which can be found in the downloaded folder from LIDC-IDRI.
4042

41-
Finally, resample the images by running
43+
If you downloaded mhd/raw data, please resample the images by running
4244
```bash
4345
python3 luna16_prepare_env_files.py
4446
python3 luna16_prepare_images.py -c ./config/config_train_luna16_16g.json
4547
```
4648

47-
The original images are with mhd/raw format, the resampled images will be with Nifti format.
49+
If you downloaded DICOM data, please resample the images by running
50+
```bash
51+
python3 luna16_prepare_env_files.py
52+
python3 luna16_prepare_images_dicom.py -c ./config/config_train_luna16_16g.json
53+
```
54+
55+
The resampled images will be with Nifti format.
4856

4957
#### [3.2 3D Detection Training](./luna16_training.py)
5058

detection/generate_transforms.py

Lines changed: 38 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -25,9 +25,21 @@
2525
RandFlipBoxd,
2626
RandRotateBox90d,
2727
RandZoomBoxd,
28-
ConvertBoxModed
28+
ConvertBoxModed,
2929
)
30-
def generate_detection_train_transform(image_key, box_key, label_key, gt_box_mode, intensity_transform, patch_size, batch_size, affine_lps_to_ras=False, amp=True):
30+
31+
32+
def generate_detection_train_transform(
33+
image_key,
34+
box_key,
35+
label_key,
36+
gt_box_mode,
37+
intensity_transform,
38+
patch_size,
39+
batch_size,
40+
affine_lps_to_ras=False,
41+
amp=True,
42+
):
3143
"""
3244
Generate training transform for detection.
3345
@@ -166,7 +178,16 @@ def generate_detection_train_transform(image_key, box_key, label_key, gt_box_mod
166178
)
167179
return train_transforms
168180

169-
def generate_detection_val_transform(image_key, box_key, label_key, gt_box_mode, intensity_transform, affine_lps_to_ras=False, amp=True):
181+
182+
def generate_detection_val_transform(
183+
image_key,
184+
box_key,
185+
label_key,
186+
gt_box_mode,
187+
intensity_transform,
188+
affine_lps_to_ras=False,
189+
amp=True,
190+
):
170191
"""
171192
Generate validation transform for detection.
172193
@@ -211,7 +232,17 @@ def generate_detection_val_transform(image_key, box_key, label_key, gt_box_mode,
211232
)
212233
return val_transforms
213234

214-
def generate_detection_inference_transform(image_key, pred_box_key, pred_label_key, pred_score_key, gt_box_mode, intensity_transform, affine_lps_to_ras=False, amp=True):
235+
236+
def generate_detection_inference_transform(
237+
image_key,
238+
pred_box_key,
239+
pred_label_key,
240+
pred_score_key,
241+
gt_box_mode,
242+
intensity_transform,
243+
affine_lps_to_ras=False,
244+
amp=True,
245+
):
215246
"""
216247
Generate validation transform for detection.
217248
@@ -260,7 +291,9 @@ def generate_detection_inference_transform(image_key, pred_box_key, pred_label_k
260291
image_meta_key_postfix="meta_dict",
261292
affine_lps_to_ras=affine_lps_to_ras,
262293
),
263-
ConvertBoxModed(box_keys=[pred_box_key], src_mode = "xyzxyz", dst_mode=gt_box_mode),
294+
ConvertBoxModed(
295+
box_keys=[pred_box_key], src_mode="xyzxyz", dst_mode=gt_box_mode
296+
),
264297
DeleteItemsd(keys=[image_key]),
265298
]
266299
)

detection/luna16_post_combine_cross_fold_results.py

Lines changed: 13 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -3,12 +3,14 @@
33
import argparse
44
import os
55

6+
67
def main():
78
parser = argparse.ArgumentParser(description="PyTorch Object Detection Training")
89
parser.add_argument(
910
"-i",
1011
"--input",
11-
nargs='+', default=[],
12+
nargs="+",
13+
default=[],
1214
help="input json",
1315
)
1416
parser.add_argument(
@@ -22,16 +24,20 @@ def main():
2224
in_json_list = args.input
2325
out_csv = args.output
2426

25-
with open(out_csv, 'w', newline='') as csvfile:
26-
spamwriter = csv.writer(csvfile, delimiter=',',
27-
quotechar='|', quoting=csv.QUOTE_MINIMAL)
28-
spamwriter.writerow(['seriesuid','coordX','coordY','coordZ','probability'])
27+
with open(out_csv, "w", newline="") as csvfile:
28+
spamwriter = csv.writer(
29+
csvfile, delimiter=",", quotechar="|", quoting=csv.QUOTE_MINIMAL
30+
)
31+
spamwriter.writerow(["seriesuid", "coordX", "coordY", "coordZ", "probability"])
2932
for in_json in in_json_list:
3033
result = json.load(open(in_json, "r"))
3134
for subj in result["validation"]:
3235
seriesuid = os.path.split(subj["image"])[-1][:-7]
3336
for b in range(len(subj["box"])):
34-
spamwriter.writerow([seriesuid]+subj["box"][b][0:3]+[subj["score"][b]])
37+
spamwriter.writerow(
38+
[seriesuid] + subj["box"][b][0:3] + [subj["score"][b]]
39+
)
40+
3541

36-
if __name__ == '__main__':
42+
if __name__ == "__main__":
3743
main()

detection/luna16_prepare_env_files.py

Lines changed: 45 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -14,15 +14,31 @@
1414
import logging
1515
import sys
1616

17+
1718
def main():
1819
# ------------- Modification starts -------------
19-
raw_data_base_dir = "/orig_datasets/" # the directory of the raw images
20-
resampled_data_base_dir = "/datasets/" # the directory of the resampled images
21-
downloaded_datasplit_dir = "LUNA16_datasplit" # the directory of downloaded data split files
20+
raw_data_base_dir = "/home/canz/Projects/datasets/LIDC/manifest-1600709154662/LIDC-IDRI" # the directory of the raw images
21+
resampled_data_base_dir = "/home/canz/Projects/datasets/LIDC/manifest-1600709154662/LIDC-IDRI_resample" # the directory of the resampled images
22+
downloaded_datasplit_dir = (
23+
"./LUNA16_datasplit" # the directory of downloaded data split files
24+
)
25+
26+
out_trained_models_dir = (
27+
"trained_models" # the directory to save trained model weights
28+
)
29+
out_tensorboard_events_dir = (
30+
"tfevent_train" # the directory to save tensorboard training curves
31+
)
32+
out_inference_result_dir = (
33+
"result" # the directory to save predicted boxes for inference
34+
)
2235

23-
out_trained_models_dir = "trained_models" # the directory to save trained model weights
24-
out_tensorboard_events_dir = "tfevent_train" # the directory to save tensorboard training curves
25-
out_inference_result_dir = "result" # the directory to save predicted boxes for inference
36+
# if deal with mhd/raw data, set it to be None
37+
# dicom_meta_data_csv = None
38+
# if deal with DICOM data, also need metadata.csv
39+
dicom_meta_data_csv = (
40+
"/home/canz/Projects/datasets/LIDC/manifest-1600709154662/metadata.csv"
41+
)
2642
# ------------- Modification ends ---------------
2743

2844
try:
@@ -45,20 +61,36 @@ def main():
4561
env_dict = {}
4662
env_dict["orig_data_base_dir"] = raw_data_base_dir
4763
env_dict["data_base_dir"] = resampled_data_base_dir
48-
env_dict["data_list_file_path"] = os.path.join(downloaded_datasplit_dir,"original/dataset_fold0.json")
64+
if dicom_meta_data_csv != None:
65+
env_dict["data_list_file_path"] = os.path.join(
66+
downloaded_datasplit_dir, "dicom_original/dataset_fold0.json"
67+
)
68+
else:
69+
env_dict["data_list_file_path"] = os.path.join(
70+
downloaded_datasplit_dir, "mhd_original/dataset_fold0.json"
71+
)
72+
if dicom_meta_data_csv != None:
73+
env_dict["dicom_meta_data_csv"] = dicom_meta_data_csv
4974
with open(out_file, "w") as outfile:
5075
json.dump(env_dict, outfile, indent=4)
5176

52-
5377
# generate env json file for training and inference
5478
for fold in range(10):
55-
out_file = "config/environment_luna16_fold"+str(fold)+".json"
79+
out_file = "config/environment_luna16_fold" + str(fold) + ".json"
5680
env_dict = {}
57-
env_dict["model_path"] = os.path.join(out_trained_models_dir,"model_luna16_fold"+str(fold)+".pt")
81+
env_dict["model_path"] = os.path.join(
82+
out_trained_models_dir, "model_luna16_fold" + str(fold) + ".pt"
83+
)
5884
env_dict["data_base_dir"] = resampled_data_base_dir
59-
env_dict["data_list_file_path"] = os.path.join(downloaded_datasplit_dir,"dataset_fold"+str(fold)+".json")
60-
env_dict["tfevent_path"] = os.path.join(out_tensorboard_events_dir,"luna16_fold"+str(fold))
61-
env_dict["result_list_file_path"] = os.path.join(out_inference_result_dir,"result_luna16_fold"+str(fold)+".json")
85+
env_dict["data_list_file_path"] = os.path.join(
86+
downloaded_datasplit_dir, "dataset_fold" + str(fold) + ".json"
87+
)
88+
env_dict["tfevent_path"] = os.path.join(
89+
out_tensorboard_events_dir, "luna16_fold" + str(fold)
90+
)
91+
env_dict["result_list_file_path"] = os.path.join(
92+
out_inference_result_dir, "result_luna16_fold" + str(fold) + ".json"
93+
)
6294
with open(out_file, "w") as outfile:
6395
json.dump(env_dict, outfile, indent=4)
6496

Lines changed: 132 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,132 @@
1+
# Copyright (c) MONAI Consortium
2+
# Licensed under the Apache License, Version 2.0 (the "License");
3+
# you may not use this file except in compliance with the License.
4+
# You may obtain a copy of the License at
5+
# http://www.apache.org/licenses/LICENSE-2.0
6+
# Unless required by applicable law or agreed to in writing, software
7+
# distributed under the License is distributed on an "AS IS" BASIS,
8+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
9+
# See the License for the specific language governing permissions and
10+
# limitations under the License.
11+
12+
import argparse
13+
import json
14+
import logging
15+
import sys
16+
import os
17+
import csv
18+
from pathlib import Path
19+
20+
import monai
21+
import torch
22+
from monai.data import DataLoader, Dataset, load_decathlon_datalist, NibabelWriter
23+
from monai.data.utils import no_collation
24+
from monai.transforms import (
25+
Compose,
26+
EnsureChannelFirstd,
27+
EnsureTyped,
28+
LoadImaged,
29+
Orientationd,
30+
Spacingd,
31+
)
32+
33+
34+
def main():
35+
parser = argparse.ArgumentParser(description="LUNA16 Detection Image Resampling")
36+
parser.add_argument(
37+
"-e",
38+
"--environment-file",
39+
default="./config/environment_luna16_prepare.json",
40+
help="environment json file that stores environment path",
41+
)
42+
parser.add_argument(
43+
"-c",
44+
"--config-file",
45+
default="./config/config_train_luna16_16g.json",
46+
help="config json file that stores hyper-parameters",
47+
)
48+
args = parser.parse_args()
49+
50+
monai.config.print_config()
51+
52+
env_dict = json.load(open(args.environment_file, "r"))
53+
config_dict = json.load(open(args.config_file, "r"))
54+
55+
for k, v in env_dict.items():
56+
setattr(args, k, v)
57+
for k, v in config_dict.items():
58+
setattr(args, k, v)
59+
60+
# 1. define transform
61+
# resample images to args.spacing defined in args.config_file.
62+
process_transforms = Compose(
63+
[
64+
LoadImaged(
65+
keys=["image"],
66+
meta_key_postfix="meta_dict",
67+
reader="itkreader",
68+
affine_lps_to_ras=True,
69+
),
70+
EnsureChannelFirstd(keys=["image"]),
71+
EnsureTyped(keys=["image"], dtype=torch.float16),
72+
Orientationd(keys=["image"], axcodes="RAS"),
73+
Spacingd(keys=["image"], pixdim=args.spacing, padding_mode="border"),
74+
]
75+
)
76+
77+
# 2. prepare data
78+
meta_dict = {}
79+
with open(env_dict["dicom_meta_data_csv"], newline="") as csvfile:
80+
print("open " + env_dict["dicom_meta_data_csv"])
81+
reader = csv.DictReader(csvfile)
82+
for row in reader:
83+
meta_dict[row["File Location"][12:]] = str(row["Series UID"])
84+
85+
for data_list_key in ["training", "validation"]:
86+
# create a data loader
87+
process_data = load_decathlon_datalist(
88+
args.data_list_file_path,
89+
is_segmentation=True,
90+
data_list_key=data_list_key,
91+
base_dir=args.orig_data_base_dir,
92+
)
93+
process_ds = Dataset(
94+
data=process_data,
95+
transform=process_transforms,
96+
)
97+
process_loader = DataLoader(
98+
process_ds,
99+
batch_size=1,
100+
shuffle=False,
101+
num_workers=1,
102+
pin_memory=False,
103+
collate_fn=no_collation,
104+
)
105+
106+
print("-" * 10)
107+
for batch_data in process_loader:
108+
for batch_data_i in batch_data:
109+
subj_id = meta_dict[
110+
"/".join(
111+
batch_data_i["image_meta_dict"]["filename_or_obj"].split("/")[
112+
-3:
113+
]
114+
)
115+
]
116+
new_path = os.path.join(args.data_base_dir, subj_id)
117+
Path(new_path).mkdir(parents=True, exist_ok=True)
118+
new_filename = os.path.join(new_path, subj_id + ".nii.gz")
119+
writer = NibabelWriter()
120+
writer.set_data_array(data_array=batch_data_i["image"])
121+
writer.set_metadata(meta_dict=batch_data_i["image"].meta)
122+
writer.write(new_filename, verbose=True)
123+
124+
125+
if __name__ == "__main__":
126+
logging.basicConfig(
127+
stream=sys.stdout,
128+
level=logging.INFO,
129+
format="[%(asctime)s.%(msecs)03d][%(levelname)5s](%(name)s) - %(message)s",
130+
datefmt="%Y-%m-%d %H:%M:%S",
131+
)
132+
main()

0 commit comments

Comments
 (0)