|
117 | 117 | "print(monai.networks.layers.factories.Norm.factories.keys())\n",
|
118 | 118 | "for norm_layer in monai.networks.layers.factories.Norm.factories.keys():\n",
|
119 | 119 | " try:\n",
|
120 |
| - " model = BasicUnetPlusPlus( \n", |
| 120 | + " model = BasicUnetPlusPlus(\n", |
121 | 121 | " spatial_dims=3,\n",
|
122 | 122 | " in_channels=3,\n",
|
123 | 123 | " out_channels=3,\n",
|
|
126 | 126 | " norm=norm_layer,\n",
|
127 | 127 | " )\n",
|
128 | 128 | " except Exception as msg:\n",
|
129 |
| - " print(f\"Exception layer: {msg}\")\n", |
130 |
| - " " |
| 129 | + " print(f\"Exception layer: {msg}\")" |
131 | 130 | ]
|
132 | 131 | },
|
133 | 132 | {
|
|
192 | 191 | "\n",
|
193 | 192 | "def make_model_with_layer(layer_norm):\n",
|
194 | 193 | " return BasicUnetPlusPlus(\n",
|
195 |
| - " spatial_dims=2,\n", |
196 |
| - " in_channels=3,\n", |
197 |
| - " out_channels=1,\n", |
198 |
| - " features=(32, 32, 64, 128, 256, 32),\n", |
199 |
| - " norm=layer_norm\n", |
| 194 | + " spatial_dims=2, in_channels=3, out_channels=1, features=(32, 32, 64, 128, 256, 32), norm=layer_norm\n", |
200 | 195 | " )\n",
|
201 | 196 | "\n",
|
| 197 | + "\n", |
202 | 198 | "def test_min_dim():\n",
|
203 | 199 | " MIN_EDGE = 16\n",
|
204 | 200 | " batch_size, spatial_dim, H, W = 1, 3, MIN_EDGE, MIN_EDGE\n",
|
205 | 201 | " MODEL_BY_NORM_LAYER: Dict[str, BasicUnetPlusPlus] = {}\n",
|
206 | 202 | " print(\"Prepare model\")\n",
|
207 |
| - " for norm_layer in ['instance', 'batch']:\n", |
| 203 | + " for norm_layer in [\"instance\", \"batch\"]:\n", |
208 | 204 | " MODEL_BY_NORM_LAYER[norm_layer] = make_model_with_layer(norm_layer)\n",
|
209 |
| - " \n", |
| 205 | + "\n", |
210 | 206 | " # print(f\"Input dimension {(batch_size, spatial_dim, H, W)} that will cause error\")\n",
|
211 |
| - " for norm_layer in ['instance', 'batch']:\n", |
212 |
| - " print(\"=\"*10 + f\" USING NORM LAYER: {norm_layer.upper()} \" + \"=\"*10)\n", |
| 207 | + " for norm_layer in [\"instance\", \"batch\"]:\n", |
| 208 | + " print(\"=\" * 10 + f\" USING NORM LAYER: {norm_layer.upper()} \" + \"=\" * 10)\n", |
213 | 209 | " model = MODEL_BY_NORM_LAYER[norm_layer]\n",
|
214 | 210 | " print(\"_\" * 10 + \" Changing the H dimension of 2D input \" + \"_\" * 10)\n",
|
215 |
| - " for _H_temp in [H, H*2]:\n", |
| 211 | + " for _H_temp in [H, H * 2]:\n", |
216 | 212 | " try:\n",
|
217 | 213 | " x = torch.ones(batch_size, spatial_dim, _H_temp, W)\n",
|
218 | 214 | " print(f\">> Using Input.shape={x.shape}\")\n",
|
|
231 | 227 | " print(f\">> Exception: {msg}\\n\")\n",
|
232 | 228 | " pass\n",
|
233 | 229 | "\n",
|
| 230 | + "\n", |
234 | 231 | "with torch.no_grad():\n",
|
235 | 232 | " test_min_dim()"
|
236 | 233 | ]
|
|
296 | 293 | ],
|
297 | 294 | "source": [
|
298 | 295 | "# Example about the pooling with odd shape\n",
|
299 |
| - "un_pool = torch.Tensor(\n", |
300 |
| - " [[1, 2, 3, 4, 5],\n", |
301 |
| - " [1, 2, 3, 4, 5],\n", |
302 |
| - " [1, 2, 3, 4, 5],\n", |
303 |
| - " [1, 2, 3, 4, 5],\n", |
304 |
| - " [1, 2, 3, 4, 5]]\n", |
305 |
| - ") * torch.Tensor([1, 2, 3, 4, 5])[..., None]\n", |
| 296 | + "un_pool = (\n", |
| 297 | + " torch.Tensor([[1, 2, 3, 4, 5], [1, 2, 3, 4, 5], [1, 2, 3, 4, 5], [1, 2, 3, 4, 5], [1, 2, 3, 4, 5]])\n", |
| 298 | + " * torch.Tensor([1, 2, 3, 4, 5])[..., None]\n", |
| 299 | + ")\n", |
306 | 300 | "un_pool = un_pool[None, None, ...]\n",
|
307 | 301 | "\n",
|
308 | 302 | "pooled = nn.MaxPool2d(kernel_size=2)(un_pool)\n",
|
|
342 | 336 | "# [batch_size, spatial_dim, D, H, W]\n",
|
343 | 337 | "# )\n",
|
344 | 338 | "# # y = model.forward(x)\n",
|
345 |
| - "# # print(f\"{x.shape=}\") \n", |
| 339 | + "# # print(f\"{x.shape=}\")\n", |
346 | 340 | "# # print(f\"{[o.shape for o in y]=}\")\n",
|
347 |
| - " \n", |
| 341 | + "\n", |
348 | 342 | "# return model, x\n",
|
349 | 343 | "\n",
|
350 | 344 | "# # Loss edge info. if input dim is not divisible by 2 when pooling\n",
|
351 | 345 | "# # Ensure the lowest image dimension\n",
|
352 |
| - "# # As the lowest dim before the norm is [1x1x1] \n", |
| 346 | + "# # As the lowest dim before the norm is [1x1x1]\n", |
353 | 347 | "# # -> instance and batch norm don't allow that, so make it at least\n",
|
354 | 348 | "# # [2x1x1]\n",
|
355 | 349 | "# from monai.networks.blocks import ADN\n",
|
|
358 | 352 | "# print(f\"Input {x.shape=}\")\n",
|
359 | 353 | "# model.eval()\n",
|
360 | 354 | "# # 2 conv\n",
|
361 |
| - "# x_0_0 = model.conv_0_0(x) \n", |
| 355 | + "# x_0_0 = model.conv_0_0(x)\n", |
362 | 356 | "# # down conv\n",
|
363 |
| - "# x_1_0 = model.conv_1_0(x_0_0) \n", |
| 357 | + "# x_1_0 = model.conv_1_0(x_0_0)\n", |
364 | 358 | "# print(f\"{x_1_0.shape=}\")\n",
|
365 | 359 | "# x_0_1 = model.upcat_0_1(x_1_0, x_0_0)\n",
|
366 | 360 | "# print(f\"{x_0_1.shape=}\")\n",
|
367 |
| - "# # x_2_0 = model.conv_2_0(x_1_0) \n", |
| 361 | + "# # x_2_0 = model.conv_2_0(x_1_0)\n", |
368 | 362 | "# # print(f\"{x_2_0.shape=}\")\n",
|
369 |
| - "# # x_3_0 = model.conv_3_0(x_2_0) \n", |
| 363 | + "# # x_3_0 = model.conv_3_0(x_2_0)\n", |
370 | 364 | "# # print(f\"{x_3_0.shape=}\") # 2 x 2 -> 1 x 1\n",
|
371 | 365 | "# # pooled = nn.MaxPool3d(kernel_size=2)(x_3_0)\n",
|
372 | 366 | "# # print(f\"{pooled.shape=}\") # 2 x 2 -> 1 x 1\n",
|
373 | 367 | "# # conved1 = nn.Conv3d(in_channels=128, out_channels=256, kernel_size=3, stride=1, padding=1)(pooled)\n",
|
374 | 368 | "# # conved2 = nn.Conv3d(in_channels=256, out_channels=256, kernel_size=3, stride=1, padding=1)(conved1)\n",
|
375 | 369 | "# # print(f\"{conved1.shape=}\") # 2 x 2 -> 1 x 1\n",
|
376 | 370 | "# # print(f\"{conved2.shape=}\") # 2 x 2 -> 1 x 1\n",
|
377 |
| - " \n", |
| 371 | + "\n", |
378 | 372 | "# # normed = ADN(\n",
|
379 | 373 | "# # ordering='NDA',\n",
|
380 | 374 | "# # in_channels=256,\n",
|
|
385 | 379 | "# # )(conved2)\n",
|
386 | 380 | "# # print(f\"{normed.shape=}\") # 2 x 2 -> 1 x 1\n",
|
387 | 381 | "\n",
|
388 |
| - "# # x_4_0 = model.conv_4_0(x_3_0) \n", |
| 382 | + "# # x_4_0 = model.conv_4_0(x_3_0)\n", |
389 | 383 | "# # print(f\"{x_4_0.shape=}\")\n",
|
390 | 384 | "\n",
|
391 | 385 | "# # Up path\n",
|
|
0 commit comments