@@ -134,6 +134,7 @@ def __init__(
134
134
super (SelectAdaptivePool2d , self ).__init__ ()
135
135
assert input_fmt in ('NCHW' , 'NHWC' )
136
136
self .pool_type = pool_type or '' # convert other falsy values to empty string for consistent TS typing
137
+ pool_type = pool_type .lower ()
137
138
if not pool_type :
138
139
self .pool = nn .Identity () # pass through
139
140
self .flatten = nn .Flatten (1 ) if flatten else nn .Identity ()
@@ -145,8 +146,10 @@ def __init__(
145
146
self .pool = FastAdaptiveAvgMaxPool (flatten , input_fmt = input_fmt )
146
147
elif pool_type .endswith ('max' ):
147
148
self .pool = FastAdaptiveMaxPool (flatten , input_fmt = input_fmt )
148
- else :
149
+ elif pool_type == 'fast' or pool_type . endswith ( 'avg' ) :
149
150
self .pool = FastAdaptiveAvgPool (flatten , input_fmt = input_fmt )
151
+ else :
152
+ assert False , 'Invalid pool type: %s' % pool_type
150
153
self .flatten = nn .Identity ()
151
154
else :
152
155
assert input_fmt == 'NCHW'
@@ -156,8 +159,10 @@ def __init__(
156
159
self .pool = AdaptiveCatAvgMaxPool2d (output_size )
157
160
elif pool_type == 'max' :
158
161
self .pool = nn .AdaptiveMaxPool2d (output_size )
159
- else :
162
+ elif pool_type == 'avg' :
160
163
self .pool = nn .AdaptiveAvgPool2d (output_size )
164
+ else :
165
+ assert False , 'Invalid pool type: %s' % pool_type
161
166
self .flatten = nn .Flatten (1 ) if flatten else nn .Identity ()
162
167
163
168
def is_identity (self ):
0 commit comments