Skip to content

Commit e573874

Browse files
670 add bundle example for multi-gpu training (Project-MONAI#673)
* [DLMED] draft config Signed-off-by: Nic Ma <[email protected]> * [DLMED] update for test Signed-off-by: Nic Ma <[email protected]> * [DLMED] update based on enhancement Signed-off-by: Nic Ma <[email protected]> * [DLMED] update tutorial Signed-off-by: Nic Ma <[email protected]> * [DLMED] simplify to override Signed-off-by: Nic Ma <[email protected]> * [DLMED] update according to comments Signed-off-by: Nic Ma <[email protected]> * [DLMED] remove test file Signed-off-by: Nic Ma <[email protected]> * [DLMED] add evaluation config Signed-off-by: Nic Ma <[email protected]> * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * [DLMED] simplify inference Signed-off-by: Nic Ma <[email protected]> * [DLMED] update according to comments Signed-off-by: Nic Ma <[email protected]> Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
1 parent 404a92b commit e573874

File tree

6 files changed

+162
-42
lines changed

6 files changed

+162
-42
lines changed

modules/bundles/get_started.ipynb

Lines changed: 33 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -6,13 +6,13 @@
66
"source": [
77
"# Get started to MONAI bundle\n",
88
"\n",
9-
"MONAI bundle usually includes the stored weights of a model, TorchScript model, JSON files that include configs and metadata about the model, information for constructing training, inference, and post-processing transform sequences, plain-text description, legal information, and other data the model creator wishes to include.\n",
9+
"A MONAI bundle usually includes the stored weights of a model, TorchScript model, JSON files which include configs and metadata about the model, information for constructing training, inference, and post-processing transform sequences, plain-text description, legal information, and other data the model creator wishes to include.\n",
1010
"\n",
11-
"For more information about MONAI bundle description: https://docs.monai.io/en/latest/bundle_intro.html.\n",
11+
"For more information about MONAI bundles read the description: https://docs.monai.io/en/latest/bundle_intro.html.\n",
1212
"\n",
13-
"This notebook is step-by-step tutorial to help get started to develop a bundle package, which contains a config file to construct the training pipeline and also have a `metadata.json` file to define the metadata information.\n",
13+
"This notebook is a step-by-step tutorial to help get started to develop a bundle package, which contains a config file to construct the training pipeline and also has a `metadata.json` file to define the metadata information.\n",
1414
"\n",
15-
"This notebook mainly contains below sections:\n",
15+
"This notebook mainly contains the below sections:\n",
1616
"- Define a training config with `JSON` or `YAML` format\n",
1717
"- Execute training based on bundle scripts and configs\n",
1818
"- Hybrid programming with config and python code\n",
@@ -21,7 +21,6 @@
2121
"- Instantiate a python object from a dictionary config with `_target_` indicating class or function name or module path.\n",
2222
"- Execute python expression from a string config with the `$` syntax.\n",
2323
"- Refer to other python object with the `@` syntax.\n",
24-
"- Require other independent config items to execute or instantiate first with the `_requires_` syntax.\n",
2524
"- Macro text replacement with the `%` syntax to simplify the config content.\n",
2625
"- Leverage the `_disabled_` syntax to tune or debug different components.\n",
2726
"- Override config content at runtime.\n",
@@ -144,13 +143,13 @@
144143
"source": [
145144
"## Define train config - Set imports and input / output environments\n",
146145
"\n",
147-
"Now let's start to define the config file for a regular training task. MONAI bundle support both `JSON` and `YAML` format, here we use `JSON` as example.\n",
146+
"Now let's start to define the config file for a regular training task. MONAI bundles support both `JSON` and `YAML` format, here we use `JSON` as the example.\n",
148147
"\n",
149148
"According to the predefined syntax of MONAI bundle, `$` indicates an expression to evaluate, `@` refers to another object in the config content. For more details about the syntax in bundle config, please check: https://docs.monai.io/en/latest/config_syntax.html.\n",
150149
"\n",
151-
"Please note that MONAI bundle doesn't require any hard-code logic in the config, so users can define the config content in any structure.\n",
150+
"Please note that a MONAI bundle doesn't require any hard-coded logic in the config, so users can define the config content in any structure.\n",
152151
"\n",
153-
"For the first step, import `os` and `glob` to use in the expressions (start with `$`). Then define input / output environments and enable `cudnn.benchmark` for better performance."
152+
"For the first step, import `os` and `glob` to use in the expressions (start with `$`), then define input / output environments and enable `cudnn.benchmark` for better performance."
154153
]
155154
},
156155
{
@@ -164,8 +163,6 @@
164163
" \"$import os\",\n",
165164
" \"$import ignite\"\n",
166165
" ],\n",
167-
" \"determinism\": \"$monai.utils.set_determinism(seed=123)\",\n",
168-
" \"cudnn_opt\": \"$setattr(torch.backends.cudnn, 'benchmark', True)\",\n",
169166
" \"device\": \"$torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')\",\n",
170167
" \"ckpt_path\": \"/workspace/data/models/model.pt\",\n",
171168
" \"dataset_dir\": \"/workspace/data/Task09_Spleen\",\n",
@@ -325,7 +322,9 @@
325322
"cell_type": "markdown",
326323
"metadata": {},
327324
"source": [
328-
"The train and validation image file names are organized into a list of dictionaries."
325+
"The train and validation image file names are organized into a list of dictionaries.\n",
326+
"\n",
327+
"Here we use `dataset` instance as 1 argument of `dataloader` by the `@` syntax, and please note that `\"#\"` in the reference id are interpreted as special characters to go one level further into the nested config structures. For example: `\"dataset\": \"@train#dataset\"`."
329328
]
330329
},
331330
{
@@ -430,8 +429,6 @@
430429
"\n",
431430
"Here we use MONAI engine `SupervisedTrainer` to execute a regular training.\n",
432431
"\n",
433-
"`determinism` and `cudnn_opt` are not args of the trainer, but should execute them before training, so here mark them in the `_requires_` field.\n",
434-
"\n",
435432
"If users have customized logic, then can put the logic in the `iteration_update` arg or implement their own `trainer` in python code and set `_target_` to the class directly."
436433
]
437434
},
@@ -442,7 +439,6 @@
442439
"```json\n",
443440
"\"trainer\": {\n",
444441
" \"_target_\": \"SupervisedTrainer\",\n",
445-
" \"_requires_\": [\"@determinism\", \"@cudnn_opt\"],\n",
446442
" \"max_epochs\": 100,\n",
447443
" \"device\": \"@device\",\n",
448444
" \"train_data_loader\": \"@train#dataloader\",\n",
@@ -499,7 +495,7 @@
499495
"source": [
500496
"## Define metadata information\n",
501497
"\n",
502-
"Optinally, we can define a `metadata` file in the bundle, which contains the metadata information relating to the model, including what the shape and format of inputs and outputs are, what the meaning of the outputs are, what type of model is present, and other information. The structure is a dictionary containing a defined set of keys with additional user-specified keys.\n",
498+
"We can define a `metadata` file in the bundle, which contains the metadata information relating to the model, including what the shape and format of inputs and outputs are, what the meaning of the outputs are, what type of model is present, and other information. The structure is a dictionary containing a defined set of keys with additional user-specified keys.\n",
503499
"\n",
504500
"A typical `metadata` example is available: \n",
505501
"https://github.com/Project-MONAI/tutorials/blob/master/modules/bundles/spleen_segmentation/configs/metadata.json"
@@ -513,14 +509,29 @@
513509
"\n",
514510
"There are several predefined scripts in MONAI bundle module to help execute `regular training`, `metadata verification base on schema`, `network input / output verification`, `export to TorchScript model`, etc.\n",
515511
"\n",
516-
"Here we leverage the `run` script and specify the ID of trainer in the config."
512+
"Here we leverage the `run` script and specify the ID of trainer in the config.\n",
513+
"\n",
514+
"Just define the entry point expressions in the config to execute in order, and specify the `runner_id` in CLI script."
517515
]
518516
},
519517
{
520518
"cell_type": "markdown",
521519
"metadata": {},
522520
"source": [
523-
"`python -m monai.bundle run \"'train#trainer'\" --config_file configs/train.json`"
521+
"```json\n",
522+
"\"training\": [\n",
523+
" \"$monai.utils.set_determinism(seed=123)\",\n",
524+
" \"$setattr(torch.backends.cudnn, 'benchmark', True)\",\n",
525+
" \"$@train#trainer.run()\"\n",
526+
"]\n",
527+
"```"
528+
]
529+
},
530+
{
531+
"cell_type": "markdown",
532+
"metadata": {},
533+
"source": [
534+
"`python -m monai.bundle run training --config_file configs/train.json`"
524535
]
525536
},
526537
{
@@ -538,7 +549,7 @@
538549
"cell_type": "markdown",
539550
"metadata": {},
540551
"source": [
541-
"`python -m monai.bundle run \"'train#trainer'\" --config_file configs/train.json --device \"\\$torch.device('cuda:1')\"`"
552+
"`python -m monai.bundle run training --config_file configs/train.json --device \"\\$torch.device('cuda:1')\"`"
542553
]
543554
},
544555
{
@@ -552,7 +563,7 @@
552563
"cell_type": "markdown",
553564
"metadata": {},
554565
"source": [
555-
"`python -m monai.bundle run \"'train#trainer'\" --config_file configs/train.json --network \"%configs/test.json#network\"`"
566+
"`python -m monai.bundle run training --config_file configs/train.json --network \"%configs/test.json#network\"`"
556567
]
557568
},
558569
{
@@ -561,8 +572,9 @@
561572
"source": [
562573
"## Hybrid programming with config and python code\n",
563574
"\n",
564-
"MONAI bundle is flexible to support customized logic, there are several ways to achieve that:\n",
565-
"- If defining own components like transform, loss, trainer, etc. in a python file, just use its module path in `_target_`.\n",
575+
"A MONAI bundle supports flexible customized logic, there are several ways to achieve this:\n",
576+
"\n",
577+
"- If defining own components like transform, loss, trainer, etc. in a python file, just use its module path in `_target_` within the config file.\n",
566578
"- Parse the config in your own python program and do lazy instantiation with customized logic.\n",
567579
"\n",
568580
"Here we show an example to parse the config in python code and execute the training."
Lines changed: 58 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,58 @@
1+
{
2+
"validate#postprocessing":{
3+
"_target_": "Compose",
4+
"transforms": [
5+
{
6+
"_target_": "Activationsd",
7+
"keys": "pred",
8+
"softmax": true
9+
},
10+
{
11+
"_target_": "Invertd",
12+
"keys": ["pred", "label"],
13+
"transform": "@validate#preprocessing",
14+
"orig_keys": "image",
15+
"meta_key_postfix": "meta_dict",
16+
"nearest_interp": [false, true],
17+
"to_tensor": true
18+
},
19+
{
20+
"_target_": "AsDiscreted",
21+
"keys": ["pred", "label"],
22+
"argmax": [true, false],
23+
"to_onehot": 2
24+
},
25+
{
26+
"_target_": "SaveImaged",
27+
"keys": "pred",
28+
"meta_keys": "pred_meta_dict",
29+
"output_dir": "@output_dir",
30+
"resample": false,
31+
"squeeze_end_dims": true
32+
}
33+
]
34+
},
35+
"validate#handlers": [
36+
{
37+
"_target_": "CheckpointLoader",
38+
"load_path": "$@ckpt_dir + '/model.pt'",
39+
"load_dict": {"model": "@network"}
40+
},
41+
{
42+
"_target_": "StatsHandler",
43+
"iteration_log": false
44+
},
45+
{
46+
"_target_": "MetricsSaver",
47+
"save_dir": "@output_dir",
48+
"metrics": ["val_mean_dice", "val_acc"],
49+
"metric_details": ["val_mean_dice"],
50+
"batch_transform": "$monai.handlers.from_engine(['image_meta_dict'])",
51+
"summary_ops": "*"
52+
}
53+
],
54+
"evaluating": [
55+
"$setattr(torch.backends.cudnn, 'benchmark', True)",
56+
"$@validate#evaluator.run()"
57+
]
58+
}

modules/bundles/spleen_segmentation/configs/inference.json

Lines changed: 10 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -3,12 +3,11 @@
33
"$import glob",
44
"$import os"
55
],
6-
"cudnn_opt": "$setattr(torch.backends.cudnn, 'benchmark', True)",
7-
"device": "$torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')",
8-
"ckpt_path": "/workspace/data/tutorials/modules/bundles/spleen_segmentation/models/model.pt",
9-
"download_ckpt": "$monai.apps.utils.download_url('https://huggingface.co/MONAI/example_spleen_segmentation/resolve/main/model.pt', @ckpt_path)",
6+
"bundle_root": "/workspace/data/tutorials/modules/bundles/spleen_segmentation",
7+
"output_dir": "$@bundle_root + '/eval'",
108
"dataset_dir": "/workspace/data/Task09_Spleen",
119
"datalist": "$list(sorted(glob.glob(@dataset_dir + '/imagesTs/*.nii.gz')))",
10+
"device": "$torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')",
1211
"network_def": {
1312
"_target_": "UNet",
1413
"spatial_dims": 3,
@@ -101,16 +100,14 @@
101100
"_target_": "SaveImaged",
102101
"keys": "pred",
103102
"meta_keys": "pred_meta_dict",
104-
"output_dir": "eval"
103+
"output_dir": "@output_dir"
105104
}
106105
]
107106
},
108107
"handlers": [
109108
{
110109
"_target_": "CheckpointLoader",
111-
"_requires_": "@download_ckpt",
112-
"_disabled_": "$not os.path.exists(@ckpt_path)",
113-
"load_path": "@ckpt_path",
110+
"load_path": "$@bundle_root + '/models/model.pt'",
114111
"load_dict": {"model": "@network"}
115112
},
116113
{
@@ -120,13 +117,16 @@
120117
],
121118
"evaluator": {
122119
"_target_": "SupervisedEvaluator",
123-
"_requires_": "@cudnn_opt",
124120
"device": "@device",
125121
"val_data_loader": "@dataloader",
126122
"network": "@network",
127123
"inferer": "@inferer",
128124
"postprocessing": "@postprocessing",
129125
"val_handlers": "@handlers",
130126
"amp": true
131-
}
127+
},
128+
"evaluating": [
129+
"$setattr(torch.backends.cudnn, 'benchmark', True)",
130+
131+
]
132132
}
Lines changed: 34 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,34 @@
1+
{
2+
"device": "$torch.device(f'cuda:{dist.get_rank()}')",
3+
"network": {
4+
"_target_": "torch.nn.parallel.DistributedDataParallel",
5+
"module": "$@network_def.to(@device)",
6+
"device_ids": ["@device"]
7+
},
8+
"train#sampler": {
9+
"_target_": "DistributedSampler",
10+
"dataset": "@train#dataset",
11+
"even_divisible": true,
12+
"shuffle": true
13+
},
14+
"train#dataloader#sampler": "@train#sampler",
15+
"train#dataloader#shuffle": false,
16+
"train#trainer#train_handlers": "$@train#handlers[: 1 if dist.get_rank() > 0 else None]",
17+
"validate#sampler": {
18+
"_target_": "DistributedSampler",
19+
"dataset": "@validate#dataset",
20+
"even_divisible": false,
21+
"shuffle": false
22+
},
23+
"validate#dataloader#sampler": "@validate#sampler",
24+
"validate#evaluator#val_handlers": "$None if dist.get_rank() > 0 else @validate#handlers",
25+
"training": [
26+
"$import torch.distributed as dist",
27+
"$dist.init_process_group(backend='nccl')",
28+
"$torch.cuda.set_device(@device)",
29+
"$monai.utils.set_determinism(seed=123)",
30+
"$setattr(torch.backends.cudnn, 'benchmark', True)",
31+
"$@train#trainer.run()",
32+
"$dist.destroy_process_group()"
33+
]
34+
}

