Skip to content

add support for dicom images #1050

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 10 commits into from
Nov 18, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
14 changes: 11 additions & 3 deletions detection/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,8 @@ The dataset we are experimenting in this example is LUNA16 (https://luna16.grand

LUNA16 is a public dataset of CT lung nodule detection. Using raw CT scans, the goal is to identify locations of possible nodules, and to assign a probability for being a nodule to each location.

Users can either download mhd/raw data from [LUNA16](https://luna16.grand-challenge.org/Home/), or DICOM data from [LIDC-IDRI](https://wiki.cancerimagingarchive.net/pages/viewpage.action?pageId=1966254).

Disclaimer: We are not the host of the data. Please make sure to read the requirements and usage policies of the data and give credit to the authors of the dataset! We acknowledge the National Cancer Institute and the Foundation for the National Institutes of Health, and their critical role in the creation of the free publicly available LIDC/IDRI Database used in this study.

We follow the official 10-fold data splitting from LUNA16 challenge and generate data split json files using the script from [nnDetection](https://github.com/MIC-DKFZ/nnDetection/blob/main/projects/Task016_Luna/scripts/prepare.py).
Expand All @@ -36,15 +38,21 @@ In these files, the values of "box" are the ground truth boxes in world coordina

The raw CT images in LUNA16 have various of voxel sizes. The first step is to resample them to the same voxel size, which is defined in the value of "spacing" in [./config/config_train_luna16_16g.json](./config/config_train_luna16_16g.json).

Then, please open [luna16_prepare_env_files.py](luna16_prepare_env_files.py), change the value of "raw_data_base_dir" to the directory where you store the downloaded images, the value of "downloaded_datasplit_dir" to where you downloaded the data split json files, and the value of "resampled_data_base_dir" to the target directory where you will save the resampled images.
Then, please open [luna16_prepare_env_files.py](luna16_prepare_env_files.py), change the value of "raw_data_base_dir" to the directory where you store the downloaded images, the value of "downloaded_datasplit_dir" to where you downloaded the data split json files, and the value of "resampled_data_base_dir" to the target directory where you will save the resampled images. If you are using DICOM data, please also provide path to "dicom_meta_data_csv" which can be found in the downloaded folder from LIDC-IDRI.

Finally, resample the images by running
If you downloaded mhd/raw data, please resample the images by running
```bash
python3 luna16_prepare_env_files.py
python3 luna16_prepare_images.py -c ./config/config_train_luna16_16g.json
```

The original images are with mhd/raw format, the resampled images will be with Nifti format.
If you downloaded DICOM data, please resample the images by running
```bash
python3 luna16_prepare_env_files.py
python3 luna16_prepare_images_dicom.py -c ./config/config_train_luna16_16g.json
```

The resampled images will be with Nifti format.

#### [3.2 3D Detection Training](./luna16_training.py)

Expand Down
43 changes: 38 additions & 5 deletions detection/generate_transforms.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,9 +25,21 @@
RandFlipBoxd,
RandRotateBox90d,
RandZoomBoxd,
ConvertBoxModed
ConvertBoxModed,
)
def generate_detection_train_transform(image_key, box_key, label_key, gt_box_mode, intensity_transform, patch_size, batch_size, affine_lps_to_ras=False, amp=True):


def generate_detection_train_transform(
image_key,
box_key,
label_key,
gt_box_mode,
intensity_transform,
patch_size,
batch_size,
affine_lps_to_ras=False,
amp=True,
):
"""
Generate training transform for detection.

Expand Down Expand Up @@ -166,7 +178,16 @@ def generate_detection_train_transform(image_key, box_key, label_key, gt_box_mod
)
return train_transforms

def generate_detection_val_transform(image_key, box_key, label_key, gt_box_mode, intensity_transform, affine_lps_to_ras=False, amp=True):

def generate_detection_val_transform(
image_key,
box_key,
label_key,
gt_box_mode,
intensity_transform,
affine_lps_to_ras=False,
amp=True,
):
"""
Generate validation transform for detection.

Expand Down Expand Up @@ -211,7 +232,17 @@ def generate_detection_val_transform(image_key, box_key, label_key, gt_box_mode,
)
return val_transforms

def generate_detection_inference_transform(image_key, pred_box_key, pred_label_key, pred_score_key, gt_box_mode, intensity_transform, affine_lps_to_ras=False, amp=True):

def generate_detection_inference_transform(
image_key,
pred_box_key,
pred_label_key,
pred_score_key,
gt_box_mode,
intensity_transform,
affine_lps_to_ras=False,
amp=True,
):
"""
Generate validation transform for detection.

Expand Down Expand Up @@ -260,7 +291,9 @@ def generate_detection_inference_transform(image_key, pred_box_key, pred_label_k
image_meta_key_postfix="meta_dict",
affine_lps_to_ras=affine_lps_to_ras,
),
ConvertBoxModed(box_keys=[pred_box_key], src_mode = "xyzxyz", dst_mode=gt_box_mode),
ConvertBoxModed(
box_keys=[pred_box_key], src_mode="xyzxyz", dst_mode=gt_box_mode
),
DeleteItemsd(keys=[image_key]),
]
)
Expand Down
20 changes: 13 additions & 7 deletions detection/luna16_post_combine_cross_fold_results.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,12 +3,14 @@
import argparse
import os


def main():
parser = argparse.ArgumentParser(description="PyTorch Object Detection Training")
parser.add_argument(
"-i",
"--input",
nargs='+', default=[],
nargs="+",
default=[],
help="input json",
)
parser.add_argument(
Expand All @@ -22,16 +24,20 @@ def main():
in_json_list = args.input
out_csv = args.output

with open(out_csv, 'w', newline='') as csvfile:
spamwriter = csv.writer(csvfile, delimiter=',',
quotechar='|', quoting=csv.QUOTE_MINIMAL)
spamwriter.writerow(['seriesuid','coordX','coordY','coordZ','probability'])
with open(out_csv, "w", newline="") as csvfile:
spamwriter = csv.writer(
csvfile, delimiter=",", quotechar="|", quoting=csv.QUOTE_MINIMAL
)
spamwriter.writerow(["seriesuid", "coordX", "coordY", "coordZ", "probability"])
for in_json in in_json_list:
result = json.load(open(in_json, "r"))
for subj in result["validation"]:
seriesuid = os.path.split(subj["image"])[-1][:-7]
for b in range(len(subj["box"])):
spamwriter.writerow([seriesuid]+subj["box"][b][0:3]+[subj["score"][b]])
spamwriter.writerow(
[seriesuid] + subj["box"][b][0:3] + [subj["score"][b]]
)


if __name__ == '__main__':
if __name__ == "__main__":
main()
58 changes: 45 additions & 13 deletions detection/luna16_prepare_env_files.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,15 +14,31 @@
import logging
import sys


def main():
# ------------- Modification starts -------------
raw_data_base_dir = "/orig_datasets/" # the directory of the raw images
resampled_data_base_dir = "/datasets/" # the directory of the resampled images
downloaded_datasplit_dir = "LUNA16_datasplit" # the directory of downloaded data split files
raw_data_base_dir = "/home/canz/Projects/datasets/LIDC/manifest-1600709154662/LIDC-IDRI" # the directory of the raw images
resampled_data_base_dir = "/home/canz/Projects/datasets/LIDC/manifest-1600709154662/LIDC-IDRI_resample" # the directory of the resampled images
downloaded_datasplit_dir = (
"./LUNA16_datasplit" # the directory of downloaded data split files
)

out_trained_models_dir = (
"trained_models" # the directory to save trained model weights
)
out_tensorboard_events_dir = (
"tfevent_train" # the directory to save tensorboard training curves
)
out_inference_result_dir = (
"result" # the directory to save predicted boxes for inference
)

out_trained_models_dir = "trained_models" # the directory to save trained model weights
out_tensorboard_events_dir = "tfevent_train" # the directory to save tensorboard training curves
out_inference_result_dir = "result" # the directory to save predicted boxes for inference
# if deal with mhd/raw data, set it to be None
# dicom_meta_data_csv = None
# if deal with DICOM data, also need metadata.csv
dicom_meta_data_csv = (
"/home/canz/Projects/datasets/LIDC/manifest-1600709154662/metadata.csv"
)
# ------------- Modification ends ---------------

try:
Expand All @@ -45,20 +61,36 @@ def main():
env_dict = {}
env_dict["orig_data_base_dir"] = raw_data_base_dir
env_dict["data_base_dir"] = resampled_data_base_dir
env_dict["data_list_file_path"] = os.path.join(downloaded_datasplit_dir,"original/dataset_fold0.json")
if dicom_meta_data_csv != None:
env_dict["data_list_file_path"] = os.path.join(
downloaded_datasplit_dir, "dicom_original/dataset_fold0.json"
)
else:
env_dict["data_list_file_path"] = os.path.join(
downloaded_datasplit_dir, "mhd_original/dataset_fold0.json"
)
if dicom_meta_data_csv != None:
env_dict["dicom_meta_data_csv"] = dicom_meta_data_csv
with open(out_file, "w") as outfile:
json.dump(env_dict, outfile, indent=4)


# generate env json file for training and inference
for fold in range(10):
out_file = "config/environment_luna16_fold"+str(fold)+".json"
out_file = "config/environment_luna16_fold" + str(fold) + ".json"
env_dict = {}
env_dict["model_path"] = os.path.join(out_trained_models_dir,"model_luna16_fold"+str(fold)+".pt")
env_dict["model_path"] = os.path.join(
out_trained_models_dir, "model_luna16_fold" + str(fold) + ".pt"
)
env_dict["data_base_dir"] = resampled_data_base_dir
env_dict["data_list_file_path"] = os.path.join(downloaded_datasplit_dir,"dataset_fold"+str(fold)+".json")
env_dict["tfevent_path"] = os.path.join(out_tensorboard_events_dir,"luna16_fold"+str(fold))
env_dict["result_list_file_path"] = os.path.join(out_inference_result_dir,"result_luna16_fold"+str(fold)+".json")
env_dict["data_list_file_path"] = os.path.join(
downloaded_datasplit_dir, "dataset_fold" + str(fold) + ".json"
)
env_dict["tfevent_path"] = os.path.join(
out_tensorboard_events_dir, "luna16_fold" + str(fold)
)
env_dict["result_list_file_path"] = os.path.join(
out_inference_result_dir, "result_luna16_fold" + str(fold) + ".json"
)
with open(out_file, "w") as outfile:
json.dump(env_dict, outfile, indent=4)

Expand Down
132 changes: 132 additions & 0 deletions detection/luna16_prepare_images_dicom.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,132 @@
# Copyright (c) MONAI Consortium
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
# http://www.apache.org/licenses/LICENSE-2.0
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import argparse
import json
import logging
import sys
import os
import csv
from pathlib import Path

import monai
import torch
from monai.data import DataLoader, Dataset, load_decathlon_datalist, NibabelWriter
from monai.data.utils import no_collation
from monai.transforms import (
Compose,
EnsureChannelFirstd,
EnsureTyped,
LoadImaged,
Orientationd,
Spacingd,
)


def main():
parser = argparse.ArgumentParser(description="LUNA16 Detection Image Resampling")
parser.add_argument(
"-e",
"--environment-file",
default="./config/environment_luna16_prepare.json",
help="environment json file that stores environment path",
)
parser.add_argument(
"-c",
"--config-file",
default="./config/config_train_luna16_16g.json",
help="config json file that stores hyper-parameters",
)
args = parser.parse_args()

monai.config.print_config()

env_dict = json.load(open(args.environment_file, "r"))
config_dict = json.load(open(args.config_file, "r"))

for k, v in env_dict.items():
setattr(args, k, v)
for k, v in config_dict.items():
setattr(args, k, v)

# 1. define transform
# resample images to args.spacing defined in args.config_file.
process_transforms = Compose(
[
LoadImaged(
keys=["image"],
meta_key_postfix="meta_dict",
reader="itkreader",
affine_lps_to_ras=True,
),
EnsureChannelFirstd(keys=["image"]),
EnsureTyped(keys=["image"], dtype=torch.float16),
Orientationd(keys=["image"], axcodes="RAS"),
Spacingd(keys=["image"], pixdim=args.spacing, padding_mode="border"),
]
)

# 2. prepare data
meta_dict = {}
with open(env_dict["dicom_meta_data_csv"], newline="") as csvfile:
print("open " + env_dict["dicom_meta_data_csv"])
reader = csv.DictReader(csvfile)
for row in reader:
meta_dict[row["File Location"][12:]] = str(row["Series UID"])

for data_list_key in ["training", "validation"]:
# create a data loader
process_data = load_decathlon_datalist(
args.data_list_file_path,
is_segmentation=True,
data_list_key=data_list_key,
base_dir=args.orig_data_base_dir,
)
process_ds = Dataset(
data=process_data,
transform=process_transforms,
)
process_loader = DataLoader(
process_ds,
batch_size=1,
shuffle=False,
num_workers=1,
pin_memory=False,
collate_fn=no_collation,
)

print("-" * 10)
for batch_data in process_loader:
for batch_data_i in batch_data:
subj_id = meta_dict[
"/".join(
batch_data_i["image_meta_dict"]["filename_or_obj"].split("/")[
-3:
]
)
]
new_path = os.path.join(args.data_base_dir, subj_id)
Path(new_path).mkdir(parents=True, exist_ok=True)
new_filename = os.path.join(new_path, subj_id + ".nii.gz")
writer = NibabelWriter()
writer.set_data_array(data_array=batch_data_i["image"])
writer.set_metadata(meta_dict=batch_data_i["image"].meta)
writer.write(new_filename, verbose=True)


if __name__ == "__main__":
logging.basicConfig(
stream=sys.stdout,
level=logging.INFO,
format="[%(asctime)s.%(msecs)03d][%(levelname)5s](%(name)s) - %(message)s",
datefmt="%Y-%m-%d %H:%M:%S",
)
main()
Loading