Skip to content

Commit 72e3877

Browse files
dongyang0122dongy
andauthored
update tutorial (Project-MONAI#667)
Signed-off-by: dongy <[email protected]> Co-authored-by: dongy <[email protected]>
1 parent 69c937d commit 72e3877

File tree

2 files changed

+8
-8
lines changed

2 files changed

+8
-8
lines changed

automl/DiNTS/search_dints.py

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -332,14 +332,14 @@ def main():
332332
# train_ds_w = monai.data.Dataset(data=train_files_w, transform=train_transforms)
333333
# val_ds = monai.data.Dataset(data=val_files, transform=val_transforms)
334334

335-
train_loader_a = ThreadDataLoader(train_ds_a, num_workers=0, batch_size=num_images_per_batch, shuffle=True)
336-
train_loader_w = ThreadDataLoader(train_ds_w, num_workers=0, batch_size=num_images_per_batch, shuffle=True)
337-
val_loader = ThreadDataLoader(val_ds, num_workers=0, batch_size=1, shuffle=False)
335+
train_loader_a = ThreadDataLoader(train_ds_a, num_workers=4, batch_size=num_images_per_batch, shuffle=True)
336+
train_loader_w = ThreadDataLoader(train_ds_w, num_workers=4, batch_size=num_images_per_batch, shuffle=True)
337+
val_loader = ThreadDataLoader(val_ds, num_workers=4, batch_size=1, shuffle=False)
338338

339339
# DataLoader can be used as alternatives when ThreadDataLoader is less efficient.
340-
# train_loader_a = DataLoader(train_ds_a, batch_size=num_images_per_batch, shuffle=True, num_workers=2, pin_memory=torch.cuda.is_available())
341-
# train_loader_w = DataLoader(train_ds_w, batch_size=num_images_per_batch, shuffle=True, num_workers=2, pin_memory=torch.cuda.is_available())
342-
# val_loader = DataLoader(val_ds, batch_size=1, shuffle=False, num_workers=2, pin_memory=torch.cuda.is_available())
340+
# train_loader_a = DataLoader(train_ds_a, batch_size=num_images_per_batch, shuffle=True, num_workers=4, pin_memory=torch.cuda.is_available())
341+
# train_loader_w = DataLoader(train_ds_w, batch_size=num_images_per_batch, shuffle=True, num_workers=4, pin_memory=torch.cuda.is_available())
342+
# val_loader = DataLoader(val_ds, batch_size=1, shuffle=False, num_workers=4, pin_memory=torch.cuda.is_available())
343343

344344
dints_space = monai.networks.nets.TopologySearch(
345345
channel_mul=0.5,

automl/DiNTS/train_dints.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -318,8 +318,8 @@ def main():
318318
# train_ds = monai.data.Dataset(data=train_files, transform=train_transforms)
319319
# val_ds = monai.data.Dataset(data=val_files, transform=val_transforms)
320320

321-
train_loader = ThreadDataLoader(train_ds, num_workers=0, batch_size=num_images_per_batch, shuffle=True)
322-
val_loader = ThreadDataLoader(val_ds, num_workers=0, batch_size=1, shuffle=False)
321+
train_loader = ThreadDataLoader(train_ds, num_workers=8, batch_size=num_images_per_batch, shuffle=True)
322+
val_loader = ThreadDataLoader(val_ds, num_workers=4, batch_size=1, shuffle=False)
323323

324324
ckpt = torch.load(args.arch_ckpt)
325325
node_a = ckpt['node_a']

0 commit comments

Comments
 (0)