Skip to content

Commit 53983a7

Browse files
652 Enhance get_kernels_strides of DynUNet pipeline (Project-MONAI#653)
* enhance get-kernels-strides Signed-off-by: Yiheng Wang <[email protected]> * modify valueerror message Signed-off-by: Yiheng Wang <[email protected]> * remove unused var Signed-off-by: Yiheng Wang <[email protected]> * change raise msg Signed-off-by: Yiheng Wang <[email protected]>
1 parent 98d4461 commit 53983a7

File tree

1 file changed

+15
-2
lines changed

1 file changed

+15
-2
lines changed

modules/dynunet_pipeline/create_network.py

Lines changed: 15 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -2,14 +2,21 @@
22

33
import torch
44
from monai.networks.nets import DynUNet
5-
65
from task_params import deep_supr_num, patch_size, spacing
76

87

98
def get_kernels_strides(task_id):
9+
"""
10+
This function is only used for decathlon datasets with the provided patch sizes.
11+
When refering this method for other tasks, please ensure that the patch size for each spatial dimension should
12+
be divisible by the product of all strides in the corresponding dimension.
13+
In addition, the minimal spatial size should have at least one dimension that has twice the size of
14+
the product of all strides. For patch sizes that cannot find suitable strides, an error will be raised.
15+
16+
"""
1017
sizes, spacings = patch_size[task_id], spacing[task_id]
18+
input_size = sizes
1119
strides, kernels = [], []
12-
1320
while True:
1421
spacing_ratio = [sp / min(spacings) for sp in spacings]
1522
stride = [
@@ -19,10 +26,16 @@ def get_kernels_strides(task_id):
1926
kernel = [3 if ratio <= 2 else 1 for ratio in spacing_ratio]
2027
if all(s == 1 for s in stride):
2128
break
29+
for idx, (i, j) in enumerate(zip(sizes, stride)):
30+
if i % j != 0:
31+
raise ValueError(
32+
f"Patch size is not supported, please try to modify the size {input_size[idx]} in the spatial dimension {idx}."
33+
)
2234
sizes = [i / j for i, j in zip(sizes, stride)]
2335
spacings = [i * j for i, j in zip(spacings, stride)]
2436
kernels.append(kernel)
2537
strides.append(stride)
38+
2639
strides.insert(0, len(spacings) * [1])
2740
kernels.append(len(spacings) * [3])
2841
return kernels, strides

0 commit comments

Comments
 (0)