Skip to content

Commit 405291f

Browse files
committed
fix #1850
Signed-off-by: YunLiu <[email protected]>
1 parent 1ad4540 commit 405291f

File tree

5 files changed

+96
-40
lines changed

5 files changed

+96
-40
lines changed

generation/maisi/maisi_diff_unet_training_tutorial.ipynb

Lines changed: 62 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -57,10 +57,47 @@
5757
},
5858
{
5959
"cell_type": "code",
60-
"execution_count": null,
60+
"execution_count": 2,
6161
"id": "e3bf0346",
6262
"metadata": {},
63-
"outputs": [],
63+
"outputs": [
64+
{
65+
"name": "stdout",
66+
"output_type": "stream",
67+
"text": [
68+
"MONAI version: 1.4.0rc10\n",
69+
"Numpy version: 1.24.4\n",
70+
"Pytorch version: 2.5.0a0+872d972e41.nv24.08.01\n",
71+
"MONAI flags: HAS_EXT = False, USE_COMPILED = False, USE_META_DICT = False\n",
72+
"MONAI rev id: cac21f6936a2e8d6e4e57e4e958f8e32aae1585e\n",
73+
"MONAI __file__: /usr/local/lib/python3.10/dist-packages/monai/__init__.py\n",
74+
"\n",
75+
"Optional dependencies:\n",
76+
"Pytorch Ignite version: 0.4.11\n",
77+
"ITK version: 5.4.0\n",
78+
"Nibabel version: 5.2.1\n",
79+
"scikit-image version: 0.23.2\n",
80+
"scipy version: 1.13.1\n",
81+
"Pillow version: 10.4.0\n",
82+
"Tensorboard version: 2.17.0\n",
83+
"gdown version: 5.2.0\n",
84+
"TorchVision version: 0.20.0a0\n",
85+
"tqdm version: 4.66.4\n",
86+
"lmdb version: 1.5.1\n",
87+
"psutil version: 5.9.8\n",
88+
"pandas version: 2.2.2\n",
89+
"einops version: 0.7.0\n",
90+
"transformers version: 4.40.2\n",
91+
"mlflow version: 2.16.0\n",
92+
"pynrrd version: 1.0.0\n",
93+
"clearml version: 1.16.3\n",
94+
"\n",
95+
"For details about installing the optional dependencies, please visit:\n",
96+
" https://docs.monai.io/en/latest/installation.html#installing-the-recommended-dependencies\n",
97+
"\n"
98+
]
99+
}
100+
],
64101
"source": [
65102
"from scripts.diff_model_setting import setup_logging\n",
66103
"import copy\n",
@@ -336,6 +373,8 @@
336373
" model_config_filepath,\n",
337374
" \"--model_def\",\n",
338375
" model_def_filepath,\n",
376+
" \"--num_gpus\",\n",
377+
" str(num_gpus),\n",
339378
"]\n",
340379
"\n",
341380
"run_torchrun(module, module_args, num_gpus=num_gpus)"
@@ -457,17 +496,17 @@
457496
"INFO:training:[config] num_train_timesteps -> 1000.\n",
458497
"INFO:training:num_files_train: 2\n",
459498
"INFO:training:Training from scratch.\n",
460-
"INFO:training:Scaling factor set to 0.89132159948349.\n",
461-
"INFO:training:scale_factor -> 0.89132159948349.\n",
499+
"INFO:training:Scaling factor set to 0.8903454542160034.\n",
500+
"INFO:training:scale_factor -> 0.8903454542160034.\n",
462501
"INFO:training:torch.set_float32_matmul_precision -> highest.\n",
463502
"INFO:training:Epoch 1, lr 0.0001.\n",
464-
"INFO:training:[2024-09-24 03:46:57] epoch 1, iter 1/2, loss: 0.7984, lr: 0.000100000000.\n",
465-
"INFO:training:[2024-09-24 03:46:58] epoch 1, iter 2/2, loss: 0.7911, lr: 0.000056250000.\n",
466-
"INFO:training:epoch 1 average loss: 0.7947.\n",
503+
"INFO:training:[2024-09-30 06:30:33] epoch 1, iter 1/2, loss: 0.7974, lr: 0.000100000000.\n",
504+
"INFO:training:[2024-09-30 06:30:33] epoch 1, iter 2/2, loss: 0.7939, lr: 0.000056250000.\n",
505+
"INFO:training:epoch 1 average loss: 0.7957.\n",
467506
"INFO:training:Epoch 2, lr 2.5e-05.\n",
468-
"INFO:training:[2024-09-24 03:46:59] epoch 2, iter 1/2, loss: 0.7910, lr: 0.000025000000.\n",
469-
"INFO:training:[2024-09-24 03:46:59] epoch 2, iter 2/2, loss: 0.7897, lr: 0.000006250000.\n",
470-
"INFO:training:epoch 2 average loss: 0.7903.\n",
507+
"INFO:training:[2024-09-30 06:30:35] epoch 2, iter 1/2, loss: 0.7902, lr: 0.000025000000.\n",
508+
"INFO:training:[2024-09-30 06:30:35] epoch 2, iter 2/2, loss: 0.7889, lr: 0.000006250000.\n",
509+
"INFO:training:epoch 2 average loss: 0.7895.\n",
471510
"\n"
472511
]
473512
}
@@ -484,6 +523,8 @@
484523
" model_config_filepath,\n",
485524
" \"--model_def\",\n",
486525
" model_def_filepath,\n",
526+
" \"--num_gpus\",\n",
527+
" str(num_gpus),\n",
487528
"]\n",
488529
"\n",
489530
"run_torchrun(module, module_args, num_gpus=num_gpus)"
@@ -518,24 +559,24 @@
518559
"output_type": "stream",
519560
"text": [
520561
"\n",
521-
"INFO:inference:Using cuda:0 of 1 with random seed: 62801\n",
562+
"INFO:inference:Using cuda:0 of 1 with random seed: 93612\n",
522563
"INFO:inference:[config] ckpt_filepath -> ./temp_work_dir/./models/diff_unet_ckpt.pt.\n",
523-
"INFO:inference:[config] random_seed -> 62801.\n",
564+
"INFO:inference:[config] random_seed -> 93612.\n",
524565
"INFO:inference:[config] output_prefix -> unet_3d.\n",
525566
"INFO:inference:[config] output_size -> (256, 256, 128).\n",
526567
"INFO:inference:[config] out_spacing -> (1.0, 1.0, 0.75).\n",
527568
"INFO:root:`controllable_anatomy_size` is not provided.\n",
528569
"INFO:inference:checkpoints ./temp_work_dir/./models/diff_unet_ckpt.pt loaded.\n",
529-
"INFO:inference:scale_factor -> 0.89132159948349.\n",
570+
"INFO:inference:scale_factor -> 0.8903454542160034.\n",
530571
"INFO:inference:num_downsample_level -> 4, divisor -> 4.\n",
531572
"INFO:inference:noise: cuda:0, torch.float32, <class 'torch.Tensor'>\n",
532573
"\n",
533574
" 0%| | 0/10 [00:00<?, ?it/s]\n",
534-
" 10%|███████▍ | 1/10 [00:00<00:02, 3.62it/s]\n",
535-
" 40%|█████████████████████████████▌ | 4/10 [00:00<00:00, 12.53it/s]\n",
536-
" 80%|███████████████████████████████████████████████████████████▏ | 8/10 [00:00<00:00, 19.54it/s]\n",
537-
"100%|█████████████████████████████████████████████████████████████████████████| 10/10 [00:00<00:00, 18.16it/s]\n",
538-
"INFO:inference:Saved ./temp_work_dir/./predictions/unet_3d_seed62801_size256x256x128_spacing1.00x1.00x0.75_20240924034721.nii.gz.\n",
575+
" 10%|███████▍ | 1/10 [00:00<00:02, 3.48it/s]\n",
576+
" 40%|█████████████████████████████▌ | 4/10 [00:00<00:00, 12.23it/s]\n",
577+
" 80%|███████████████████████████████████████████████████████████▏ | 8/10 [00:00<00:00, 19.26it/s]\n",
578+
"100%|█████████████████████████████████████████████████████████████████████████| 10/10 [00:00<00:00, 17.80it/s]\n",
579+
"INFO:inference:Saved ./temp_work_dir/./predictions/unet_3d_seed93612_size256x256x128_spacing1.00x1.00x0.75_20240930063144_rank0.nii.gz.\n",
539580
"\n"
540581
]
541582
}
@@ -552,6 +593,8 @@
552593
" model_config_filepath,\n",
553594
" \"--model_def\",\n",
554595
" model_def_filepath,\n",
596+
" \"--num_gpus\",\n",
597+
" str(num_gpus),\n",
555598
"]\n",
556599
"\n",
557600
"run_torchrun(module, module_args, num_gpus=num_gpus)\n",
@@ -562,7 +605,7 @@
562605
],
563606
"metadata": {
564607
"kernelspec": {
565-
"display_name": "Python 3 (ipykernel)",
608+
"display_name": "Python 3",
566609
"language": "python",
567610
"name": "python3"
568611
},

generation/maisi/scripts/diff_model_create_training_data.py

Lines changed: 6 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -160,7 +160,7 @@ def process_file(
160160

161161

162162
@torch.inference_mode()
163-
def diff_model_create_training_data(env_config_path: str, model_config_path: str, model_def_path: str) -> None:
163+
def diff_model_create_training_data(env_config_path: str, model_config_path: str, model_def_path: str, num_gpus: int) -> None:
164164
"""
165165
Create training data for the diffusion model.
166166
@@ -170,7 +170,7 @@ def diff_model_create_training_data(env_config_path: str, model_config_path: str
170170
model_def_path (str): Path to the model definition file.
171171
"""
172172
args = load_config(env_config_path, model_config_path, model_def_path)
173-
local_rank, world_size, device = initialize_distributed()
173+
local_rank, world_size, device = initialize_distributed(num_gpus=num_gpus)
174174
logger = setup_logging("creating training data")
175175
logger.info(f"Using device {device}")
176176

@@ -224,6 +224,9 @@ def diff_model_create_training_data(env_config_path: str, model_config_path: str
224224
parser.add_argument(
225225
"--model_def", type=str, default="./configs/config_maisi.json", help="Path to model definition file"
226226
)
227+
parser.add_argument(
228+
"--num_gpus", type=int, default=1, help="Number of GPUs to use for distributed training"
229+
)
227230

228231
args = parser.parse_args()
229-
diff_model_create_training_data(args.env_config, args.model_config, args.model_def)
232+
diff_model_create_training_data(args.env_config, args.model_config, args.model_def, args.num_gpus)

generation/maisi/scripts/diff_model_infer.py

Lines changed: 9 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -211,7 +211,7 @@ def save_image(
211211

212212

213213
@torch.inference_mode()
214-
def diff_model_infer(env_config_path: str, model_config_path: str, model_def_path: str) -> None:
214+
def diff_model_infer(env_config_path: str, model_config_path: str, model_def_path: str, num_gpus: int) -> None:
215215
"""
216216
Main function to run the diffusion model inference.
217217
@@ -221,7 +221,7 @@ def diff_model_infer(env_config_path: str, model_config_path: str, model_def_pat
221221
model_def_path (str): Path to the model definition file.
222222
"""
223223
args = load_config(env_config_path, model_config_path, model_def_path)
224-
local_rank, world_size, device = initialize_distributed()
224+
local_rank, world_size, device = initialize_distributed(num_gpus)
225225
logger = setup_logging("inference")
226226
random_seed = set_random_seed(
227227
args.diffusion_unet_inference["random_seed"] + local_rank
@@ -311,6 +311,12 @@ def diff_model_infer(env_config_path: str, model_config_path: str, model_def_pat
311311
default="./configs/config_maisi.json",
312312
help="Path to model definition file",
313313
)
314+
parser.add_argument(
315+
"--num_gpus",
316+
type=int,
317+
default=1,
318+
help="Number of GPUs to use for distributed inference",
319+
)
314320

315321
args = parser.parse_args()
316-
diff_model_infer(args.env_config, args.model_config, args.model_def)
322+
diff_model_infer(args.env_config, args.model_config, args.model_def, args.num_gpus)

generation/maisi/scripts/diff_model_setting.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -74,14 +74,14 @@ def load_config(env_config_path: str, model_config_path: str, model_def_path: st
7474
return args
7575

7676

77-
def initialize_distributed() -> tuple:
77+
def initialize_distributed(num_gpus) -> tuple:
7878
"""
7979
Initialize distributed training.
8080
8181
Returns:
8282
tuple: local_rank, world_size, and device.
8383
"""
84-
if torch.cuda.is_available() and torch.cuda.device_count() > 1:
84+
if torch.cuda.is_available() and num_gpus > 1:
8585
dist.init_process_group(backend="nccl", init_method="env://")
8686
local_rank = dist.get_rank()
8787
world_size = dist.get_world_size()

generation/maisi/scripts/diff_model_train.py

Lines changed: 17 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -108,14 +108,14 @@ def load_unet(args: argparse.Namespace, device: torch.device, logger: logging.Lo
108108
unet = define_instance(args, "diffusion_unet_def").to(device)
109109
unet = torch.nn.SyncBatchNorm.convert_sync_batchnorm(unet)
110110

111-
if torch.cuda.device_count() > 1:
111+
if dist.is_initialized():
112112
unet = DistributedDataParallel(unet, device_ids=[device], find_unused_parameters=True)
113113

114114
if args.existing_ckpt_filepath is None:
115115
logger.info("Training from scratch.")
116116
else:
117117
checkpoint_unet = torch.load(f"{args.existing_ckpt_filepath}", map_location=device)
118-
if torch.cuda.device_count() > 1:
118+
if dist.is_initialized():
119119
unet.module.load_state_dict(checkpoint_unet["unet_state_dict"], strict=True)
120120
else:
121121
unet.load_state_dict(checkpoint_unet["unet_state_dict"], strict=True)
@@ -143,8 +143,9 @@ def calculate_scale_factor(
143143
scale_factor = 1 / torch.std(z)
144144
logger.info(f"Scaling factor set to {scale_factor}.")
145145

146-
dist.barrier()
147-
dist.all_reduce(scale_factor, op=torch.distributed.ReduceOp.AVG)
146+
if dist.is_initialized():
147+
dist.barrier()
148+
dist.all_reduce(scale_factor, op=torch.distributed.ReduceOp.AVG)
148149
logger.info(f"scale_factor -> {scale_factor}.")
149150
return scale_factor
150151

@@ -271,7 +272,7 @@ def train_one_epoch(
271272
)
272273
)
273274

274-
if torch.cuda.device_count() > 1:
275+
if dist.is_initialized():
275276
dist.all_reduce(loss_torch, op=torch.distributed.ReduceOp.SUM)
276277

277278
return loss_torch
@@ -298,7 +299,7 @@ def save_checkpoint(
298299
ckpt_folder (str): Checkpoint folder path.
299300
args (argparse.Namespace): Configuration arguments.
300301
"""
301-
unet_state_dict = unet.module.state_dict() if torch.cuda.device_count() > 1 else unet.state_dict()
302+
unet_state_dict = unet.module.state_dict() if dist.is_initialized() else unet.state_dict()
302303
torch.save(
303304
{
304305
"epoch": epoch + 1,
@@ -311,7 +312,7 @@ def save_checkpoint(
311312
)
312313

313314

314-
def diff_model_train(env_config_path: str, model_config_path: str, model_def_path: str) -> None:
315+
def diff_model_train(env_config_path: str, model_config_path: str, model_def_path: str, num_gpus: int) -> None:
315316
"""
316317
Main function to train a diffusion model.
317318
@@ -321,7 +322,7 @@ def diff_model_train(env_config_path: str, model_config_path: str, model_def_pat
321322
model_def_path (str): Path to the model definition file.
322323
"""
323324
args = load_config(env_config_path, model_config_path, model_def_path)
324-
local_rank, world_size, device = initialize_distributed()
325+
local_rank, world_size, device = initialize_distributed(num_gpus)
325326
logger = setup_logging("training")
326327

327328
logger.info(f"Using {device} of {world_size}")
@@ -350,10 +351,10 @@ def diff_model_train(env_config_path: str, model_config_path: str, model_def_pat
350351
train_files.append(
351352
{"image": str_img, "top_region_index": str_info, "bottom_region_index": str_info, "spacing": str_info}
352353
)
353-
354-
train_files = partition_dataset(
355-
data=train_files, shuffle=True, num_partitions=dist.get_world_size(), even_divisible=True
356-
)[local_rank]
354+
if dist.is_initialized():
355+
train_files = partition_dataset(
356+
data=train_files, shuffle=True, num_partitions=dist.get_world_size(), even_divisible=True
357+
)[local_rank]
357358

358359
train_loader = prepare_data(
359360
train_files, device, args.diffusion_unet_train["cache_rate"], args.diffusion_unet_train["batch_size"]
@@ -429,6 +430,9 @@ def diff_model_train(env_config_path: str, model_config_path: str, model_def_pat
429430
parser.add_argument(
430431
"--model_def", type=str, default="./configs/config_maisi.json", help="Path to model definition file"
431432
)
433+
parser.add_argument(
434+
"--num_gpus", type=int, default=1, help="Number of GPUs to use for training"
435+
)
432436

433437
args = parser.parse_args()
434-
diff_model_train(args.env_config, args.model_config, args.model_def)
438+
diff_model_train(args.env_config, args.model_config, args.model_def, args.num_gpus)

0 commit comments

Comments
 (0)