@@ -68,6 +68,8 @@ class UNet(nn.Module):
68
68
bias: whether to have a bias term in convolution blocks. Defaults to True.
69
69
According to `Performance Tuning Guide <https://pytorch.org/tutorials/recipes/recipes/tuning_guide.html>`_,
70
70
if a conv layer is directly followed by a batch norm layer, bias should be False.
71
+ adn_ordering: a string representing the ordering of activation (A), normalization (N), and dropout (D).
72
+ Defaults to "NDA". See also: :py:class:`monai.networks.blocks.ADN`.
71
73
72
74
Examples::
73
75
@@ -122,6 +124,7 @@ def __init__(
122
124
norm : Union [Tuple , str ] = Norm .INSTANCE ,
123
125
dropout : float = 0.0 ,
124
126
bias : bool = True ,
127
+ adn_ordering : str = "NDA" ,
125
128
dimensions : Optional [int ] = None ,
126
129
) -> None :
127
130
@@ -155,6 +158,7 @@ def __init__(
155
158
self .norm = norm
156
159
self .dropout = dropout
157
160
self .bias = bias
161
+ self .adn_ordering = adn_ordering
158
162
159
163
def _create_block (
160
164
inc : int , outc : int , channels : Sequence [int ], strides : Sequence [int ], is_top : bool
@@ -229,6 +233,7 @@ def _get_down_layer(self, in_channels: int, out_channels: int, strides: int, is_
229
233
norm = self .norm ,
230
234
dropout = self .dropout ,
231
235
bias = self .bias ,
236
+ adn_ordering = self .adn_ordering ,
232
237
)
233
238
return mod
234
239
mod = Convolution (
@@ -241,6 +246,7 @@ def _get_down_layer(self, in_channels: int, out_channels: int, strides: int, is_
241
246
norm = self .norm ,
242
247
dropout = self .dropout ,
243
248
bias = self .bias ,
249
+ adn_ordering = self .adn_ordering ,
244
250
)
245
251
return mod
246
252
@@ -279,6 +285,7 @@ def _get_up_layer(self, in_channels: int, out_channels: int, strides: int, is_to
279
285
bias = self .bias ,
280
286
conv_only = is_top and self .num_res_units == 0 ,
281
287
is_transposed = True ,
288
+ adn_ordering = self .adn_ordering ,
282
289
)
283
290
284
291
if self .num_res_units > 0 :
@@ -294,6 +301,7 @@ def _get_up_layer(self, in_channels: int, out_channels: int, strides: int, is_to
294
301
dropout = self .dropout ,
295
302
bias = self .bias ,
296
303
last_conv_only = is_top ,
304
+ adn_ordering = self .adn_ordering ,
297
305
)
298
306
conv = nn .Sequential (conv , ru )
299
307
0 commit comments