Skip to content

Commit e721728

Browse files
[pre-commit.ci] auto fixes from pre-commit.com hooks
for more information, see https://pre-commit.ci
1 parent 172c64e commit e721728

File tree

1 file changed

+23
-29
lines changed

1 file changed

+23
-29
lines changed

modules/network_contraints/unet_plusplus.ipynb

Lines changed: 23 additions & 29 deletions
Original file line numberDiff line numberDiff line change
@@ -117,7 +117,7 @@
117117
"print(monai.networks.layers.factories.Norm.factories.keys())\n",
118118
"for norm_layer in monai.networks.layers.factories.Norm.factories.keys():\n",
119119
" try:\n",
120-
" model = BasicUnetPlusPlus( \n",
120+
" model = BasicUnetPlusPlus(\n",
121121
" spatial_dims=3,\n",
122122
" in_channels=3,\n",
123123
" out_channels=3,\n",
@@ -126,8 +126,7 @@
126126
" norm=norm_layer,\n",
127127
" )\n",
128128
" except Exception as msg:\n",
129-
" print(f\"Exception layer: {msg}\")\n",
130-
" "
129+
" print(f\"Exception layer: {msg}\")"
131130
]
132131
},
133132
{
@@ -192,27 +191,24 @@
192191
"\n",
193192
"def make_model_with_layer(layer_norm):\n",
194193
" return BasicUnetPlusPlus(\n",
195-
" spatial_dims=2,\n",
196-
" in_channels=3,\n",
197-
" out_channels=1,\n",
198-
" features=(32, 32, 64, 128, 256, 32),\n",
199-
" norm=layer_norm\n",
194+
" spatial_dims=2, in_channels=3, out_channels=1, features=(32, 32, 64, 128, 256, 32), norm=layer_norm\n",
200195
" )\n",
201196
"\n",
197+
"\n",
202198
"def test_min_dim():\n",
203199
" MIN_EDGE = 16\n",
204200
" batch_size, spatial_dim, H, W = 1, 3, MIN_EDGE, MIN_EDGE\n",
205201
" MODEL_BY_NORM_LAYER: Dict[str, BasicUnetPlusPlus] = {}\n",
206202
" print(\"Prepare model\")\n",
207-
" for norm_layer in ['instance', 'batch']:\n",
203+
" for norm_layer in [\"instance\", \"batch\"]:\n",
208204
" MODEL_BY_NORM_LAYER[norm_layer] = make_model_with_layer(norm_layer)\n",
209-
" \n",
205+
"\n",
210206
" # print(f\"Input dimension {(batch_size, spatial_dim, H, W)} that will cause error\")\n",
211-
" for norm_layer in ['instance', 'batch']:\n",
212-
" print(\"=\"*10 + f\" USING NORM LAYER: {norm_layer.upper()} \" + \"=\"*10)\n",
207+
" for norm_layer in [\"instance\", \"batch\"]:\n",
208+
" print(\"=\" * 10 + f\" USING NORM LAYER: {norm_layer.upper()} \" + \"=\" * 10)\n",
213209
" model = MODEL_BY_NORM_LAYER[norm_layer]\n",
214210
" print(\"_\" * 10 + \" Changing the H dimension of 2D input \" + \"_\" * 10)\n",
215-
" for _H_temp in [H, H*2]:\n",
211+
" for _H_temp in [H, H * 2]:\n",
216212
" try:\n",
217213
" x = torch.ones(batch_size, spatial_dim, _H_temp, W)\n",
218214
" print(f\">> Using Input.shape={x.shape}\")\n",
@@ -231,6 +227,7 @@
231227
" print(f\">> Exception: {msg}\\n\")\n",
232228
" pass\n",
233229
"\n",
230+
"\n",
234231
"with torch.no_grad():\n",
235232
" test_min_dim()"
236233
]
@@ -296,13 +293,10 @@
296293
],
297294
"source": [
298295
"# Example about the pooling with odd shape\n",
299-
"un_pool = torch.Tensor(\n",
300-
" [[1, 2, 3, 4, 5],\n",
301-
" [1, 2, 3, 4, 5],\n",
302-
" [1, 2, 3, 4, 5],\n",
303-
" [1, 2, 3, 4, 5],\n",
304-
" [1, 2, 3, 4, 5]]\n",
305-
") * torch.Tensor([1, 2, 3, 4, 5])[..., None]\n",
296+
"un_pool = (\n",
297+
" torch.Tensor([[1, 2, 3, 4, 5], [1, 2, 3, 4, 5], [1, 2, 3, 4, 5], [1, 2, 3, 4, 5], [1, 2, 3, 4, 5]])\n",
298+
" * torch.Tensor([1, 2, 3, 4, 5])[..., None]\n",
299+
")\n",
306300
"un_pool = un_pool[None, None, ...]\n",
307301
"\n",
308302
"pooled = nn.MaxPool2d(kernel_size=2)(un_pool)\n",
@@ -342,14 +336,14 @@
342336
"# [batch_size, spatial_dim, D, H, W]\n",
343337
"# )\n",
344338
"# # y = model.forward(x)\n",
345-
"# # print(f\"{x.shape=}\") \n",
339+
"# # print(f\"{x.shape=}\")\n",
346340
"# # print(f\"{[o.shape for o in y]=}\")\n",
347-
" \n",
341+
"\n",
348342
"# return model, x\n",
349343
"\n",
350344
"# # Loss edge info. if input dim is not divisible by 2 when pooling\n",
351345
"# # Ensure the lowest image dimension\n",
352-
"# # As the lowest dim before the norm is [1x1x1] \n",
346+
"# # As the lowest dim before the norm is [1x1x1]\n",
353347
"# # -> instance and batch norm don't allow that, so make it at least\n",
354348
"# # [2x1x1]\n",
355349
"# from monai.networks.blocks import ADN\n",
@@ -358,23 +352,23 @@
358352
"# print(f\"Input {x.shape=}\")\n",
359353
"# model.eval()\n",
360354
"# # 2 conv\n",
361-
"# x_0_0 = model.conv_0_0(x) \n",
355+
"# x_0_0 = model.conv_0_0(x)\n",
362356
"# # down conv\n",
363-
"# x_1_0 = model.conv_1_0(x_0_0) \n",
357+
"# x_1_0 = model.conv_1_0(x_0_0)\n",
364358
"# print(f\"{x_1_0.shape=}\")\n",
365359
"# x_0_1 = model.upcat_0_1(x_1_0, x_0_0)\n",
366360
"# print(f\"{x_0_1.shape=}\")\n",
367-
"# # x_2_0 = model.conv_2_0(x_1_0) \n",
361+
"# # x_2_0 = model.conv_2_0(x_1_0)\n",
368362
"# # print(f\"{x_2_0.shape=}\")\n",
369-
"# # x_3_0 = model.conv_3_0(x_2_0) \n",
363+
"# # x_3_0 = model.conv_3_0(x_2_0)\n",
370364
"# # print(f\"{x_3_0.shape=}\") # 2 x 2 -> 1 x 1\n",
371365
"# # pooled = nn.MaxPool3d(kernel_size=2)(x_3_0)\n",
372366
"# # print(f\"{pooled.shape=}\") # 2 x 2 -> 1 x 1\n",
373367
"# # conved1 = nn.Conv3d(in_channels=128, out_channels=256, kernel_size=3, stride=1, padding=1)(pooled)\n",
374368
"# # conved2 = nn.Conv3d(in_channels=256, out_channels=256, kernel_size=3, stride=1, padding=1)(conved1)\n",
375369
"# # print(f\"{conved1.shape=}\") # 2 x 2 -> 1 x 1\n",
376370
"# # print(f\"{conved2.shape=}\") # 2 x 2 -> 1 x 1\n",
377-
" \n",
371+
"\n",
378372
"# # normed = ADN(\n",
379373
"# # ordering='NDA',\n",
380374
"# # in_channels=256,\n",
@@ -385,7 +379,7 @@
385379
"# # )(conved2)\n",
386380
"# # print(f\"{normed.shape=}\") # 2 x 2 -> 1 x 1\n",
387381
"\n",
388-
"# # x_4_0 = model.conv_4_0(x_3_0) \n",
382+
"# # x_4_0 = model.conv_4_0(x_3_0)\n",
389383
"# # print(f\"{x_4_0.shape=}\")\n",
390384
"\n",
391385
"# # Up path\n",

0 commit comments

Comments
 (0)