Skip to content

Commit 7db7e6a

Browse files
authored
Merge pull request #384 from Project-MONAI/vikash/breast_density_map
A breast density monai application package
2 parents c5ce6c4 + a4352b7 commit 7db7e6a

File tree

5 files changed

+197
-0
lines changed

5 files changed

+197
-0
lines changed
Lines changed: 29 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,29 @@
1+
## A MONAI Application Package to deploy breast density classification algorithm
2+
This MAP is based on the Breast Density Model in MONAI [Model-Zoo](https://github.com/Project-MONAI/model-zoo). This model is developed at the Center for Augmented Intelligence in Imaging at the Mayo Clinic, Florida.
3+
For any questions, feel free to contact Vikash Gupta ([email protected])
4+
Sample data and a torchscript model can be downloaded from https://drive.google.com/drive/folders/1Dryozl2MwNunpsGaFPVoaKBLkNbVM3Hu?usp=sharing
5+
6+
7+
## Run the application package
8+
### Python CLI
9+
```
10+
python app.py -i <input_dir> -o <out_dir> -m <breast_density_model>
11+
```
12+
13+
### MONAI Deploy CLI
14+
```
15+
monai-deploy exec app.py -i <input_dir> -o <out_dir> -m <breast_density_model>
16+
```
17+
Alternatively, you can go a level higher and execute
18+
```
19+
monai-deploy exec breast_density_classification_app -i <input_dir> -o <out_dir> -m <breast_density_model>
20+
```
21+
22+
23+
### Packaging the monai app
24+
In order to build the monai app, Go a level up and execute the following command.
25+
```
26+
monai-deploy package -b nvcr.io/nvidia/pytorch:21.12-py3 breast_density_classification_app --tag breast_density:0.1.0 -m $breast_density_model
27+
```
28+
29+
Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,7 @@
1+
import os
2+
import sys
3+
4+
_current_dir = os.path.abspath(os.path.dirname(__file__))
5+
if sys.path and os.path.abspath(sys.path[0]) != _current_dir:
6+
sys.path.insert(0, _current_dir)
7+
del _current_dir
Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,4 @@
1+
from app import BreastClassificationApp
2+
3+
if __name__ == "__main__":
4+
BreastClassificationApp(do_run=True)
Lines changed: 48 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,48 @@
1+
from breast_density_classifier_operator import ClassifierOperator
2+
3+
from monai.deploy.core import Application
4+
from monai.deploy.operators.dicom_data_loader_operator import DICOMDataLoaderOperator
5+
from monai.deploy.operators.dicom_series_selector_operator import DICOMSeriesSelectorOperator
6+
from monai.deploy.operators.dicom_series_to_volume_operator import DICOMSeriesToVolumeOperator
7+
from monai.deploy.operators.dicom_text_sr_writer_operator import DICOMTextSRWriterOperator, EquipmentInfo, ModelInfo
8+
9+
10+
class BreastClassificationApp(Application):
11+
def __init__(self, *args, **kwargs):
12+
super().__init__(*args, **kwargs)
13+
14+
def compose(self):
15+
model_info = ModelInfo(
16+
"MONAI Model for Breast Density",
17+
"BreastDensity",
18+
"0.1",
19+
"Center for Augmented Intelligence in Imaging, Mayo Clinic, Florida",
20+
)
21+
my_equipment = EquipmentInfo(manufacturer="MONAI Deploy App SD", manufacturer_model="DICOM SR Writer")
22+
my_special_tags = {"SeriesDescription": "Not for clinical use"}
23+
study_loader_op = DICOMDataLoaderOperator()
24+
series_selector_op = DICOMSeriesSelectorOperator(rules="")
25+
series_to_vol_op = DICOMSeriesToVolumeOperator()
26+
classifier_op = ClassifierOperator()
27+
sr_writer_op = DICOMTextSRWriterOperator(
28+
copy_tags=False, model_info=model_info, equipment_info=my_equipment, custom_tags=my_special_tags
29+
)
30+
31+
self.add_flow(study_loader_op, series_selector_op, {"dicom_study_list": "dicom_study_list"})
32+
self.add_flow(
33+
series_selector_op, series_to_vol_op, {"study_selected_series_list": "study_selected_series_list"}
34+
)
35+
self.add_flow(series_to_vol_op, classifier_op, {"image": "image"})
36+
self.add_flow(classifier_op, sr_writer_op, {"result_text": "classification_result"})
37+
38+
39+
def main():
40+
app = BreastClassificationApp()
41+
image_dir = "./sampleDICOMs/1/BI_BREAST_SCREENING_BILATERAL_WITH_TOMOSYNTHESIS-2019-07-08/1/L_CC_C-View"
42+
43+
model_path = "./model/traced_ts_model.pt"
44+
app.run(input=image_dir, output="./output", model=model_path)
45+
46+
47+
if __name__ == "__main__":
48+
main()
Lines changed: 109 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,109 @@
1+
import json
2+
from typing import Dict, Text
3+
4+
import torch
5+
6+
import monai.deploy.core as md
7+
from monai.data import DataLoader, Dataset
8+
from monai.deploy.core import ExecutionContext, Image, InputContext, IOType, Operator, OutputContext
9+
from monai.deploy.operators.monai_seg_inference_operator import InMemImageReader
10+
from monai.transforms import (
11+
Activations,
12+
Compose,
13+
EnsureChannelFirst,
14+
EnsureType,
15+
LoadImage,
16+
NormalizeIntensity,
17+
RepeatChannel,
18+
Resize,
19+
SqueezeDim,
20+
)
21+
22+
23+
@md.input("image", Image, IOType.IN_MEMORY)
24+
@md.output("result_text", Text, IOType.IN_MEMORY)
25+
class ClassifierOperator(Operator):
26+
def __init__(self):
27+
super().__init__()
28+
self._input_dataset_key = "image"
29+
self._pred_dataset_key = "pred"
30+
31+
def _convert_dicom_metadata_datatype(self, metadata: Dict):
32+
if not metadata:
33+
return metadata
34+
35+
# Try to convert data type for the well knowned attributes. Add more as needed.
36+
if metadata.get("SeriesInstanceUID", None):
37+
try:
38+
metadata["SeriesInstanceUID"] = str(metadata["SeriesInstanceUID"])
39+
except Exception:
40+
pass
41+
if metadata.get("row_pixel_spacing", None):
42+
try:
43+
metadata["row_pixel_spacing"] = float(metadata["row_pixel_spacing"])
44+
except Exception:
45+
pass
46+
if metadata.get("col_pixel_spacing", None):
47+
try:
48+
metadata["col_pixel_spacing"] = float(metadata["col_pixel_spacing"])
49+
except Exception:
50+
pass
51+
52+
print("Converted Image object metadata:")
53+
for k, v in metadata.items():
54+
print(f"{k}: {v}, type {type(v)}")
55+
56+
return metadata
57+
58+
def compute(self, op_input: InputContext, op_output: OutputContext, context: ExecutionContext):
59+
input_image = op_input.get("image")
60+
_reader = InMemImageReader(input_image)
61+
input_img_metadata = self._convert_dicom_metadata_datatype(input_image.metadata())
62+
img_name = str(input_img_metadata.get("SeriesInstanceUID", "Img_in_context"))
63+
64+
output_path = context.output.get().path
65+
66+
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
67+
model = context.models.get()
68+
69+
pre_transforms = self.pre_process(_reader)
70+
post_transforms = self.post_process()
71+
72+
dataset = Dataset(data=[{self._input_dataset_key: img_name}], transform=pre_transforms)
73+
dataloader = DataLoader(dataset, batch_size=10, shuffle=False, num_workers=0)
74+
75+
with torch.no_grad():
76+
for d in dataloader:
77+
image = d[0].to(device)
78+
outputs = model(image)
79+
out = post_transforms(outputs).data.cpu().numpy()[0]
80+
print(out)
81+
82+
result_dict = (
83+
"A " + ":" + str(out[0]) + " B " + ":" + str(out[1]) + " C " + ":" + str(out[2]) + " D " + ":" + str(out[3])
84+
)
85+
result_dict_out = {"A": str(out[0]), "B": str(out[1]), "C": str(out[2]), "D": str(out[3])}
86+
output_folder = context.output.get().path
87+
output_folder.mkdir(parents=True, exist_ok=True)
88+
89+
output_path = output_folder / "output.json"
90+
with open(output_path, "w") as fp:
91+
json.dump(result_dict, fp)
92+
93+
op_output.set(result_dict, "result_text")
94+
95+
def pre_process(self, image_reader) -> Compose:
96+
return Compose(
97+
[
98+
LoadImage(reader=image_reader, image_only=True),
99+
EnsureChannelFirst(),
100+
SqueezeDim(dim=3),
101+
NormalizeIntensity(),
102+
Resize(spatial_size=(299, 299)),
103+
RepeatChannel(repeats=3),
104+
EnsureChannelFirst(),
105+
]
106+
)
107+
108+
def post_process(self) -> Compose:
109+
return Compose([EnsureType(), Activations(sigmoid=True)])

0 commit comments

Comments
 (0)