modules/bundles/spleen_segmentation/configs/train.json

Lines changed: 13 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -4,13 +4,13 @@
44
"$import os",
55
"$import ignite"
66
],
7-
"determinism": "$monai.utils.set_determinism(seed=123)",
8-
"cudnn_opt": "$setattr(torch.backends.cudnn, 'benchmark', True)",
9-
"device": "$torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')",
10-
"ckpt_dir": "/workspace/data/tutorials/modules/bundles/spleen_segmentation/models",
7+
"bundle_root": "/workspace/data/tutorials/modules/bundles/spleen_segmentation",
8+
"ckpt_dir": "$@bundle_root + '/models'",
9+
"output_dir": "$@bundle_root + '/eval'",
1110
"dataset_dir": "/workspace/data/Task09_Spleen",
1211
"images": "$list(sorted(glob.glob(@dataset_dir + '/imagesTr/*.nii.gz')))",
1312
"labels": "$list(sorted(glob.glob(@dataset_dir + '/labelsTr/*.nii.gz')))",
13+
"device": "$torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')",
1414
"network_def": {
1515
"_target_": "UNet",
1616
"spatial_dims": 3,
@@ -94,7 +94,7 @@
9494
"_target_": "DataLoader",
9595
"dataset": "@train#dataset",
9696
"batch_size": 2,
97-
"shuffle": false,
97+
"shuffle": true,
9898
"num_workers": 4
9999
},
100100
"inferer": {
@@ -130,7 +130,7 @@
130130
},
131131
{
132132
"_target_": "TensorBoardStatsHandler",
133-
"log_dir": "eval",
133+
"log_dir": "@output_dir",
134134
"tag_name": "train_loss",
135135
"output_transform": "$monai.handlers.from_engine(['loss'], first=True)"
136136
}
@@ -143,7 +143,6 @@
143143
},
144144
"trainer": {
145145
"_target_": "SupervisedTrainer",
146-
"_requires_": ["@determinism", "@cudnn_opt"],
147146
"max_epochs": 100,
148147
"device": "@device",
149148
"train_data_loader": "@train#dataloader",
@@ -196,7 +195,7 @@
196195
},
197196
{
198197
"_target_": "TensorBoardStatsHandler",
199-
"log_dir": "eval",
198+
"log_dir": "@output_dir",
200199
"iteration_log": false
201200
},
202201
{
@@ -232,5 +231,10 @@
232231
"val_handlers": "@validate#handlers",
233232
"amp": true
234233
}
235-
}
234+
},
235+
"training": [
236+
"$monai.utils.set_determinism(seed=123)",
237+
"$setattr(torch.backends.cudnn, 'benchmark', True)",
238+
"$@train#trainer.run()"
239+
]
236240
}

