|
1 | 1 | {
|
2 | 2 | "cells": [
|
| 3 | + { |
| 4 | + "cell_type": "markdown", |
| 5 | + "metadata": {}, |
| 6 | + "source": [ |
| 7 | + "Copyright (c) MONAI Consortium \n", |
| 8 | + "Licensed under the Apache License, Version 2.0 (the \"License\"); \n", |
| 9 | + "you may not use this file except in compliance with the License. \n", |
| 10 | + "You may obtain a copy of the License at \n", |
| 11 | + " http://www.apache.org/licenses/LICENSE-2.0 \n", |
| 12 | + "Unless required by applicable law or agreed to in writing, software \n", |
| 13 | + "distributed under the License is distributed on an \"AS IS\" BASIS, \n", |
| 14 | + "WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. \n", |
| 15 | + "See the License for the specific language governing permissions and \n", |
| 16 | + "limitations under the License.\n", |
| 17 | + "\n", |
| 18 | + "# UNet++ input size constrains\n", |
| 19 | + "\n", |
| 20 | + "MONAI provides an enhanced version of UNet (``monai.networks.nets.UNet ``), which not only supports residual units, but also can use more hyperparameters (like ``strides``, ``kernel_size`` and ``up_kernel_size``) than ``monai.networks.nets.BasicUNet``. However, ``UNet`` has some constrains for both network hyperparameters and sizes of input.\n", |
| 21 | + "\n", |
| 22 | + "MONAI provides a version of UNET++ (`` monai.networks.nets.BasicUnetPlusPlus ``), with fixed num. of down-scale layer, strides of 2. The configurations you can change are: the number input and output channels, number of hidden channels (6 different layers), norm and activation, bias of convolution, dropout rate, and up-sampling model. As `UNET`, different model configurations can affect the input shape.\n", |
| 23 | + "\n", |
| 24 | + "The constrains of hyperparameters can be found in the docstring of the network, and this tutorial is focused on how to determine a reasonable input size." |
| 25 | + ] |
| 26 | + }, |
| 27 | + { |
| 28 | + "cell_type": "markdown", |
| 29 | + "metadata": {}, |
| 30 | + "source": [ |
| 31 | + "## Setup enviroments" |
| 32 | + ] |
| 33 | + }, |
3 | 34 | {
|
4 | 35 | "cell_type": "code",
|
5 |
| - "execution_count": 1, |
| 36 | + "execution_count": null, |
6 | 37 | "metadata": {},
|
7 | 38 | "outputs": [],
|
8 | 39 | "source": [
|
9 |
| - "# Copyright (c) MONAI Consortium\n", |
10 |
| - "# Licensed under the Apache License, Version 2.0 (the \"License\");\n", |
11 |
| - "# you may not use this file except in compliance with the License.\n", |
12 |
| - "# You may obtain a copy of the License at\n", |
13 |
| - "# http://www.apache.org/licenses/LICENSE-2.0\n", |
14 |
| - "# Unless required by applicable law or agreed to in writing, software\n", |
15 |
| - "# distributed under the License is distributed on an \"AS IS\" BASIS,\n", |
16 |
| - "# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n", |
17 |
| - "# See the License for the specific language governing permissions and\n", |
18 |
| - "# limitations under the License." |
| 40 | + "!python -c \"import monai\" || pip install -q monai-weekly" |
| 41 | + ] |
| 42 | + }, |
| 43 | + { |
| 44 | + "cell_type": "markdown", |
| 45 | + "metadata": {}, |
| 46 | + "source": [ |
| 47 | + "## Setup imports" |
19 | 48 | ]
|
20 | 49 | },
|
21 | 50 | {
|
|
41 | 70 | "scikit-image version: NOT INSTALLED or UNKNOWN VERSION.\n",
|
42 | 71 | "scipy version: NOT INSTALLED or UNKNOWN VERSION.\n",
|
43 | 72 | "Pillow version: NOT INSTALLED or UNKNOWN VERSION.\n",
|
44 |
| - "Tensorboard version: NOT INSTALLED or UNKNOWN VERSION.\n", |
| 73 | + "Tensorboard version: 2.13.0\n", |
45 | 74 | "gdown version: NOT INSTALLED or UNKNOWN VERSION.\n",
|
46 | 75 | "TorchVision version: NOT INSTALLED or UNKNOWN VERSION.\n",
|
47 |
| - "tqdm version: NOT INSTALLED or UNKNOWN VERSION.\n", |
| 76 | + "tqdm version: 4.65.0\n", |
48 | 77 | "lmdb version: NOT INSTALLED or UNKNOWN VERSION.\n",
|
49 | 78 | "psutil version: 5.9.5\n",
|
50 | 79 | "pandas version: NOT INSTALLED or UNKNOWN VERSION.\n",
|
|
62 | 91 | "source": [
|
63 | 92 | "from monai.networks.nets import BasicUnetPlusPlus\n",
|
64 | 93 | "import monai\n",
|
65 |
| - "import math\n", |
66 | 94 | "import torch\n",
|
67 | 95 | "import torch.nn as nn\n",
|
| 96 | + "from typing import Dict\n", |
68 | 97 | "\n",
|
69 |
| - "monai.config.print_config()" |
| 98 | + "monai.config.print_config()\n" |
| 99 | + ] |
| 100 | + }, |
| 101 | + { |
| 102 | + "cell_type": "markdown", |
| 103 | + "metadata": {}, |
| 104 | + "source": [ |
| 105 | + "## Check UNet++ structure" |
70 | 106 | ]
|
71 | 107 | },
|
72 | 108 | {
|
|
117 | 153 | "print(monai.networks.layers.factories.Norm.factories.keys())\n",
|
118 | 154 | "for norm_layer in monai.networks.layers.factories.Norm.factories.keys():\n",
|
119 | 155 | " try:\n",
|
120 |
| - " model = BasicUnetPlusPlus( \n", |
| 156 | + " model = BasicUnetPlusPlus(\n", |
121 | 157 | " spatial_dims=3,\n",
|
122 | 158 | " in_channels=3,\n",
|
123 | 159 | " out_channels=3,\n",
|
|
126 | 162 | " norm=norm_layer,\n",
|
127 | 163 | " )\n",
|
128 | 164 | " except Exception as msg:\n",
|
129 |
| - " print(f\"Exception layer: {msg}\")\n", |
130 |
| - " " |
| 165 | + " print(f\"Exception layer: {msg}\")\n" |
131 | 166 | ]
|
132 | 167 | },
|
133 | 168 | {
|
134 | 169 | "cell_type": "markdown",
|
135 | 170 | "metadata": {},
|
136 | 171 | "source": [
|
137 | 172 | "# Normalization\n",
|
| 173 | + "\n", |
138 | 174 | "UNET++ use the same `TwoConv`, `Down`, and `UpCat` as UNet. Therefore, you can referred to the `modules/UNet_input_size_constrains.ipynb` for break down analysis. For summary, the constraints for these types of normalization are:\n",
|
139 | 175 | "\n",
|
140 | 176 | "- Instance Norm: the product of spatial dimension must > 1 (not include channel and batch)\n",
|
|
144 | 180 | "\n",
|
145 | 181 | "As for UNET++ have 4 down-sampling blocks with 2x kernel size, with no argument to change this behavior, the smallest edge we can have is `2**4 = 16`, and after the last down-sampling block, the `vector.shape = [..., ..., 1, 1]` or (`[..., ..., 1, 1, 1]` for 3D), which will cause error for the Normalization layer.\n",
|
146 | 182 | "\n",
|
147 |
| - "See the test code below for examples of batch norm and instance norm" |
| 183 | + "See the test code below for examples of batch norm and instance norm\n" |
148 | 184 | ]
|
149 | 185 | },
|
150 | 186 | {
|
|
187 | 223 | }
|
188 | 224 | ],
|
189 | 225 | "source": [
|
190 |
| - "from typing import Dict\n", |
191 |
| - "\n", |
192 |
| - "\n", |
193 | 226 | "def make_model_with_layer(layer_norm):\n",
|
194 | 227 | " return BasicUnetPlusPlus(\n",
|
195 |
| - " spatial_dims=2,\n", |
196 |
| - " in_channels=3,\n", |
197 |
| - " out_channels=1,\n", |
198 |
| - " features=(32, 32, 64, 128, 256, 32),\n", |
199 |
| - " norm=layer_norm\n", |
| 228 | + " spatial_dims=2, in_channels=3, out_channels=1, features=(32, 32, 64, 128, 256, 32), norm=layer_norm\n", |
200 | 229 | " )\n",
|
201 | 230 | "\n",
|
| 231 | + "\n", |
202 | 232 | "def test_min_dim():\n",
|
203 | 233 | " MIN_EDGE = 16\n",
|
204 | 234 | " batch_size, spatial_dim, H, W = 1, 3, MIN_EDGE, MIN_EDGE\n",
|
205 | 235 | " MODEL_BY_NORM_LAYER: Dict[str, BasicUnetPlusPlus] = {}\n",
|
206 | 236 | " print(\"Prepare model\")\n",
|
207 |
| - " for norm_layer in ['instance', 'batch']:\n", |
| 237 | + " for norm_layer in [\"instance\", \"batch\"]:\n", |
208 | 238 | " MODEL_BY_NORM_LAYER[norm_layer] = make_model_with_layer(norm_layer)\n",
|
209 |
| - " \n", |
| 239 | + "\n", |
210 | 240 | " # print(f\"Input dimension {(batch_size, spatial_dim, H, W)} that will cause error\")\n",
|
211 |
| - " for norm_layer in ['instance', 'batch']:\n", |
212 |
| - " print(\"=\"*10 + f\" USING NORM LAYER: {norm_layer.upper()} \" + \"=\"*10)\n", |
| 241 | + " for norm_layer in [\"instance\", \"batch\"]:\n", |
| 242 | + " print(\"=\" * 10 + f\" USING NORM LAYER: {norm_layer.upper()} \" + \"=\" * 10)\n", |
213 | 243 | " model = MODEL_BY_NORM_LAYER[norm_layer]\n",
|
214 | 244 | " print(\"_\" * 10 + \" Changing the H dimension of 2D input \" + \"_\" * 10)\n",
|
215 |
| - " for _H_temp in [H, H*2]:\n", |
| 245 | + " for _H_temp in [H, H * 2]:\n", |
216 | 246 | " try:\n",
|
217 | 247 | " x = torch.ones(batch_size, spatial_dim, _H_temp, W)\n",
|
218 | 248 | " print(f\">> Using Input.shape={x.shape}\")\n",
|
|
231 | 261 | " print(f\">> Exception: {msg}\\n\")\n",
|
232 | 262 | " pass\n",
|
233 | 263 | "\n",
|
| 264 | + "\n", |
234 | 265 | "with torch.no_grad():\n",
|
235 | 266 | " test_min_dim()"
|
236 | 267 | ]
|
|
271 | 302 | "\n",
|
272 | 303 | " return x\n",
|
273 | 304 | "\n",
|
274 |
| - "```" |
| 305 | + "```\n" |
275 | 306 | ]
|
276 | 307 | },
|
277 | 308 | {
|
278 | 309 | "cell_type": "code",
|
279 |
| - "execution_count": 86, |
| 310 | + "execution_count": 2, |
280 | 311 | "metadata": {},
|
281 | 312 | "outputs": [
|
282 | 313 | {
|
|
296 | 327 | ],
|
297 | 328 | "source": [
|
298 | 329 | "# Example about the pooling with odd shape\n",
|
299 |
| - "un_pool = torch.Tensor(\n", |
300 |
| - " [[1, 2, 3, 4, 5],\n", |
301 |
| - " [1, 2, 3, 4, 5],\n", |
302 |
| - " [1, 2, 3, 4, 5],\n", |
303 |
| - " [1, 2, 3, 4, 5],\n", |
304 |
| - " [1, 2, 3, 4, 5]]\n", |
305 |
| - ") * torch.Tensor([1, 2, 3, 4, 5])[..., None]\n", |
| 330 | + "un_pool = (\n", |
| 331 | + " torch.Tensor(\n", |
| 332 | + " [\n", |
| 333 | + " [1, 2, 3, 4, 5],\n", |
| 334 | + " [1, 2, 3, 4, 5],\n", |
| 335 | + " [1, 2, 3, 4, 5],\n", |
| 336 | + " [1, 2, 3, 4, 5],\n", |
| 337 | + " [1, 2, 3, 4, 5],\n", |
| 338 | + " ]\n", |
| 339 | + " )\n", |
| 340 | + " * torch.Tensor([1, 2, 3, 4, 5])[..., None]\n", |
| 341 | + ")\n", |
306 | 342 | "un_pool = un_pool[None, None, ...]\n",
|
307 | 343 | "\n",
|
308 | 344 | "pooled = nn.MaxPool2d(kernel_size=2)(un_pool)\n",
|
|
342 | 378 | "# [batch_size, spatial_dim, D, H, W]\n",
|
343 | 379 | "# )\n",
|
344 | 380 | "# # y = model.forward(x)\n",
|
345 |
| - "# # print(f\"{x.shape=}\") \n", |
| 381 | + "# # print(f\"{x.shape=}\")\n", |
346 | 382 | "# # print(f\"{[o.shape for o in y]=}\")\n",
|
347 |
| - " \n", |
| 383 | + "\n", |
348 | 384 | "# return model, x\n",
|
349 | 385 | "\n",
|
350 | 386 | "# # Loss edge info. if input dim is not divisible by 2 when pooling\n",
|
351 | 387 | "# # Ensure the lowest image dimension\n",
|
352 |
| - "# # As the lowest dim before the norm is [1x1x1] \n", |
| 388 | + "# # As the lowest dim before the norm is [1x1x1]\n", |
353 | 389 | "# # -> instance and batch norm don't allow that, so make it at least\n",
|
354 | 390 | "# # [2x1x1]\n",
|
355 | 391 | "# from monai.networks.blocks import ADN\n",
|
|
358 | 394 | "# print(f\"Input {x.shape=}\")\n",
|
359 | 395 | "# model.eval()\n",
|
360 | 396 | "# # 2 conv\n",
|
361 |
| - "# x_0_0 = model.conv_0_0(x) \n", |
| 397 | + "# x_0_0 = model.conv_0_0(x)\n", |
362 | 398 | "# # down conv\n",
|
363 |
| - "# x_1_0 = model.conv_1_0(x_0_0) \n", |
| 399 | + "# x_1_0 = model.conv_1_0(x_0_0)\n", |
364 | 400 | "# print(f\"{x_1_0.shape=}\")\n",
|
365 | 401 | "# x_0_1 = model.upcat_0_1(x_1_0, x_0_0)\n",
|
366 | 402 | "# print(f\"{x_0_1.shape=}\")\n",
|
367 |
| - "# # x_2_0 = model.conv_2_0(x_1_0) \n", |
| 403 | + "# # x_2_0 = model.conv_2_0(x_1_0)\n", |
368 | 404 | "# # print(f\"{x_2_0.shape=}\")\n",
|
369 |
| - "# # x_3_0 = model.conv_3_0(x_2_0) \n", |
| 405 | + "# # x_3_0 = model.conv_3_0(x_2_0)\n", |
370 | 406 | "# # print(f\"{x_3_0.shape=}\") # 2 x 2 -> 1 x 1\n",
|
371 | 407 | "# # pooled = nn.MaxPool3d(kernel_size=2)(x_3_0)\n",
|
372 | 408 | "# # print(f\"{pooled.shape=}\") # 2 x 2 -> 1 x 1\n",
|
373 | 409 | "# # conved1 = nn.Conv3d(in_channels=128, out_channels=256, kernel_size=3, stride=1, padding=1)(pooled)\n",
|
374 | 410 | "# # conved2 = nn.Conv3d(in_channels=256, out_channels=256, kernel_size=3, stride=1, padding=1)(conved1)\n",
|
375 | 411 | "# # print(f\"{conved1.shape=}\") # 2 x 2 -> 1 x 1\n",
|
376 | 412 | "# # print(f\"{conved2.shape=}\") # 2 x 2 -> 1 x 1\n",
|
377 |
| - " \n", |
| 413 | + "\n", |
378 | 414 | "# # normed = ADN(\n",
|
379 | 415 | "# # ordering='NDA',\n",
|
380 | 416 | "# # in_channels=256,\n",
|
|
385 | 421 | "# # )(conved2)\n",
|
386 | 422 | "# # print(f\"{normed.shape=}\") # 2 x 2 -> 1 x 1\n",
|
387 | 423 | "\n",
|
388 |
| - "# # x_4_0 = model.conv_4_0(x_3_0) \n", |
| 424 | + "# # x_4_0 = model.conv_4_0(x_3_0)\n", |
389 | 425 | "# # print(f\"{x_4_0.shape=}\")\n",
|
390 | 426 | "\n",
|
391 | 427 | "# # Up path\n",
|
392 | 428 | "# # x_3_0 = model.conv_3_0(x_2_0)\n",
|
393 | 429 | "# # print(f\"{x_1_0.shape=}\")"
|
394 | 430 | ]
|
395 |
| - }, |
396 |
| - { |
397 |
| - "cell_type": "code", |
398 |
| - "execution_count": 37, |
399 |
| - "metadata": {}, |
400 |
| - "outputs": [ |
401 |
| - { |
402 |
| - "data": { |
403 |
| - "text/plain": [ |
404 |
| - "37" |
405 |
| - ] |
406 |
| - }, |
407 |
| - "execution_count": 37, |
408 |
| - "metadata": {}, |
409 |
| - "output_type": "execute_result" |
410 |
| - } |
411 |
| - ], |
412 |
| - "source": [] |
413 |
| - }, |
414 |
| - { |
415 |
| - "cell_type": "code", |
416 |
| - "execution_count": 29, |
417 |
| - "metadata": {}, |
418 |
| - "outputs": [ |
419 |
| - { |
420 |
| - "name": "stdout", |
421 |
| - "output_type": "stream", |
422 |
| - "text": [ |
423 |
| - "un_pool.shape=torch.Size([1, 1, 5, 5]), pooled.shape=torch.Size([1, 1, 2, 2])\n", |
424 |
| - "tensor([[[[ 1., 2., 3., 4., 5.],\n", |
425 |
| - " [ 2., 4., 6., 8., 10.],\n", |
426 |
| - " [ 3., 6., 9., 12., 15.],\n", |
427 |
| - " [ 4., 8., 12., 16., 20.],\n", |
428 |
| - " [ 5., 10., 15., 20., 25.]]]])\n", |
429 |
| - "tensor([[[[ 4., 8.],\n", |
430 |
| - " [ 8., 16.]]]])\n" |
431 |
| - ] |
432 |
| - } |
433 |
| - ], |
434 |
| - "source": [] |
435 | 431 | }
|
436 | 432 | ],
|
437 | 433 | "metadata": {
|
|
0 commit comments