Skip to content

652 Enhance get_kernels_strides of DynUNet pipeline #653

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 6 commits into from
Apr 12, 2022
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
17 changes: 15 additions & 2 deletions modules/dynunet_pipeline/create_network.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,14 +2,21 @@

import torch
from monai.networks.nets import DynUNet

from task_params import deep_supr_num, patch_size, spacing


def get_kernels_strides(task_id):
"""
This function is only used for decathlon datasets with the provided patch sizes.
When refering this method for other tasks, please ensure that the patch size for each spatial dimension should
be divisible by the product of all strides in the corresponding dimension.
In addition, the minimal spatial size should have at least one dimension that has twice the size of
the product of all strides. For patch sizes that cannot find suitable strides, an error will be raised.

"""
sizes, spacings = patch_size[task_id], spacing[task_id]
input_size = sizes
strides, kernels = [], []

while True:
spacing_ratio = [sp / min(spacings) for sp in spacings]
stride = [
Expand All @@ -19,10 +26,16 @@ def get_kernels_strides(task_id):
kernel = [3 if ratio <= 2 else 1 for ratio in spacing_ratio]
if all(s == 1 for s in stride):
break
for idx, (i, j) in enumerate(zip(sizes, stride)):
if i % j != 0:
raise ValueError(
f"Patch size is not supported, please try to modify the size {input_size[idx]} in the spatial dimension {idx}."
)
sizes = [i / j for i, j in zip(sizes, stride)]
spacings = [i * j for i, j in zip(spacings, stride)]
kernels.append(kernel)
strides.append(stride)

strides.insert(0, len(spacings) * [1])
kernels.append(len(spacings) * [3])
return kernels, strides
Expand Down