Skip to content

Commit e8f8b26

Browse files
committed
[DLMED] draft config
Signed-off-by: Nic Ma <[email protected]>
1 parent 6382e8a commit e8f8b26

File tree

2 files changed

+264
-4
lines changed

2 files changed

+264
-4
lines changed

modules/bundles/spleen_segmentation/configs/train.json

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -7,8 +7,8 @@
77
"determinism": "$monai.utils.set_determinism(seed=123)",
88
"cudnn_opt": "$setattr(torch.backends.cudnn, 'benchmark', True)",
99
"device": "$torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')",
10-
"ckpt_dir": "/workspace/data/tutorials/modules/bundles/spleen_segmentation/models",
11-
"dataset_dir": "/workspace/data/Task09_Spleen",
10+
"ckpt_dir": "/workspace/data/medical/tutorials/modules/bundles/spleen_segmentation/models",
11+
"dataset_dir": "/workspace/data/medical/Task09_Spleen",
1212
"images": "$list(sorted(glob.glob(@dataset_dir + '/imagesTr/*.nii.gz')))",
1313
"labels": "$list(sorted(glob.glob(@dataset_dir + '/labelsTr/*.nii.gz')))",
1414
"network_def": {
@@ -94,7 +94,7 @@
9494
"_target_": "DataLoader",
9595
"dataset": "@train#dataset",
9696
"batch_size": 2,
97-
"shuffle": false,
97+
"shuffle": true,
9898
"num_workers": 4
9999
},
100100
"inferer": {
@@ -143,7 +143,7 @@
143143
},
144144
"trainer": {
145145
"_target_": "SupervisedTrainer",
146-
"_requires_": ["@determinism", "@cudnn_opt"],
146+
"_requires_": ["@ddp_init", "@determinism", "@cudnn_opt"],
147147
"max_epochs": 100,
148148
"device": "@device",
149149
"train_data_loader": "@train#dataloader",
Lines changed: 260 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,260 @@
1+
{
2+
"imports": [
3+
"$import glob",
4+
"$import os",
5+
"$import ignite",
6+
"$import torch.distributed as dist"
7+
],
8+
"ddp_init": "$dist.init_process_group(backend='nccl', init_method='env://')",
9+
"determinism": "$monai.utils.set_determinism(seed=123)",
10+
"cudnn_opt": "$setattr(torch.backends.cudnn, 'benchmark', True)",
11+
"device": "$torch.device(f'cuda:{dist.get_rank()}')",
12+
"set_device": "$torch.cuda.set_device(@device)",
13+
"ckpt_dir": "/workspace/data/medical/tutorials/modules/bundles/spleen_segmentation/models",
14+
"dataset_dir": "/workspace/data/medical/Task09_Spleen",
15+
"images": "$list(sorted(glob.glob(@dataset_dir + '/imagesTr/*.nii.gz')))",
16+
"labels": "$list(sorted(glob.glob(@dataset_dir + '/labelsTr/*.nii.gz')))",
17+
"network_def": {
18+
"_target_": "UNet",
19+
"spatial_dims": 3,
20+
"in_channels": 1,
21+
"out_channels": 2,
22+
"channels": [16, 32, 64, 128, 256],
23+
"strides": [2, 2, 2, 2],
24+
"num_res_units": 2,
25+
"norm": "batch"
26+
},
27+
"gpu_net": "$@network_def.to(@device)",
28+
"loss": {
29+
"_target_": "DiceCELoss",
30+
"to_onehot_y": true,
31+
"softmax": true,
32+
"squared_pred": true,
33+
"batch": true
34+
},
35+
"optimizer": {
36+
"_target_": "torch.optim.Adam",
37+
"params": "$@gpu_net.parameters()",
38+
"lr": 1e-4
39+
},
40+
"network": {
41+
"_target_": "torch.nn.parallel.DistributedDataParallel",
42+
"module": "@gpu_net",
43+
"device_ids": ["@device"]
44+
},
45+
"train": {
46+
"preprocessing": {
47+
"_target_": "Compose",
48+
"transforms": [
49+
{
50+
"_target_": "LoadImaged",
51+
"keys": ["image", "label"]
52+
},
53+
{
54+
"_target_": "EnsureChannelFirstd",
55+
"keys": ["image", "label"]
56+
},
57+
{
58+
"_target_": "Orientationd",
59+
"keys": ["image", "label"],
60+
"axcodes": "RAS"
61+
},
62+
{
63+
"_target_": "Spacingd",
64+
"keys": ["image", "label"],
65+
"pixdim": [1.5, 1.5, 2.0],
66+
"mode": ["bilinear", "nearest"]
67+
},
68+
{
69+
"_target_": "ScaleIntensityRanged",
70+
"keys": "image",
71+
"a_min": -57,
72+
"a_max": 164,
73+
"b_min": 0,
74+
"b_max": 1,
75+
"clip": true
76+
},
77+
{
78+
"_target_": "RandCropByPosNegLabeld",
79+
"keys": ["image", "label"],
80+
"label_key": "label",
81+
"spatial_size": [96, 96, 96],
82+
"pos": 1,
83+
"neg": 1,
84+
"num_samples": 4,
85+
"image_key": "image",
86+
"image_threshold": 0
87+
},
88+
{
89+
"_target_": "EnsureTyped",
90+
"keys": ["image", "label"]
91+
}
92+
]
93+
},
94+
"dataset": {
95+
"_target_": "CacheDataset",
96+
"data": "$[{'image': i, 'label': l} for i, l in zip(@images[:-9], @labels[:-9])]",
97+
"transform": "@train#preprocessing",
98+
"cache_rate": 1.0,
99+
"num_workers": 4
100+
},
101+
"sampler": {
102+
"_target_": "DistributedSampler",
103+
"dataset": "@train#dataset",
104+
"even_divisible": true,
105+
"shuffle": true
106+
},
107+
"dataloader": {
108+
"_target_": "DataLoader",
109+
"dataset": "@train#dataset",
110+
"sampler": "@train#sampler",
111+
"batch_size": 2,
112+
"shuffle": false,
113+
"num_workers": 4
114+
},
115+
"inferer": {
116+
"_target_": "SimpleInferer"
117+
},
118+
"postprocessing": {
119+
"_target_": "Compose",
120+
"transforms": [
121+
{
122+
"_target_": "Activationsd",
123+
"keys": "pred",
124+
"softmax": true
125+
},
126+
{
127+
"_target_": "AsDiscreted",
128+
"keys": ["pred", "label"],
129+
"argmax": [true, false],
130+
"to_onehot": 2
131+
}
132+
]
133+
},
134+
"handlers": [
135+
{
136+
"_target_": "ValidationHandler",
137+
"validator": "@validate#evaluator",
138+
"epoch_level": true,
139+
"interval": 5
140+
},
141+
{
142+
"_target_": "StatsHandler",
143+
"_disabled_": "$dist.get_rank() > 0",
144+
"tag_name": "train_loss",
145+
"output_transform": "$monai.handlers.from_engine(['loss'], first=True)"
146+
},
147+
{
148+
"_target_": "TensorBoardStatsHandler",
149+
"_disabled_": "$dist.get_rank() > 0",
150+
"log_dir": "eval",
151+
"tag_name": "train_loss",
152+
"output_transform": "$monai.handlers.from_engine(['loss'], first=True)"
153+
}
154+
],
155+
"key_metric": {
156+
"train_accuracy": {
157+
"_target_": "ignite.metrics.Accuracy",
158+
"output_transform": "$monai.handlers.from_engine(['pred', 'label'])"
159+
}
160+
},
161+
"trainer": {
162+
"_target_": "SupervisedTrainer",
163+
"_requires_": ["@ddp_init", "@set_device", "@determinism", "@cudnn_opt"],
164+
"max_epochs": 100,
165+
"device": "@device",
166+
"train_data_loader": "@train#dataloader",
167+
"network": "@network",
168+
"loss_function": "@loss",
169+
"optimizer": "@optimizer",
170+
"inferer": "@train#inferer",
171+
"postprocessing": "@train#postprocessing",
172+
"key_train_metric": "@train#key_metric",
173+
"train_handlers": "@train#handlers",
174+
"amp": true
175+
}
176+
},
177+
"validate": {
178+
"preprocessing": {
179+
"_target_": "Compose",
180+
"transforms": [
181+
"%train#preprocessing#transforms#0",
182+
"%train#preprocessing#transforms#1",
183+
"%train#preprocessing#transforms#2",
184+
"%train#preprocessing#transforms#3",
185+
"%train#preprocessing#transforms#4",
186+
"%train#preprocessing#transforms#6"
187+
]
188+
},
189+
"dataset": {
190+
"_target_": "CacheDataset",
191+
"data": "$[{'image': i, 'label': l} for i, l in zip(@images[-9:], @labels[-9:])]",
192+
"transform": "@validate#preprocessing",
193+
"cache_rate": 1.0
194+
},
195+
"sampler": {
196+
"_target_": "DistributedSampler",
197+
"dataset": "@validate#dataset",
198+
"even_divisible": false,
199+
"shuffle": false
200+
},
201+
"dataloader": {
202+
"_target_": "DataLoader",
203+
"dataset": "@validate#dataset",
204+
"sampler": "@validate#sampler",
205+
"batch_size": 1,
206+
"shuffle": false,
207+
"num_workers": 4
208+
},
209+
"inferer": {
210+
"_target_": "SlidingWindowInferer",
211+
"roi_size": [96, 96, 96],
212+
"sw_batch_size": 4,
213+
"overlap": 0.5
214+
},
215+
"postprocessing": "%train#postprocessing",
216+
"handlers": [
217+
{
218+
"_target_": "StatsHandler",
219+
"iteration_log": false
220+
},
221+
{
222+
"_target_": "TensorBoardStatsHandler",
223+
"log_dir": "eval",
224+
"iteration_log": false
225+
},
226+
{
227+
"_target_": "CheckpointSaver",
228+
"save_dir": "@ckpt_dir",
229+
"save_dict": {"model": "@network"},
230+
"save_key_metric": true,
231+
"key_metric_filename": "model.pt"
232+
}
233+
],
234+
"key_metric": {
235+
"val_mean_dice": {
236+
"_target_": "MeanDice",
237+
"include_background": false,
238+
"output_transform": "$monai.handlers.from_engine(['pred', 'label'])"
239+
}
240+
},
241+
"additional_metrics": {
242+
"val_accuracy": {
243+
"_target_": "ignite.metrics.Accuracy",
244+
"output_transform": "$monai.handlers.from_engine(['pred', 'label'])"
245+
}
246+
},
247+
"evaluator": {
248+
"_target_": "SupervisedEvaluator",
249+
"device": "@device",
250+
"val_data_loader": "@validate#dataloader",
251+
"network": "@network",
252+
"inferer": "@validate#inferer",
253+
"postprocessing": "@validate#postprocessing",
254+
"key_val_metric": "@validate#key_metric",
255+
"additional_metrics": "@validate#additional_metrics",
256+
"val_handlers": "$@validate#handlers if dist.get_rank() > 0 else None",
257+
"amp": true
258+
}
259+
}
260+
}

0 commit comments

Comments
 (0)