Skip to content

Commit 15aa098

Browse files
committed
chore: update reduceAxes variable in GlobalPoolingConverter
Signed-off-by: Ruoqian Guo <[email protected]>
1 parent e554dbd commit 15aa098

File tree

1 file changed

+7
-3
lines changed

1 file changed

+7
-3
lines changed

core/conversion/converters/impl/pooling.cpp

Lines changed: 7 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -16,9 +16,13 @@ bool GlobalPoolingConverter(
1616
nvinfer1::PoolingType pool_type) {
1717
auto in = args[0].ITensorOrFreeze(ctx);
1818
nvinfer1::Dims dims = in->getDimensions();
19-
auto out_size = util::toDims(args[1].unwrapToIntList());
20-
// Generate a bitmask of all 1s except the last 2 bits (N and C axes)
21-
uint32_t reduceAxes = ((1 << dims.nbDims) - 1) ^ ((1 << (dims.nbDims - out_size.nbDims)) - 1);
19+
// Generate a bitmask of all 1s except the last 2 bits (N and C axes) when dims.nbDims >= 2
20+
uint32_t reduceAxes = ((1 << dims.nbDims) - 1) & ~0b11;
21+
// Generate a bitmask of all 1s except the last 1 bits (N axes) when dims.nbDims == 2. `aten::adaptive_avg_pool1d`'s
22+
// input can be (N, C, L) or (C, L).
23+
if (dims.nbDims == 2) {
24+
reduceAxes = ((1 << dims.nbDims) - 1) & ~0b1;
25+
}
2226
auto* new_layer = ctx->net->addReduce(
2327
*in,
2428
pool_type == nvinfer1::PoolingType::kMAX ? nvinfer1::ReduceOperation::kMAX : nvinfer1::ReduceOperation::kAVG,

0 commit comments

Comments
 (0)