Skip to content

Commit 7a9d0cb

Browse files
committed
[DLMED] add scripts
Signed-off-by: Nic Ma <[email protected]>
1 parent 3d49380 commit 7a9d0cb

File tree

6 files changed

+104
-4
lines changed

6 files changed

+104
-4
lines changed

modules/bundles/hybrid_programming/README.md

Lines changed: 7 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -6,14 +6,18 @@ This example mainly shows 2 typical use cases of hybrid programming with MONAI b
66
Please note that this example depends on the `spleen_segmentation` bundle example and show the hybrid programming via overriding the config file of it.
77

88
## commands example
9-
Override the `train` config with customized `transforms` and execute training:
9+
Export the customized python code to `PYTHONPATH`:
10+
```
11+
export PYTHONPATH=$PYTHONPATH:"<path to 'hybrid_programming/scripts'>"
12+
```
1013

14+
Override the `train` config with customized `transforms` and execute training:
1115
```
12-
python -m monai.bundle run training --meta_file configs/metadata.json --config_file "['configs/train.json','configs/custom_train.json']" --logging_file configs/logging.conf
16+
python -m monai.bundle run training --meta_file <spleen_configs_path>/metadata.json --config_file "['<spleen_configs_path>/train.json','configs/custom_train.json']" --logging_file <spleen_configs_path>/logging.conf
1317
```
1418

1519
Parse the config in the python program and execute inference from the program:
1620

1721
```
18-
python -m scripts.inference.run --config_file configs/inference.json
22+
python -m scripts.inference run --config_file <spleen_configs_path>/inference.json --ckpt_path <path_to_checkpoint>
1923
```
Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,7 @@
1+
{
2+
"train#preprocessing#transforms#6":
3+
{
4+
"_target_": "scripts.custom_transforms.PrintEnsureTyped",
5+
"keys": ["image", "label"]
6+
}
7+
}
Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,10 @@
1+
# Copyright (c) MONAI Consortium
2+
# Licensed under the Apache License, Version 2.0 (the "License");
3+
# you may not use this file except in compliance with the License.
4+
# You may obtain a copy of the License at
5+
# http://www.apache.org/licenses/LICENSE-2.0
6+
# Unless required by applicable law or agreed to in writing, software
7+
# distributed under the License is distributed on an "AS IS" BASIS,
8+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
9+
# See the License for the specific language governing permissions and
10+
# limitations under the License.
Lines changed: 32 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,32 @@
1+
# Copyright (c) MONAI Consortium
2+
# Licensed under the Apache License, Version 2.0 (the "License");
3+
# you may not use this file except in compliance with the License.
4+
# You may obtain a copy of the License at
5+
# http://www.apache.org/licenses/LICENSE-2.0
6+
# Unless required by applicable law or agreed to in writing, software
7+
# distributed under the License is distributed on an "AS IS" BASIS,
8+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
9+
# See the License for the specific language governing permissions and
10+
# limitations under the License.
11+
12+
from monai.config import KeysCollection
13+
from monai.transforms import EnsureTyped
14+
15+
16+
class PrintEnsureTyped(EnsureTyped):
17+
"""
18+
Extend the `EnsureTyped` transform to print the image shape.
19+
20+
Args:
21+
keys: keys of the corresponding items to be transformed.
22+
23+
"""
24+
25+
def __init__(self, keys: KeysCollection, data_type: str = "tensor") -> None:
26+
super().__init__(keys, data_type=data_type)
27+
28+
def __call__(self, data):
29+
d = dict(super().__call__(data=data))
30+
for key in self.key_iterator(d):
31+
print(f"data shape of {key}: {d[key].shape}")
32+
return d
Lines changed: 47 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,47 @@
1+
# Copyright (c) MONAI Consortium
2+
# Licensed under the Apache License, Version 2.0 (the "License");
3+
# you may not use this file except in compliance with the License.
4+
# You may obtain a copy of the License at
5+
# http://www.apache.org/licenses/LICENSE-2.0
6+
# Unless required by applicable law or agreed to in writing, software
7+
# distributed under the License is distributed on an "AS IS" BASIS,
8+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
9+
# See the License for the specific language governing permissions and
10+
# limitations under the License.
11+
12+
import torch
13+
from monai.bundle import ConfigParser
14+
from monai.data import decollate_batch
15+
from monai.utils.enums import CommonKeys
16+
17+
18+
def run(config_file: str, ckpt_path: str):
19+
parser = ConfigParser()
20+
parser.read_config(config_file)
21+
# edit the config content at runtime and lazy instantiation
22+
parser["inferer"]["roi_size"] = [160, 160, 160]
23+
24+
device = parser.get_parsed_content("device")
25+
# instantialize the components
26+
model = parser.get_parsed_content("network")
27+
model.load_state_dict(torch.load(ckpt_path))
28+
29+
dataloader = parser.get_parsed_content("dataloader")
30+
inferer = parser.get_parsed_content("inferer")
31+
postprocessing = parser.get_parsed_content("postprocessing")
32+
33+
model.eval()
34+
with torch.no_grad():
35+
for d in dataloader:
36+
images = d[CommonKeys.IMAGE].to(device)
37+
# define sliding window size and batch size for windows inference
38+
d[CommonKeys.PRED] = inferer(inputs=images, network=model)
39+
# decollate the batch data into a list of dictionaries, then execute postprocessing transforms
40+
[postprocessing(i) for i in decollate_batch(d)]
41+
42+
43+
if __name__ == "__main__":
44+
from monai.utils import optional_import
45+
46+
fire, _ = optional_import("fire")
47+
fire.Fire()

modules/bundles/spleen_segmentation/configs/metadata.json

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,7 @@
55
"0.1.0": "complete the model package",
66
"0.0.1": "initialize the model package structure"
77
},
8-
"monai_version": "0.8.0",
8+
"monai_version": "0.9.0",
99
"pytorch_version": "1.10.0",
1010
"numpy_version": "1.21.2",
1111
"optional_packages_version": {

0 commit comments

Comments
 (0)