modules/bundles/spleen_segmentation/docs/README.md

Lines changed: 14 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -26,13 +26,25 @@ Mean Dice = 0.96
2626
Execute training:
2727

2828
```
29-
python -m monai.bundle run "'train#trainer'" --meta_file configs/metadata.json --config_file configs/train.json --logging_file configs/logging.conf
29+
python -m monai.bundle run training --meta_file configs/metadata.json --config_file configs/train.json --logging_file configs/logging.conf
30+
```
31+
32+
Override the `train` config to execute multi-GPU training:
33+
34+
```
35+
torchrun --standalone --nnodes=1 --nproc_per_node=2 -m monai.bundle run training --meta_file configs/metadata.json --config_file "['configs/train.json','configs/multi_gpu_train.json']" --logging_file configs/logging.conf
36+
```
37+
38+
Override the `train` config to execute evaluation with the trained model:
39+
40+
```
41+
python -m monai.bundle run evaluating --meta_file configs/metadata.json --config_file "['configs/train.json','configs/evaluate.json']" --logging_file configs/logging.conf
3042
```
3143

3244
Execute inference:
3345

3446
```
35-
python -m monai.bundle run evaluator --meta_file configs/metadata.json --config_file configs/inference.json --logging_file configs/logging.conf
47+
python -m monai.bundle run evaluating --meta_file configs/metadata.json --config_file configs/inference.json --logging_file configs/logging.conf
3648
```
3749

3850
Verify the metadata format:

0 commit comments

Comments
 (0)