File tree Expand file tree Collapse file tree 1 file changed +7
-3
lines changed
core/conversion/converters/impl Expand file tree Collapse file tree 1 file changed +7
-3
lines changed Original file line number Diff line number Diff line change @@ -16,9 +16,13 @@ bool GlobalPoolingConverter(
16
16
nvinfer1::PoolingType pool_type) {
17
17
auto in = args[0 ].ITensorOrFreeze (ctx);
18
18
nvinfer1::Dims dims = in->getDimensions ();
19
- auto out_size = util::toDims (args[1 ].unwrapToIntList ());
20
- // Generate a bitmask of all 1s except the last 2 bits (N and C axes)
21
- uint32_t reduceAxes = ((1 << dims.nbDims ) - 1 ) ^ ((1 << (dims.nbDims - out_size.nbDims )) - 1 );
19
+ // Generate a bitmask of all 1s except the last 2 bits (N and C axes) when dims.nbDims >= 2
20
+ uint32_t reduceAxes = ((1 << dims.nbDims ) - 1 ) & ~0b11 ;
21
+ // Generate a bitmask of all 1s except the last 1 bits (N axes) when dims.nbDims == 2. `aten::adaptive_avg_pool1d`'s
22
+ // input can be (N, C, L) or (C, L).
23
+ if (dims.nbDims == 2 ) {
24
+ reduceAxes = ((1 << dims.nbDims ) - 1 ) & ~0b1 ;
25
+ }
22
26
auto * new_layer = ctx->net ->addReduce (
23
27
*in,
24
28
pool_type == nvinfer1::PoolingType::kMAX ? nvinfer1::ReduceOperation::kMAX : nvinfer1::ReduceOperation::kAVG ,
You can’t perform that action at this time.
0 commit comments