2
2
3
3
import torch
4
4
from monai .networks .nets import DynUNet
5
-
6
5
from task_params import deep_supr_num , patch_size , spacing
7
6
8
7
9
8
def get_kernels_strides (task_id ):
9
+ """
10
+ This function is only used for decathlon datasets with the provided patch sizes.
11
+ When refering this method for other tasks, please ensure that the patch size for each spatial dimension should
12
+ be divisible by the product of all strides in the corresponding dimension.
13
+ In addition, the minimal spatial size should have at least one dimension that has twice the size of
14
+ the product of all strides. For patch sizes that cannot find suitable strides, an error will be raised.
15
+
16
+ """
10
17
sizes , spacings = patch_size [task_id ], spacing [task_id ]
18
+ input_size = sizes
11
19
strides , kernels = [], []
12
-
13
20
while True :
14
21
spacing_ratio = [sp / min (spacings ) for sp in spacings ]
15
22
stride = [
@@ -19,10 +26,16 @@ def get_kernels_strides(task_id):
19
26
kernel = [3 if ratio <= 2 else 1 for ratio in spacing_ratio ]
20
27
if all (s == 1 for s in stride ):
21
28
break
29
+ for idx , (i , j ) in enumerate (zip (sizes , stride )):
30
+ if i % j != 0 :
31
+ raise ValueError (
32
+ f"Patch size is not supported, please try to modify the size { input_size [idx ]} in the spatial dimension { idx } ."
33
+ )
22
34
sizes = [i / j for i , j in zip (sizes , stride )]
23
35
spacings = [i * j for i , j in zip (spacings , stride )]
24
36
kernels .append (kernel )
25
37
strides .append (stride )
38
+
26
39
strides .insert (0 , len (spacings ) * [1 ])
27
40
kernels .append (len (spacings ) * [3 ])
28
41
return kernels , strides
0 commit comments