|
230 | 230 | "\n",
|
231 | 231 | "\n",
|
232 | 232 | "def test_min_dim():\n",
|
233 |
| - " MIN_EDGE = 16\n", |
234 |
| - " batch_size, spatial_dim, H, W = 1, 3, MIN_EDGE, MIN_EDGE\n", |
235 |
| - " MODEL_BY_NORM_LAYER: Dict[str, BasicUnetPlusPlus] = {}\n", |
| 233 | + " min_edge = 16\n", |
| 234 | + " batch_size, spatial_dim, height, width = 1, 3, min_edge, min_edge\n", |
| 235 | + " model_dict: Dict[str, BasicUnetPlusPlus] = {}\n", |
236 | 236 | " print(\"Prepare model\")\n",
|
237 | 237 | " for norm_layer in [\"instance\", \"batch\"]:\n",
|
238 |
| - " MODEL_BY_NORM_LAYER[norm_layer] = make_model_with_layer(norm_layer)\n", |
| 238 | + " model_dict[norm_layer] = make_model_with_layer(norm_layer)\n", |
239 | 239 | "\n",
|
240 | 240 | " # print(f\"Input dimension {(batch_size, spatial_dim, H, W)} that will cause error\")\n",
|
241 | 241 | " for norm_layer in [\"instance\", \"batch\"]:\n",
|
242 | 242 | " print(\"=\" * 10 + f\" USING NORM LAYER: {norm_layer.upper()} \" + \"=\" * 10)\n",
|
243 |
| - " model = MODEL_BY_NORM_LAYER[norm_layer]\n", |
| 243 | + " model = model_dict[norm_layer]\n", |
244 | 244 | " print(\"_\" * 10 + \" Changing the H dimension of 2D input \" + \"_\" * 10)\n",
|
245 |
| - " for _H_temp in [H, H * 2]:\n", |
| 245 | + " for temp_height in [height, height * 2]:\n", |
246 | 246 | " try:\n",
|
247 |
| - " x = torch.ones(batch_size, spatial_dim, _H_temp, W)\n", |
| 247 | + " x = torch.ones(batch_size, spatial_dim, temp_height, width)\n", |
248 | 248 | " print(f\">> Using Input.shape={x.shape}\")\n",
|
249 | 249 | " model(x)\n",
|
250 | 250 | " except Exception as msg:\n",
|
|
254 | 254 | " print(\"_\" * 10 + \" Changing the batch size \" + \"_\" * 10)\n",
|
255 | 255 | " for batch_size_tmp in [1, 2]:\n",
|
256 | 256 | " try:\n",
|
257 |
| - " x = torch.ones(batch_size_tmp, spatial_dim, H, W)\n", |
| 257 | + " x = torch.ones(batch_size_tmp, spatial_dim, height, width)\n", |
258 | 258 | " print(f\">> Input.shape={x.shape}\")\n",
|
259 | 259 | " model(x)\n",
|
260 | 260 | " except Exception as msg:\n",
|
|
0 commit comments