Skip to content

Commit c8ef7fe

Browse files
kate-sann5100monai-botpre-commit-ci[bot]wyli
authored
5542 amend extract levels in localnet (#5543)
Fixes #5542 ### Description A few sentences describing the changes proposed in this pull request. ### Types of changes <!--- Put an `x` in all the boxes that apply, and remove the not applicable items --> - [x] Non-breaking change (fix or new feature that would not break existing functionality). - [ ] Breaking change (fix or new feature that would cause existing functionality to change). - [x] New tests added to cover the changes. - [ ] Integration tests passed locally by running `./runtests.sh -f -u --net --coverage`. - [x] Quick tests passed locally by running `./runtests.sh --quick --unittests --disttests`. - [x] In-line docstrings updated. - [ ] Documentation updated, tested `make html` command in the `docs/` folder. Signed-off-by: kate-sann5100 <[email protected]> Signed-off-by: monai-bot <[email protected]> Signed-off-by: Wenqi Li <[email protected]> Co-authored-by: monai-bot <[email protected]> Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com> Co-authored-by: Wenqi Li <[email protected]>
1 parent e4b99e1 commit c8ef7fe

File tree

2 files changed

+12
-2
lines changed

2 files changed

+12
-2
lines changed

monai/networks/nets/regunet.py

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -342,7 +342,7 @@ def __init__(self, spatial_dims: int, in_channels: int, out_channels: int):
342342
self.deconv = get_deconv_block(spatial_dims=spatial_dims, in_channels=in_channels, out_channels=out_channels)
343343

344344
def forward(self, x: torch.Tensor) -> torch.Tensor:
345-
output_size = (size * 2 for size in x.shape[2:])
345+
output_size = [size * 2 for size in x.shape[2:]]
346346
deconved = self.deconv(x)
347347
resized = F.interpolate(x, output_size)
348348
resized = torch.sum(torch.stack(resized.split(split_size=resized.shape[1] // 2, dim=1), dim=-1), dim=-1)
@@ -372,6 +372,7 @@ def __init__(
372372
out_activation: Optional[str] = None,
373373
out_channels: int = 3,
374374
pooling: bool = True,
375+
use_addictive_sampling: bool = True,
375376
concat_skip: bool = False,
376377
):
377378
"""
@@ -384,12 +385,15 @@ def __init__(
384385
out_channels: number of channels for the output
385386
extract_levels: list, which levels from net to extract. The maximum level must equal to ``depth``
386387
pooling: for down-sampling, use non-parameterized pooling if true, otherwise use conv3d
388+
use_addictive_sampling: whether use additive up-sampling layer for decoding.
387389
concat_skip: when up-sampling, concatenate skipped tensor if true, otherwise use addition
388390
"""
391+
self.use_additive_upsampling = use_addictive_sampling
389392
super().__init__(
390393
spatial_dims=spatial_dims,
391394
in_channels=in_channels,
392395
num_channel_initial=num_channel_initial,
396+
extract_levels=extract_levels,
393397
depth=max(extract_levels),
394398
out_kernel_initializer=out_kernel_initializer,
395399
out_activation=out_activation,
@@ -406,7 +410,7 @@ def build_bottom_block(self, in_channels: int, out_channels: int):
406410
)
407411

408412
def build_up_sampling_block(self, in_channels: int, out_channels: int) -> nn.Module:
409-
if self._use_additive_upsampling:
413+
if self.use_additive_upsampling:
410414
return AdditiveUpSampleBlock(
411415
spatial_dims=self.spatial_dims, in_channels=in_channels, out_channels=out_channels
412416
)

tests/test_localnet.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -65,6 +65,12 @@ def test_shape(self, input_param, input_shape, expected_shape):
6565
result = net(torch.randn(input_shape).to(device))
6666
self.assertEqual(result.shape, expected_shape)
6767

68+
@parameterized.expand(TEST_CASE_LOCALNET_2D + TEST_CASE_LOCALNET_3D)
69+
def test_extract_levels(self, input_param, input_shape, expected_shape):
70+
net = LocalNet(**input_param).to(device)
71+
self.assertEqual(len(net.decode_deconvs), len(input_param["extract_levels"]) - 1)
72+
self.assertEqual(len(net.decode_convs), len(input_param["extract_levels"]) - 1)
73+
6874
def test_script(self):
6975
input_param, input_shape, _ = TEST_CASE_LOCALNET_2D[0]
7076
net = LocalNet(**input_param)

0 commit comments

Comments
 (0)