Skip to content

Commit a4c8563

Browse files
mingxin-zhengpre-commit-ci[bot]wyli
authored
fix pep 8 for all previously-excluded notebooks (#1137)
Signed-off-by: Mingxin Zheng <[email protected]> Fixes #1136 . ### Description The PR expands the scope of PEP 8 checks to cover some notebooks there were previously excluded ### Checks <!--- Put an `x` in all the boxes that apply, and remove the not applicable items --> - [x] Notebook runs automatically `./runner --no-run` Signed-off-by: Mingxin Zheng <[email protected]> Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com> Co-authored-by: Wenqi Li <[email protected]>
1 parent 44b54d0 commit a4c8563

File tree

15 files changed

+291
-264
lines changed

15 files changed

+291
-264
lines changed

3d_segmentation/spleen_segmentation_3d_visualization_basic.ipynb

Lines changed: 52 additions & 36 deletions
Original file line numberDiff line numberDiff line change
@@ -535,15 +535,15 @@
535535
"# standard PyTorch program style: create UNet, DiceLoss and Adam optimizer\n",
536536
"device = torch.device(\"cuda:0\")\n",
537537
"\n",
538-
"UNet_meatdata = dict(\n",
539-
" spatial_dims=3,\n",
540-
" in_channels=1,\n",
541-
" out_channels=2,\n",
542-
" channels=(16, 32, 64, 128, 256),\n",
543-
" strides=(2, 2, 2, 2),\n",
544-
" num_res_units=2,\n",
545-
" norm=Norm.BATCH\n",
546-
")\n",
538+
"UNet_meatdata = {\n",
539+
" \"spatial_dims\": 3,\n",
540+
" \"in_channels\": 1,\n",
541+
" \"out_channels\": 2,\n",
542+
" \"channels\": (16, 32, 64, 128, 256),\n",
543+
" \"strides\": (2, 2, 2, 2),\n",
544+
" \"num_res_units\": 2,\n",
545+
" \"norm\": Norm.BATCH,\n",
546+
"}\n",
547547
"\n",
548548
"model = UNet(**UNet_meatdata).to(device)\n",
549549
"loss_function = DiceLoss(to_onehot_y=True, softmax=True)\n",
@@ -554,7 +554,9 @@
554554
"Optimizer_metadata = {}\n",
555555
"for ind, param_group in enumerate(optimizer.param_groups):\n",
556556
" optim_meta_keys = list(param_group.keys())\n",
557-
" Optimizer_metadata[f'param_group_{ind}'] = {key: value for (key, value) in param_group.items() if 'params' not in key}"
557+
" Optimizer_metadata[f'param_group_{ind}'] = {\n",
558+
" key: value for (key, value) in param_group.items() if 'params' not in key\n",
559+
" }"
558560
]
559561
},
560562
{
@@ -612,26 +614,28 @@
612614
" loss.backward()\n",
613615
" optimizer.step()\n",
614616
" epoch_loss += loss.item()\n",
615-
" print(f\"{step}/{len(train_ds) // train_loader.batch_size}, \"\n",
616-
" f\"train_loss: {loss.item():.4f}\")\n",
617+
" print(\n",
618+
" f\"{step}/{len(train_ds) // train_loader.batch_size}, \"\n",
619+
" f\"train_loss: {loss.item():.4f}\"\n",
620+
" )\n",
617621
" # track batch loss metric\n",
618-
" aim_run.track(loss.item(), name=\"batch_loss\", context={'type':loss_type})\n",
619-
" \n",
622+
" aim_run.track(loss.item(), name=\"batch_loss\", context={'type': loss_type})\n",
623+
"\n",
620624
" epoch_loss /= step\n",
621625
" epoch_loss_values.append(epoch_loss)\n",
622-
" \n",
626+
"\n",
623627
" # track epoch loss metric\n",
624-
" aim_run.track(epoch_loss, name=\"epoch_loss\", context={'type':loss_type})\n",
628+
" aim_run.track(epoch_loss, name=\"epoch_loss\", context={'type': loss_type})\n",
625629
"\n",
626630
" print(f\"epoch {epoch + 1} average loss: {epoch_loss:.4f}\")\n",
627631
"\n",
628632
" if (epoch + 1) % val_interval == 0:\n",
629633
" if (epoch + 1) % val_interval * 2 == 0:\n",
630634
" # track model params and gradients\n",
631-
" track_params_dists(model,aim_run)\n",
635+
" track_params_dists(model, aim_run)\n",
632636
" # THIS SEGMENT TAKES RELATIVELY LONG (Advise Against it)\n",
633637
" track_gradients_dists(model, aim_run)\n",
634-
" \n",
638+
"\n",
635639
" model.eval()\n",
636640
" with torch.no_grad():\n",
637641
" for index, val_data in enumerate(val_loader):\n",
@@ -643,28 +647,35 @@
643647
" sw_batch_size = 4\n",
644648
" val_outputs = sliding_window_inference(\n",
645649
" val_inputs, roi_size, sw_batch_size, model)\n",
646-
" \n",
650+
"\n",
647651
" # tracking input, label and output images with Aim\n",
648652
" output = torch.argmax(val_outputs, dim=1)[0, :, :, slice_to_track].float()\n",
649-
" \n",
650-
" aim_run.track(aim.Image(val_inputs[0, 0, :, :, slice_to_track], \\\n",
651-
" caption=f'Input Image: {index}'), \\\n",
652-
" name='validation', context={'type':'input'})\n",
653-
" aim_run.track(aim.Image(val_labels[0, 0, :, :, slice_to_track], \\\n",
654-
" caption=f'Label Image: {index}'), \\\n",
655-
" name='validation', context={'type':'label'})\n",
656-
" aim_run.track(aim.Image(output, caption=f'Predicted Label: {index}'), \\\n",
657-
" name = 'predictions', context={'type':'labels'}) \n",
653+
"\n",
654+
" aim_run.track(\n",
655+
" aim.Image(val_inputs[0, 0, :, :, slice_to_track], caption=f'Input Image: {index}'),\n",
656+
" name='validation',\n",
657+
" context={'type': 'input'},\n",
658+
" )\n",
659+
" aim_run.track(\n",
660+
" aim.Image(val_labels[0, 0, :, :, slice_to_track], caption=f'Label Image: {index}'),\n",
661+
" name='validation',\n",
662+
" context={'type': 'label'},\n",
663+
" )\n",
664+
" aim_run.track(\n",
665+
" aim.Image(output, caption=f'Predicted Label: {index}'),\n",
666+
" name='predictions',\n",
667+
" context={'type': 'labels'},\n",
668+
" )\n",
658669
"\n",
659670
" val_outputs = [post_pred(i) for i in decollate_batch(val_outputs)]\n",
660671
" val_labels = [post_label(i) for i in decollate_batch(val_labels)]\n",
661672
" # compute metric for current iteration\n",
662673
" dice_metric(y_pred=val_outputs, y=val_labels)\n",
663-
" \n",
674+
"\n",
664675
" # aggregate the final mean dice result\n",
665676
" metric = dice_metric.aggregate().item()\n",
666677
" # track val metric\n",
667-
" aim_run.track(metric, name=\"val_metric\", context={'type':loss_type})\n",
678+
" aim_run.track(metric, name=\"val_metric\", context={'type': loss_type})\n",
668679
"\n",
669680
" # reset the status for next validation round\n",
670681
" dice_metric.reset()\n",
@@ -675,16 +686,16 @@
675686
" best_metric_epoch = epoch + 1\n",
676687
" torch.save(model.state_dict(), os.path.join(\n",
677688
" root_dir, \"best_metric_model.pth\"))\n",
678-
" \n",
689+
"\n",
679690
" best_model_log_message = f\"saved new best metric model at the {epoch+1}th epoch\"\n",
680691
" aim_run.track(aim.Text(best_model_log_message), name='best_model_log_message', epoch=epoch+1)\n",
681692
" print(best_model_log_message)\n",
682-
" \n",
693+
"\n",
683694
" message1 = f\"current epoch: {epoch + 1} current mean dice: {metric:.4f}\"\n",
684695
" message2 = f\"\\nbest mean dice: {best_metric:.4f} \"\n",
685696
" message3 = f\"at epoch: {best_metric_epoch}\"\n",
686-
" \n",
687-
" aim_run.track(aim.Text(message1 +\"\\n\" + message2 + message3), name='epoch_summary', epoch=epoch+1)\n",
697+
"\n",
698+
" aim_run.track(aim.Text(message1 + \"\\n\" + message2 + message3), name='epoch_summary', epoch=epoch+1)\n",
688699
" print(message1, message2, message3)\n"
689700
]
690701
},
@@ -903,7 +914,7 @@
903914
"provenance": []
904915
},
905916
"kernelspec": {
906-
"display_name": "Python 3 (ipykernel)",
917+
"display_name": "base",
907918
"language": "python",
908919
"name": "python3"
909920
},
@@ -917,7 +928,12 @@
917928
"name": "python",
918929
"nbconvert_exporter": "python",
919930
"pygments_lexer": "ipython3",
920-
"version": "3.8.13"
931+
"version": "3.8.13 | packaged by conda-forge | (default, Mar 25 2022, 06:04:10) \n[GCC 10.3.0]"
932+
},
933+
"vscode": {
934+
"interpreter": {
935+
"hash": "d4d1e4263499bec80672ea0156c357c1ee493ec2b1c70f0acce89fc37c4a6abe"
936+
}
921937
}
922938
},
923939
"nbformat": 4,

