Skip to content

Commit 1437309

Browse files
vikashgVikash GuptaMMelQin
authored
Added option for SimpleInferer (#428)
* Added option for SimpleInferer on monai_seg_inference_operator.py Signed-off-by: Vikash Gupta <[email protected]> Signed-off-by: M Q <[email protected]> * Corrected formatting Signed-off-by: M Q <[email protected]> * Changed formatting Signed-off-by: Vikash Gupta <[email protected]> Signed-off-by: M Q <[email protected]> * Changed str to StrEnum on inferer options Signed-off-by: Vikash Gupta <[email protected]> Signed-off-by: M Q <[email protected]> * Fixing signoff Signed-off-by: Vikash Gupta <[email protected]> Signed-off-by: M Q <[email protected]> * Define the InfererType StrEnum and updated arguments Signed-off-by: M Q <[email protected]> * Correct a typo in print statement Signed-off-by: M Q <[email protected]> * Quiet pytype complaints Signed-off-by: M Q <[email protected]> --------- Signed-off-by: Vikash Gupta <[email protected]> Signed-off-by: M Q <[email protected]> Co-authored-by: Vikash Gupta <[email protected]> Co-authored-by: M Q <[email protected]>
1 parent 6b28398 commit 1437309

File tree

2 files changed

+68
-15
lines changed

2 files changed

+68
-15
lines changed

examples/apps/ai_livertumor_seg_app/livertumor_seg_operator.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,7 @@
1515

1616
import monai.deploy.core as md
1717
from monai.deploy.core import DataPath, ExecutionContext, Image, InputContext, IOType, Operator, OutputContext
18-
from monai.deploy.operators.monai_seg_inference_operator import InMemImageReader, MonaiSegInferenceOperator
18+
from monai.deploy.operators.monai_seg_inference_operator import InfererType, InMemImageReader, MonaiSegInferenceOperator
1919
from monai.transforms import (
2020
Activationsd,
2121
AsDiscreted,
@@ -92,6 +92,8 @@ def compute(self, op_input: InputContext, op_output: OutputContext, context: Exe
9292
post_transforms,
9393
overlap=0.6,
9494
model_name="",
95+
inferer=InfererType.SLIDING_WINDOW,
96+
sw_batch_size=4,
9597
)
9698

9799
# Setting the keys used in the dictironary based transforms may change.

monai/deploy/operators/monai_seg_inference_operator.py

Lines changed: 65 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,7 @@
1515
import numpy as np
1616

1717
from monai.deploy.utils.importutil import optional_import
18+
from monai.utils import StrEnum # Will use the built-in StrEnum when SDK requires Python 3.11.
1819

1920
MONAI_UTILS = "monai.utils"
2021
torch, _ = optional_import("torch", "1.5")
@@ -28,6 +29,7 @@
2829
ImageReader = object # for 'class InMemImageReader(ImageReader):' to work
2930
decollate_batch, _ = optional_import("monai.data", name="decollate_batch")
3031
sliding_window_inference, _ = optional_import("monai.inferers", name="sliding_window_inference")
32+
simple_inference, _ = optional_import("monai.inferers", name="SimpleInferer")
3133
ensure_tuple, _ = optional_import(MONAI_UTILS, name="ensure_tuple")
3234
MetaKeys, _ = optional_import(MONAI_UTILS, name="MetaKeys")
3335
SpaceKeys, _ = optional_import(MONAI_UTILS, name="SpaceKeys")
@@ -40,7 +42,14 @@
4042

4143
from .inference_operator import InferenceOperator
4244

43-
__all__ = ["MonaiSegInferenceOperator", "InMemImageReader"]
45+
__all__ = ["MonaiSegInferenceOperator", "InfererType", "InMemImageReader"]
46+
47+
48+
class InfererType(StrEnum):
49+
"""Represents the supported types of the inferer, e.g. Simple and Sliding Window."""
50+
51+
SIMPLE = "simple"
52+
SLIDING_WINDOW = "sliding_window"
4453

4554

4655
@md.input("image", Image, IOType.IN_MEMORY)
@@ -61,22 +70,30 @@ class MonaiSegInferenceOperator(InferenceOperator):
6170

6271
def __init__(
6372
self,
64-
roi_size: Union[Sequence[int], int],
73+
roi_size: Optional[Union[Sequence[int], int]],
6574
pre_transforms: Compose,
6675
post_transforms: Compose,
6776
model_name: Optional[str] = "",
68-
overlap: float = 0.5,
77+
overlap: float = 0.25,
78+
sw_batch_size: int = 4,
79+
inferer: Union[InfererType, str] = InfererType.SLIDING_WINDOW,
6980
*args,
7081
**kwargs,
7182
):
7283
"""Creates a instance of this class.
7384
7485
Args:
75-
roi_size (Union[Sequence[int], int]): The tensor size used in inference.
86+
roi_size (Union[Sequence[int], int]): The window size to execute "SLIDING_WINDOW" evaluation.
87+
An optional input only to be passed for "SLIDING_WINDOW".
88+
If using a "SIMPLE" Inferer, this input is ignored.
7689
pre_transforms (Compose): MONAI Compose object used for pre-transforms.
7790
post_transforms (Compose): MONAI Compose object used for post-transforms.
7891
model_name (str, optional): Name of the model. Default to "" for single model app.
79-
overlap (float): The overlap used in sliding window inference.
92+
overlap (float): The amount of overlap between scans along each spatial dimension. Defaults to 0.25.
93+
Applicable for "SLIDING_WINDOW" only.
94+
sw_batch_size(int): The batch size to run window slices. Defaults to 4.
95+
Applicable for "SLIDING_WINDOW" only.
96+
inferer (InfererType): The type of inferer to use, "SIMPLE" or "SLIDING_WINDOW". Defaults to "SLIDING_WINDOW".
8097
"""
8198

8299
super().__init__()
@@ -90,7 +107,9 @@ def __init__(
90107
self._pre_transform = pre_transforms
91108
self._post_transforms = post_transforms
92109
self._model_name = model_name.strip() if isinstance(model_name, str) else ""
93-
self.overlap = overlap
110+
self._overlap = overlap
111+
self._sw_batch_size = sw_batch_size
112+
self._inferer = inferer
94113

95114
@property
96115
def roi_size(self):
@@ -134,6 +153,28 @@ def overlap(self, val: float):
134153
raise ValueError("Overlap must be between 0 and 1.")
135154
self._overlap = val
136155

156+
@property
157+
def sw_batch_size(self):
158+
"""The batch size to run window slices"""
159+
return self._sw_batch_size
160+
161+
@sw_batch_size.setter
162+
def sw_batch_size(self, val: int):
163+
if not isinstance(val, int) or val < 0:
164+
raise ValueError("sw_batch_size must be a positive integer.")
165+
self._sw_batch_size = val
166+
167+
@property
168+
def inferer(self) -> Union[InfererType, str]:
169+
"""The type of inferer to use"""
170+
return self._inferer
171+
172+
@inferer.setter
173+
def inferer(self, val: InfererType):
174+
if not isinstance(val, InfererType):
175+
raise ValueError(f"Value must be of the correct type {InfererType}.")
176+
self._inferer = val
177+
137178
def _convert_dicom_metadata_datatype(self, metadata: Dict):
138179
"""Converts metadata in pydicom types to the corresponding native types.
139180
@@ -218,14 +259,22 @@ def compute(self, op_input: InputContext, op_output: OutputContext, context: Exe
218259
with torch.no_grad():
219260
for d in dataloader:
220261
images = d[self._input_dataset_key].to(device)
221-
sw_batch_size = 4
222-
d[self._pred_dataset_key] = sliding_window_inference(
223-
inputs=images,
224-
roi_size=self._roi_size,
225-
sw_batch_size=sw_batch_size,
226-
overlap=self.overlap,
227-
predictor=model,
228-
)
262+
if self._inferer == InfererType.SLIDING_WINDOW:
263+
# Uses the util function to drive the sliding_window inferer
264+
d[self._pred_dataset_key] = sliding_window_inference(
265+
inputs=images,
266+
roi_size=self._roi_size,
267+
sw_batch_size=self._sw_batch_size,
268+
overlap=self._overlap,
269+
predictor=model,
270+
)
271+
elif self._inferer == InfererType.SIMPLE:
272+
# Instantiates the SimpleInferer and directly uses its __call__ function
273+
d[self._pred_dataset_key] = simple_inference()(inputs=images, network=model)
274+
else:
275+
raise ValueError(
276+
f"Unknown inferer: {self._inferer}. Available options are sliding_window or simple."
277+
)
229278
d = [post_transforms(i) for i in decollate_batch(d)]
230279
out_ndarray = d[0][self._pred_dataset_key].cpu().numpy()
231280
# Need to squeeze out the channel dim fist
@@ -241,8 +290,10 @@ def compute(self, op_input: InputContext, op_output: OutputContext, context: Exe
241290
out_ndarray = out_ndarray.T.astype(np.uint8)
242291
print(f"Output Seg image numpy array shaped: {out_ndarray.shape}")
243292
print(f"Output Seg image pixel max value: {np.amax(out_ndarray)}")
293+
print(f"Output Seg image pixel min value: {np.amin(out_ndarray)}")
244294
out_image = Image(out_ndarray, input_img_metadata)
245295
op_output.set(out_image, "seg_image")
296+
246297
finally:
247298
# Reset state on completing this method execution.
248299
with self._lock:

0 commit comments

Comments
 (0)