15
15
import numpy as np
16
16
17
17
from monai .deploy .utils .importutil import optional_import
18
+ from monai .utils import StrEnum # Will use the built-in StrEnum when SDK requires Python 3.11.
18
19
19
20
MONAI_UTILS = "monai.utils"
20
21
torch , _ = optional_import ("torch" , "1.5" )
28
29
ImageReader = object # for 'class InMemImageReader(ImageReader):' to work
29
30
decollate_batch , _ = optional_import ("monai.data" , name = "decollate_batch" )
30
31
sliding_window_inference , _ = optional_import ("monai.inferers" , name = "sliding_window_inference" )
32
+ simple_inference , _ = optional_import ("monai.inferers" , name = "SimpleInferer" )
31
33
ensure_tuple , _ = optional_import (MONAI_UTILS , name = "ensure_tuple" )
32
34
MetaKeys , _ = optional_import (MONAI_UTILS , name = "MetaKeys" )
33
35
SpaceKeys , _ = optional_import (MONAI_UTILS , name = "SpaceKeys" )
40
42
41
43
from .inference_operator import InferenceOperator
42
44
43
- __all__ = ["MonaiSegInferenceOperator" , "InMemImageReader" ]
45
+ __all__ = ["MonaiSegInferenceOperator" , "InfererType" , "InMemImageReader" ]
46
+
47
+
48
+ class InfererType (StrEnum ):
49
+ """Represents the supported types of the inferer, e.g. Simple and Sliding Window."""
50
+
51
+ SIMPLE = "simple"
52
+ SLIDING_WINDOW = "sliding_window"
44
53
45
54
46
55
@md .input ("image" , Image , IOType .IN_MEMORY )
@@ -61,22 +70,30 @@ class MonaiSegInferenceOperator(InferenceOperator):
61
70
62
71
def __init__ (
63
72
self ,
64
- roi_size : Union [Sequence [int ], int ],
73
+ roi_size : Optional [ Union [Sequence [int ], int ] ],
65
74
pre_transforms : Compose ,
66
75
post_transforms : Compose ,
67
76
model_name : Optional [str ] = "" ,
68
- overlap : float = 0.5 ,
77
+ overlap : float = 0.25 ,
78
+ sw_batch_size : int = 4 ,
79
+ inferer : Union [InfererType , str ] = InfererType .SLIDING_WINDOW ,
69
80
* args ,
70
81
** kwargs ,
71
82
):
72
83
"""Creates a instance of this class.
73
84
74
85
Args:
75
- roi_size (Union[Sequence[int], int]): The tensor size used in inference.
86
+ roi_size (Union[Sequence[int], int]): The window size to execute "SLIDING_WINDOW" evaluation.
87
+ An optional input only to be passed for "SLIDING_WINDOW".
88
+ If using a "SIMPLE" Inferer, this input is ignored.
76
89
pre_transforms (Compose): MONAI Compose object used for pre-transforms.
77
90
post_transforms (Compose): MONAI Compose object used for post-transforms.
78
91
model_name (str, optional): Name of the model. Default to "" for single model app.
79
- overlap (float): The overlap used in sliding window inference.
92
+ overlap (float): The amount of overlap between scans along each spatial dimension. Defaults to 0.25.
93
+ Applicable for "SLIDING_WINDOW" only.
94
+ sw_batch_size(int): The batch size to run window slices. Defaults to 4.
95
+ Applicable for "SLIDING_WINDOW" only.
96
+ inferer (InfererType): The type of inferer to use, "SIMPLE" or "SLIDING_WINDOW". Defaults to "SLIDING_WINDOW".
80
97
"""
81
98
82
99
super ().__init__ ()
@@ -90,7 +107,9 @@ def __init__(
90
107
self ._pre_transform = pre_transforms
91
108
self ._post_transforms = post_transforms
92
109
self ._model_name = model_name .strip () if isinstance (model_name , str ) else ""
93
- self .overlap = overlap
110
+ self ._overlap = overlap
111
+ self ._sw_batch_size = sw_batch_size
112
+ self ._inferer = inferer
94
113
95
114
@property
96
115
def roi_size (self ):
@@ -134,6 +153,28 @@ def overlap(self, val: float):
134
153
raise ValueError ("Overlap must be between 0 and 1." )
135
154
self ._overlap = val
136
155
156
+ @property
157
+ def sw_batch_size (self ):
158
+ """The batch size to run window slices"""
159
+ return self ._sw_batch_size
160
+
161
+ @sw_batch_size .setter
162
+ def sw_batch_size (self , val : int ):
163
+ if not isinstance (val , int ) or val < 0 :
164
+ raise ValueError ("sw_batch_size must be a positive integer." )
165
+ self ._sw_batch_size = val
166
+
167
+ @property
168
+ def inferer (self ) -> Union [InfererType , str ]:
169
+ """The type of inferer to use"""
170
+ return self ._inferer
171
+
172
+ @inferer .setter
173
+ def inferer (self , val : InfererType ):
174
+ if not isinstance (val , InfererType ):
175
+ raise ValueError (f"Value must be of the correct type { InfererType } ." )
176
+ self ._inferer = val
177
+
137
178
def _convert_dicom_metadata_datatype (self , metadata : Dict ):
138
179
"""Converts metadata in pydicom types to the corresponding native types.
139
180
@@ -218,14 +259,22 @@ def compute(self, op_input: InputContext, op_output: OutputContext, context: Exe
218
259
with torch .no_grad ():
219
260
for d in dataloader :
220
261
images = d [self ._input_dataset_key ].to (device )
221
- sw_batch_size = 4
222
- d [self ._pred_dataset_key ] = sliding_window_inference (
223
- inputs = images ,
224
- roi_size = self ._roi_size ,
225
- sw_batch_size = sw_batch_size ,
226
- overlap = self .overlap ,
227
- predictor = model ,
228
- )
262
+ if self ._inferer == InfererType .SLIDING_WINDOW :
263
+ # Uses the util function to drive the sliding_window inferer
264
+ d [self ._pred_dataset_key ] = sliding_window_inference (
265
+ inputs = images ,
266
+ roi_size = self ._roi_size ,
267
+ sw_batch_size = self ._sw_batch_size ,
268
+ overlap = self ._overlap ,
269
+ predictor = model ,
270
+ )
271
+ elif self ._inferer == InfererType .SIMPLE :
272
+ # Instantiates the SimpleInferer and directly uses its __call__ function
273
+ d [self ._pred_dataset_key ] = simple_inference ()(inputs = images , network = model )
274
+ else :
275
+ raise ValueError (
276
+ f"Unknown inferer: { self ._inferer } . Available options are sliding_window or simple."
277
+ )
229
278
d = [post_transforms (i ) for i in decollate_batch (d )]
230
279
out_ndarray = d [0 ][self ._pred_dataset_key ].cpu ().numpy ()
231
280
# Need to squeeze out the channel dim fist
@@ -241,8 +290,10 @@ def compute(self, op_input: InputContext, op_output: OutputContext, context: Exe
241
290
out_ndarray = out_ndarray .T .astype (np .uint8 )
242
291
print (f"Output Seg image numpy array shaped: { out_ndarray .shape } " )
243
292
print (f"Output Seg image pixel max value: { np .amax (out_ndarray )} " )
293
+ print (f"Output Seg image pixel min value: { np .amin (out_ndarray )} " )
244
294
out_image = Image (out_ndarray , input_img_metadata )
245
295
op_output .set (out_image , "seg_image" )
296
+
246
297
finally :
247
298
# Reset state on completing this method execution.
248
299
with self ._lock :
0 commit comments