3d_segmentation/swin_unetr_brats21_segmentation_3d.ipynb

Lines changed: 40 additions & 40 deletions
Original file line numberDiff line numberDiff line change
@@ -251,7 +251,7 @@
251251
" json_data = json_data[key]\n",
252252
"\n",
253253
" for d in json_data:\n",
254-
" for k, v in d.items():\n",
254+
" for k in d:\n",
255255
" if isinstance(d[k], list):\n",
256256
" d[k] = [os.path.join(basedir, iv) for iv in d[k]]\n",
257257
" elif isinstance(d[k], str):\n",
@@ -559,17 +559,17 @@
559559
" acc_func(y_pred=val_output_convert, y=val_labels_list)\n",
560560
" acc, not_nans = acc_func.aggregate()\n",
561561
" run_acc.update(acc.cpu().numpy(), n=not_nans.cpu().numpy())\n",
562-
" Dice_TC = run_acc.avg[0]\n",
563-
" Dice_WT = run_acc.avg[1]\n",
564-
" Dice_ET = run_acc.avg[2]\n",
562+
" dice_tc = run_acc.avg[0]\n",
563+
" dice_wt = run_acc.avg[1]\n",
564+
" dice_et = run_acc.avg[2]\n",
565565
" print(\n",
566566
" \"Val {}/{} {}/{}\".format(epoch, max_epochs, idx, len(loader)),\n",
567-
" \", Dice_TC:\",\n",
568-
" Dice_TC,\n",
569-
" \", Dice_WT:\",\n",
570-
" Dice_WT,\n",
571-
" \", Dice_ET:\",\n",
572-
" Dice_ET,\n",
567+
" \", dice_tc:\",\n",
568+
" dice_tc,\n",
569+
" \", dice_wt:\",\n",
570+
" dice_wt,\n",
571+
" \", dice_et:\",\n",
572+
" dice_et,\n",
573573
" \", time {:.2f}s\".format(time.time() - start_time),\n",
574574
" )\n",
575575
" start_time = time.time()\n",
@@ -605,10 +605,10 @@
605605
"):\n",
606606
"\n",
607607
" val_acc_max = 0.0\n",
608-
" Dices_TC = []\n",
609-
" Dices_WT = []\n",
610-
" Dices_ET = []\n",
611-
" Dices_avg = []\n",
608+
" dices_tc = []\n",
609+
" dices_wt = []\n",
610+
" dices_et = []\n",
611+
" dices_avg = []\n",
612612
" loss_epochs = []\n",
613613
" trains_epoch = []\n",
614614
" for epoch in range(start_epoch, max_epochs):\n",
@@ -640,26 +640,26 @@
640640
" post_sigmoid=post_sigmoid,\n",
641641
" post_pred=post_pred,\n",
642642
" )\n",
643-
" Dice_TC = val_acc[0]\n",
644-
" Dice_WT = val_acc[1]\n",
645-
" Dice_ET = val_acc[2]\n",
643+
" dice_tc = val_acc[0]\n",
644+
" dice_wt = val_acc[1]\n",
645+
" dice_et = val_acc[2]\n",
646646
" val_avg_acc = np.mean(val_acc)\n",
647647
" print(\n",
648648
" \"Final validation stats {}/{}\".format(epoch, max_epochs - 1),\n",
649-
" \", Dice_TC:\",\n",
650-
" Dice_TC,\n",
651-
" \", Dice_WT:\",\n",
652-
" Dice_WT,\n",
653-
" \", Dice_ET:\",\n",
654-
" Dice_ET,\n",
649+
" \", dice_tc:\",\n",
650+
" dice_tc,\n",
651+
" \", dice_wt:\",\n",
652+
" dice_wt,\n",
653+
" \", dice_et:\",\n",
654+
" dice_et,\n",
655655
" \", Dice_Avg:\",\n",
656656
" val_avg_acc,\n",
657657
" \", time {:.2f}s\".format(time.time() - epoch_time),\n",
658658
" )\n",
659-
" Dices_TC.append(Dice_TC)\n",
660-
" Dices_WT.append(Dice_WT)\n",
661-
" Dices_ET.append(Dice_ET)\n",
662-
" Dices_avg.append(val_avg_acc)\n",
659+
" dices_tc.append(dice_tc)\n",
660+
" dices_wt.append(dice_wt)\n",
661+
" dices_et.append(dice_et)\n",
662+
" dices_avg.append(val_avg_acc)\n",
663663
" if val_avg_acc > val_acc_max:\n",
664664
" print(\"new best ({:.6f} --> {:.6f}). \".format(val_acc_max, val_avg_acc))\n",
665665
" val_acc_max = val_avg_acc\n",
@@ -672,10 +672,10 @@
672672
" print(\"Training Finished !, Best Accuracy: \", val_acc_max)\n",
673673
" return (\n",
674674
" val_acc_max,\n",
675-
" Dices_TC,\n",
676-
" Dices_WT,\n",
677-
" Dices_ET,\n",
678-
" Dices_avg,\n",
675+
" dices_tc,\n",
676+
" dices_wt,\n",
677+
" dices_et,\n",
678+
" dices_avg,\n",
679679
" loss_epochs,\n",
680680
" trains_epoch,\n",
681681
" )"
@@ -700,10 +700,10 @@
700700
"\n",
701701
"(\n",
702702
" val_acc_max,\n",
703-
" Dices_TC,\n",
704-
" Dices_WT,\n",
705-
" Dices_ET,\n",
706-
" Dices_avg,\n",
703+
" dices_tc,\n",
704+
" dices_wt,\n",
705+
" dices_et,\n",
706+
" dices_avg,\n",
707707
" loss_epochs,\n",
708708
" trains_epoch,\n",
709709
") = trainer(\n",
@@ -784,21 +784,21 @@
784784
"plt.subplot(1, 2, 2)\n",
785785
"plt.title(\"Val Mean Dice\")\n",
786786
"plt.xlabel(\"epoch\")\n",
787-
"plt.plot(trains_epoch, Dices_avg, color=\"green\")\n",
787+
"plt.plot(trains_epoch, dices_avg, color=\"green\")\n",
788788
"plt.show()\n",
789789
"plt.figure(\"train\", (18, 6))\n",
790790
"plt.subplot(1, 3, 1)\n",
791791
"plt.title(\"Val Mean Dice TC\")\n",
792792
"plt.xlabel(\"epoch\")\n",
793-
"plt.plot(trains_epoch, Dices_TC, color=\"blue\")\n",
793+
"plt.plot(trains_epoch, dices_tc, color=\"blue\")\n",
794794
"plt.subplot(1, 3, 2)\n",
795795
"plt.title(\"Val Mean Dice WT\")\n",
796796
"plt.xlabel(\"epoch\")\n",
797-
"plt.plot(trains_epoch, Dices_WT, color=\"brown\")\n",
797+
"plt.plot(trains_epoch, dices_wt, color=\"brown\")\n",
798798
"plt.subplot(1, 3, 3)\n",
799799
"plt.title(\"Val Mean Dice ET\")\n",
800800
"plt.xlabel(\"epoch\")\n",
801-
"plt.plot(trains_epoch, Dices_ET, color=\"purple\")\n",
801+
"plt.plot(trains_epoch, dices_et, color=\"purple\")\n",
802802
"plt.show()"
803803
]
804804
},
@@ -912,7 +912,7 @@
912912
"\n",
913913
"\n",
914914
"with torch.no_grad():\n",
915-
" for idx, batch_data in enumerate(test_loader):\n",
915+
" for batch_data in test_loader:\n",
916916
" image = batch_data[\"image\"].cuda()\n",
917917
" prob = torch.sigmoid(model_inferer_test(image))\n",
918918
" seg = prob[0].detach().cpu().numpy()\n",

