Skip to content

Commit ffaa791

Browse files
authored
UpSample optional kernel_size for deconv mode (#5221)
### Description Adds (optional) kernel_size parameter to UpSample, used for deconv (convolution transpose up-sampling). This allows to upsample, e.g to upscale to 2x with a kernel_size 3. (currently the default is to upscale to 2x with a kernel size 2) if this parameter is not set, the behavior is the same as before ### Types of changes <!--- Put an `x` in all the boxes that apply, and remove the not applicable items --> - [x] Non-breaking change (fix or new feature that would not break existing functionality). - [ ] Breaking change (fix or new feature that would cause existing functionality to change). - [x] New tests added to cover the changes. - [ ] Integration tests passed locally by running `./runtests.sh -f -u --net --coverage`. - [ ] Quick tests passed locally by running `./runtests.sh --quick --unittests --disttests`. - [ ] In-line docstrings updated. - [x] Documentation updated, tested `make html` command in the `docs/` folder. Signed-off-by: myron <[email protected]>
1 parent 77fd5f4 commit ffaa791

File tree

3 files changed

+40
-5
lines changed

3 files changed

+40
-5
lines changed

monai/networks/blocks/upsample.py

Lines changed: 14 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -40,6 +40,7 @@ def __init__(
4040
in_channels: Optional[int] = None,
4141
out_channels: Optional[int] = None,
4242
scale_factor: Union[Sequence[float], float] = 2,
43+
kernel_size: Optional[Union[Sequence[float], float]] = None,
4344
size: Optional[Union[Tuple[int], int]] = None,
4445
mode: Union[UpsampleMode, str] = UpsampleMode.DECONV,
4546
pre_conv: Optional[Union[nn.Module, str]] = "default",
@@ -54,6 +55,7 @@ def __init__(
5455
in_channels: number of channels of the input image.
5556
out_channels: number of channels of the output image. Defaults to `in_channels`.
5657
scale_factor: multiplier for spatial size. Has to match input size if it is a tuple. Defaults to 2.
58+
kernel_size: kernel size used during UpsampleMode.DECONV. Defaults to `scale_factor`.
5759
size: spatial size of the output image.
5860
Only used when ``mode`` is ``UpsampleMode.NONTRAINABLE``.
5961
In torch.nn.functional.interpolate, only one of `size` or `scale_factor` should be defined,
@@ -83,13 +85,24 @@ def __init__(
8385
if up_mode == UpsampleMode.DECONV:
8486
if not in_channels:
8587
raise ValueError(f"in_channels needs to be specified in the '{mode}' mode.")
88+
89+
if not kernel_size:
90+
kernel_size_ = scale_factor_
91+
output_padding = padding = 0
92+
else:
93+
kernel_size_ = ensure_tuple_rep(kernel_size, spatial_dims)
94+
padding = tuple((k - 1) // 2 for k in kernel_size_) # type: ignore
95+
output_padding = tuple(s - 1 - (k - 1) % 2 for k, s in zip(kernel_size_, scale_factor_)) # type: ignore
96+
8697
self.add_module(
8798
"deconv",
8899
Conv[Conv.CONVTRANS, spatial_dims](
89100
in_channels=in_channels,
90101
out_channels=out_channels or in_channels,
91-
kernel_size=scale_factor_,
102+
kernel_size=kernel_size_,
92103
stride=scale_factor_,
104+
padding=padding,
105+
output_padding=output_padding,
93106
bias=bias,
94107
),
95108
)

tests/test_milmodel.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,7 @@
1717
from monai.networks import eval_mode
1818
from monai.networks.nets import MILModel
1919
from monai.utils.module import optional_import
20-
from tests.utils import test_script_save
20+
from tests.utils import skip_if_downloading_fails, test_script_save
2121

2222
models, _ = optional_import("torchvision.models")
2323

@@ -65,7 +65,8 @@
6565
class TestMilModel(unittest.TestCase):
6666
@parameterized.expand(TEST_CASE_MILMODEL)
6767
def test_shape(self, input_param, input_shape, expected_shape):
68-
net = MILModel(**input_param).to(device)
68+
with skip_if_downloading_fails():
69+
net = MILModel(**input_param).to(device)
6970
with eval_mode(net):
7071
result = net(torch.randn(input_shape, dtype=torch.float).to(device))
7172
self.assertEqual(result.shape, expected_shape)

tests/test_upsample_block.py

Lines changed: 23 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -88,13 +88,34 @@
8888
TEST_CASES_EQ.append(test_case)
8989

9090

91+
TEST_CASES_EQ2 = [] # type: ignore
92+
for s in range(2, 5):
93+
for k in range(1, 7):
94+
expected_shape = (16, 5, 4 * s, 5 * s, 6 * s)
95+
for t in UpsampleMode:
96+
test_case = [
97+
{
98+
"spatial_dims": 3,
99+
"in_channels": 3,
100+
"out_channels": 5,
101+
"mode": t,
102+
"scale_factor": s,
103+
"kernel_size": k,
104+
"align_corners": False,
105+
},
106+
(16, 3, 4, 5, 6),
107+
expected_shape,
108+
]
109+
TEST_CASES_EQ.append(test_case)
110+
111+
91112
class TestUpsample(unittest.TestCase):
92-
@parameterized.expand(TEST_CASES + TEST_CASES_EQ)
113+
@parameterized.expand(TEST_CASES + TEST_CASES_EQ + TEST_CASES_EQ2)
93114
def test_shape(self, input_param, input_shape, expected_shape):
94115
net = UpSample(**input_param)
95116
with eval_mode(net):
96117
result = net(torch.randn(input_shape))
97-
self.assertEqual(result.shape, expected_shape)
118+
self.assertEqual(result.shape, expected_shape, msg=str(input_param))
98119

99120

100121
if __name__ == "__main__":

0 commit comments

Comments
 (0)