3d_segmentation/swin_unetr_btcv_segmentation_3d.ipynb

Lines changed: 11 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -319,9 +319,9 @@
319319
"outputs": [],
320320
"source": [
321321
"data_dir = \"data/\"\n",
322-
"split_JSON = \"dataset_0.json\"\n",
322+
"split_json = \"dataset_0.json\"\n",
323323
"\n",
324-
"datasets = data_dir + split_JSON\n",
324+
"datasets = data_dir + split_json\n",
325325
"datalist = load_decathlon_datalist(datasets, True, \"training\")\n",
326326
"val_files = load_decathlon_datalist(datasets, True, \"validation\")\n",
327327
"train_ds = CacheDataset(\n",
@@ -496,7 +496,7 @@
496496
"def validation(epoch_iterator_val):\n",
497497
" model.eval()\n",
498498
" with torch.no_grad():\n",
499-
" for step, batch in enumerate(epoch_iterator_val):\n",
499+
" for batch in epoch_iterator_val:\n",
500500
" val_inputs, val_labels = (batch[\"image\"].cuda(), batch[\"label\"].cuda())\n",
501501
" with torch.cuda.amp.autocast():\n",
502502
" val_outputs = sliding_window_inference(val_inputs, (96, 96, 96), 4, model)\n",
@@ -536,10 +536,7 @@
536536
" scaler.step(optimizer)\n",
537537
" scaler.update()\n",
538538
" optimizer.zero_grad()\n",
539-
" epoch_iterator.set_description(\n",
540-
" \"Training (%d / %d Steps) (loss=%2.5f)\"\n",
541-
" % (global_step, max_iterations, loss)\n",
542-
" )\n",
539+
" epoch_iterator.set_description(f\"Training ({global_step} / {max_iterations} Steps) (loss={loss:2.5f})\")\n",
543540
" if (\n",
544541
" global_step % eval_num == 0 and global_step != 0\n",
545542
" ) or global_step == max_iterations:\n",
@@ -731,7 +728,7 @@
731728
],
732729
"metadata": {
733730
"kernelspec": {
734-
"display_name": "Python 3 (ipykernel)",
731+
"display_name": "base",
735732
"language": "python",
736733
"name": "python3"
737734
},
@@ -745,7 +742,12 @@
745742
"name": "python",
746743
"nbconvert_exporter": "python",
747744
"pygments_lexer": "ipython3",
748-
"version": "3.8.13"
745+
"version": "3.8.13 | packaged by conda-forge | (default, Mar 25 2022, 06:04:10) \n[GCC 10.3.0]"
746+
},
747+
"vscode": {
748+
"interpreter": {
749+
"hash": "d4d1e4263499bec80672ea0156c357c1ee493ec2b1c70f0acce89fc37c4a6abe"
750+
}
749751
}
750752
},
751753
"nbformat": 4,

0 commit comments

Comments
 (0)