Skip to content

Commit 0ac503e

Browse files
authored
Merge pull request #792 from guoruoqian/fix_pooling
Feat: support aten::adaptive_max_pool1d, aten::adaptive_avg_pool3d and aten::adaptive_max_pool3d operators and fix issue #791
2 parents 726b031 + 143fc3b commit 0ac503e

File tree

3 files changed

+278
-19
lines changed

3 files changed

+278
-19
lines changed

core/conversion/converters/impl/pooling.cpp

Lines changed: 28 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -16,8 +16,13 @@ bool GlobalPoolingConverter(
1616
nvinfer1::PoolingType pool_type) {
1717
auto in = args[0].ITensorOrFreeze(ctx);
1818
nvinfer1::Dims dims = in->getDimensions();
19-
// Generate a bitmask of all 1s except the last 2 bits (N and C axes)
19+
// Generate a bitmask of all 1s except the last 2 bits (N and C axes) when dims.nbDims > 2
2020
uint32_t reduceAxes = ((1 << dims.nbDims) - 1) & ~0b11;
21+
// Generate a bitmask of all 1s except the last 1 bits (N axes) when dims.nbDims == 2. `aten::adaptive_avg_pool1d`'s
22+
// input can be (N, C, L) or (C, L).
23+
if (dims.nbDims == 2) {
24+
reduceAxes = ((1 << dims.nbDims) - 1) & ~0b1;
25+
}
2126
auto* new_layer = ctx->net->addReduce(
2227
*in,
2328
pool_type == nvinfer1::PoolingType::kMAX ? nvinfer1::ReduceOperation::kMAX : nvinfer1::ReduceOperation::kAVG,
@@ -36,7 +41,8 @@ bool AdaptivePoolingConverter(
3641
ConversionCtx* ctx,
3742
const torch::jit::Node* n,
3843
args& args,
39-
nvinfer1::PoolingType pool_type) {
44+
nvinfer1::PoolingType pool_type,
45+
const std::string& mode) {
4046
auto in = args[0].ITensorOrFreeze(ctx);
4147
auto out_size = util::toDims(args[1].unwrapToIntList());
4248

@@ -47,15 +53,7 @@ bool AdaptivePoolingConverter(
4753
}
4854

4955
auto orig_dims = in->getDimensions();
50-
bool expandDims = (orig_dims.nbDims < 4);
51-
TORCHTRT_CHECK(orig_dims.nbDims > 2, "Unable to create pooling layer from node: " << *n);
52-
if (expandDims) {
53-
in = addPadding(ctx, n, in, 4, false, false);
54-
}
55-
56-
if (out_size.nbDims == 1) {
57-
out_size = util::unsqueezeDims(out_size, 0, 1);
58-
}
56+
TORCHTRT_CHECK(orig_dims.nbDims > 1, "Unable to create pooling layer from node: " << *n);
5957

6058
auto in_shape = util::toVec(in->getDimensions());
6159
nvinfer1::ILayer* new_layer = nullptr;
@@ -89,10 +87,6 @@ bool AdaptivePoolingConverter(
8987
int32_t use_scales_casted = 0;
9088
f.emplace_back(nvinfer1::PluginField("use_scales", &use_scales_casted, nvinfer1::PluginFieldType::kINT32, 1));
9189

92-
std::string mode = "adaptive_avg_pool2d";
93-
if (pool_type == nvinfer1::PoolingType::kMAX) {
94-
mode = "adaptive_max_pool2d";
95-
}
9690
f.emplace_back(nvinfer1::PluginField("mode", &mode, nvinfer1::PluginFieldType::kCHAR, 1));
9791

9892
fc.nbFields = f.size();
@@ -109,7 +103,7 @@ bool AdaptivePoolingConverter(
109103
TORCHTRT_CHECK(new_layer, "Unable to create pooling (interpolation) plugin from node" << *n);
110104

111105
new_layer->setName(util::node_info(n).c_str());
112-
auto layer_output = addUnpadding(ctx, n, new_layer->getOutput(0), orig_dims.nbDims, false, false);
106+
auto layer_output = new_layer->getOutput(0);
113107

114108
ctx->AssociateValueAndTensor(n->outputs()[0], layer_output);
115109
LOG_DEBUG("Output tensor shape: " << layer_output->getDimensions());
@@ -237,15 +231,30 @@ auto pooling_registrations TORCHTRT_UNUSED =
237231
}})
238232
.pattern({"aten::adaptive_avg_pool1d(Tensor self, int[1] output_size) -> (Tensor)",
239233
[](ConversionCtx* ctx, const torch::jit::Node* n, args& args) -> bool {
240-
return AdaptivePoolingConverter(ctx, n, args, nvinfer1::PoolingType::kAVERAGE);
234+
return AdaptivePoolingConverter(
235+
ctx, n, args, nvinfer1::PoolingType::kAVERAGE, "adaptive_avg_pool1d");
236+
}})
237+
.pattern({"aten::adaptive_max_pool1d(Tensor self, int[2] output_size) -> (Tensor, Tensor)",
238+
[](ConversionCtx* ctx, const torch::jit::Node* n, args& args) -> bool {
239+
return AdaptivePoolingConverter(ctx, n, args, nvinfer1::PoolingType::kMAX, "adaptive_max_pool1d");
241240
}})
242241
.pattern({"aten::adaptive_avg_pool2d(Tensor self, int[2] output_size) -> (Tensor)",
243242
[](ConversionCtx* ctx, const torch::jit::Node* n, args& args) -> bool {
244-
return AdaptivePoolingConverter(ctx, n, args, nvinfer1::PoolingType::kAVERAGE);
243+
return AdaptivePoolingConverter(
244+
ctx, n, args, nvinfer1::PoolingType::kAVERAGE, "adaptive_avg_pool2d");
245245
}})
246246
.pattern({"aten::adaptive_max_pool2d(Tensor self, int[2] output_size) -> (Tensor, Tensor)",
247247
[](ConversionCtx* ctx, const torch::jit::Node* n, args& args) -> bool {
248-
return AdaptivePoolingConverter(ctx, n, args, nvinfer1::PoolingType::kMAX);
248+
return AdaptivePoolingConverter(ctx, n, args, nvinfer1::PoolingType::kMAX, "adaptive_max_pool2d");
249+
}})
250+
.pattern({"aten::adaptive_avg_pool3d(Tensor self, int[3] output_size) -> (Tensor)",
251+
[](ConversionCtx* ctx, const torch::jit::Node* n, args& args) -> bool {
252+
return AdaptivePoolingConverter(
253+
ctx, n, args, nvinfer1::PoolingType::kAVERAGE, "adaptive_avg_pool3d");
254+
}})
255+
.pattern({"aten::adaptive_max_pool3d(Tensor self, int[3] output_size) -> (Tensor, Tensor)",
256+
[](ConversionCtx* ctx, const torch::jit::Node* n, args& args) -> bool {
257+
return AdaptivePoolingConverter(ctx, n, args, nvinfer1::PoolingType::kMAX, "adaptive_max_pool3d");
249258
}});
250259
} // namespace
251260
} // namespace impl

core/plugins/impl/interpolate_plugin.cpp

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -289,10 +289,18 @@ int InterpolatePlugin::enqueue(
289289
out = at::upsample_bilinear2d(input, {size_[0], size_[1]}, align_corners_);
290290
} else if (mode_ == "trilinear") {
291291
out = at::upsample_trilinear3d(input, {size_[0], size_[1], size_[2]}, align_corners_);
292+
} else if (mode_ == "adaptive_avg_pool1d") {
293+
out = at::adaptive_avg_pool1d(input, {size_[0]});
294+
} else if (mode_ == "adaptive_max_pool1d") {
295+
out = std::get<0>(at::adaptive_max_pool1d(input, {size_[0]}));
292296
} else if (mode_ == "adaptive_avg_pool2d") {
293297
out = at::adaptive_avg_pool2d(input, {size_[0], size_[1]});
294298
} else if (mode_ == "adaptive_max_pool2d") {
295299
out = std::get<0>(at::adaptive_max_pool2d(input, {size_[0], size_[1]}));
300+
} else if (mode_ == "adaptive_avg_pool3d") {
301+
out = at::adaptive_avg_pool3d(input, {size_[0], size_[1], size_[2]});
302+
} else if (mode_ == "adaptive_max_pool3d") {
303+
out = std::get<0>(at::adaptive_max_pool3d(input, {size_[0], size_[1], size_[2]}));
296304
}
297305
}
298306

tests/core/conversion/converters/test_pooling.cpp

Lines changed: 242 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -436,6 +436,32 @@ TEST(Converters, ATenAdaptiveAvgPool2DConvertsCorrectly) {
436436
ASSERT_TRUE(torch_tensorrt::tests::util::almostEqual(jit_results[0], trt_results[0], 2e-6));
437437
}
438438

439+
TEST(Converters, ATenAdaptiveAvgPool2DGlobalPoolingConvertsCorrectly) {
440+
const auto graph = R"IR(
441+
graph(%0 : Tensor):
442+
%2 : int = prim::Constant[value=1]()
443+
%3 : int = prim::Constant[value=1]()
444+
%6 : int[] = prim::ListConstruct(%2, %3)
445+
%10 : Tensor = aten::adaptive_avg_pool2d(%0, %6)
446+
return (%10))IR";
447+
448+
auto g = std::make_shared<torch::jit::Graph>();
449+
torch::jit::parseIR(graph, g.get());
450+
451+
// PyTorch PyTorch adaptive_avg_pool2d needs a 4D input or a 3D input
452+
auto in = at::randint(-5, 5, {64, 16, 32, 32}, at::kCUDA);
453+
454+
auto jit_in = at::clone(in);
455+
auto params = torch_tensorrt::core::ir::get_static_params(g->inputs(), {});
456+
auto jit_results = torch_tensorrt::tests::util::RunGraph(g, params, {jit_in});
457+
458+
auto trt_in = at::clone(in);
459+
params = torch_tensorrt::core::ir::get_static_params(g->inputs(), {});
460+
auto trt_results = torch_tensorrt::tests::util::RunGraphEngine(g, params, {trt_in});
461+
462+
ASSERT_TRUE(torch_tensorrt::tests::util::almostEqual(jit_results[0], trt_results[0], 2e-6));
463+
}
464+
439465
TEST(Converters, ATenAdaptiveAvgPool2DConvertsCorrectlyWithDynamicInput) {
440466
const auto graph = R"IR(
441467
graph(%0 : Tensor):
@@ -488,6 +514,110 @@ TEST(Converters, ATenAdaptiveAvgPool1DConvertsCorrectly) {
488514
ASSERT_TRUE(torch_tensorrt::tests::util::almostEqual(jit_results[0], trt_results[0], 1.0));
489515
}
490516

517+
TEST(Converters, ATenAdaptiveAvgPool1DGlobalPoolingConvertsCorrectly) {
518+
const auto graph =
519+
R"IR(
520+
graph(%0 : Tensor):
521+
%2 : int = prim::Constant[value=1]()
522+
%6 : int[] = prim::ListConstruct(%2)
523+
%10 : Tensor = aten::adaptive_avg_pool1d(%0, %6)
524+
return (%10))IR";
525+
526+
auto g = std::make_shared<torch::jit::Graph>();
527+
torch::jit::parseIR(graph, g.get());
528+
529+
// PyTorch adaptive_avg_pool1d needs a 3D input or a 2D input
530+
auto in = at::randint(-5, 5, {3, 16}, at::kCUDA);
531+
532+
auto jit_in = at::clone(in);
533+
auto params = torch_tensorrt::core::ir::get_static_params(g->inputs(), {});
534+
auto jit_results = torch_tensorrt::tests::util::RunGraph(g, params, {jit_in});
535+
536+
auto trt_in = at::clone(in);
537+
params = torch_tensorrt::core::ir::get_static_params(g->inputs(), {});
538+
auto trt_results = torch_tensorrt::tests::util::RunGraphEngine(g, params, {trt_in});
539+
540+
ASSERT_TRUE(torch_tensorrt::tests::util::almostEqual(jit_results[0], trt_results[0], 2e-6));
541+
}
542+
543+
TEST(Converters, ATenAdaptiveAvgPool1DUsingPluginConvertsCorrectly) {
544+
const auto graph =
545+
R"IR(
546+
graph(%0 : Tensor):
547+
%2 : int = prim::Constant[value=3]()
548+
%6 : int[] = prim::ListConstruct(%2)
549+
%10 : Tensor = aten::adaptive_avg_pool1d(%0, %6)
550+
return (%10))IR";
551+
552+
auto g = std::make_shared<torch::jit::Graph>();
553+
torch::jit::parseIR(graph, g.get());
554+
555+
// PyTorch adaptive_avg_pool1d needs a 3D input or a 2D input
556+
auto in = at::randint(-5, 5, {1, 3, 16}, at::kCUDA);
557+
558+
auto jit_in = at::clone(in);
559+
auto params = torch_tensorrt::core::ir::get_static_params(g->inputs(), {});
560+
auto jit_results = torch_tensorrt::tests::util::RunGraph(g, params, {jit_in});
561+
562+
auto trt_in = at::clone(in);
563+
params = torch_tensorrt::core::ir::get_static_params(g->inputs(), {});
564+
auto trt_results = torch_tensorrt::tests::util::RunGraphEngine(g, params, {trt_in});
565+
566+
ASSERT_TRUE(torch_tensorrt::tests::util::almostEqual(jit_results[0], trt_results[0], 2e-6));
567+
}
568+
569+
TEST(Converters, ATenAdaptiveMaxPool1DGlobalPoolingConvertsCorrectly) {
570+
const auto graph =
571+
R"IR(
572+
graph(%0 : Tensor):
573+
%2 : int = prim::Constant[value=1]()
574+
%6 : int[] = prim::ListConstruct(%2)
575+
%10 : Tensor, %11 : Tensor = aten::adaptive_max_pool1d(%0, %6)
576+
return (%10, %11))IR";
577+
578+
auto g = std::make_shared<torch::jit::Graph>();
579+
torch::jit::parseIR(graph, g.get());
580+
581+
// PyTorch adaptive_max_pool1d needs a 3D input or a 2D input
582+
auto in = at::randint(-5, 5, {1, 3, 16}, at::kCUDA);
583+
584+
auto jit_in = at::clone(in);
585+
auto params = torch_tensorrt::core::ir::get_static_params(g->inputs(), {});
586+
auto jit_results = torch_tensorrt::tests::util::RunGraph(g, params, {jit_in});
587+
588+
auto trt_in = at::clone(in);
589+
params = torch_tensorrt::core::ir::get_static_params(g->inputs(), {});
590+
auto trt_results = torch_tensorrt::tests::util::RunGraphEngine(g, params, {trt_in});
591+
592+
ASSERT_TRUE(torch_tensorrt::tests::util::almostEqual(jit_results[0], trt_results[0], 2e-6));
593+
}
594+
595+
TEST(Converters, ATenAdaptiveMaxPool1DUsingPluginConvertsCorrectly) {
596+
const auto graph =
597+
R"IR(
598+
graph(%0 : Tensor):
599+
%2 : int = prim::Constant[value=3]()
600+
%6 : int[] = prim::ListConstruct(%2)
601+
%10 : Tensor, %11 : Tensor = aten::adaptive_max_pool1d(%0, %6)
602+
return (%10, %11))IR";
603+
604+
auto g = std::make_shared<torch::jit::Graph>();
605+
torch::jit::parseIR(graph, g.get());
606+
607+
// PyTorch adaptive_max_pool1d needs a 3D input or a 2D input
608+
auto in = at::randint(-5, 5, {1, 3, 16}, at::kCUDA);
609+
610+
auto jit_in = at::clone(in);
611+
auto params = torch_tensorrt::core::ir::get_static_params(g->inputs(), {});
612+
auto jit_results = torch_tensorrt::tests::util::RunGraph(g, params, {jit_in});
613+
614+
auto trt_in = at::clone(in);
615+
params = torch_tensorrt::core::ir::get_static_params(g->inputs(), {});
616+
auto trt_results = torch_tensorrt::tests::util::RunGraphEngine(g, params, {trt_in});
617+
618+
ASSERT_TRUE(torch_tensorrt::tests::util::almostEqual(jit_results[0], trt_results[0], 2e-6));
619+
}
620+
491621
TEST(Converters, ATenAdaptiveMaxPool2DConvertsCorrectly) {
492622
const auto graph = R"IR(
493623
graph(%0 : Tensor):
@@ -539,3 +669,115 @@ TEST(Converters, ATenAdaptiveMaxPool2DConvertsCorrectlyWithDynamicInput) {
539669

540670
ASSERT_TRUE(torch_tensorrt::tests::util::almostEqual(jit_results[0], trt_results[0], 2e-6));
541671
}
672+
673+
TEST(Converters, ATenAdaptiveAvgPool3DGlobalPoolingConvertsCorrectly) {
674+
const auto graph =
675+
R"IR(
676+
graph(%0 : Tensor):
677+
%2 : int = prim::Constant[value=1]()
678+
%3 : int = prim::Constant[value=1]()
679+
%4 : int = prim::Constant[value=1]()
680+
%6 : int[] = prim::ListConstruct(%2, %3, %4)
681+
%10 : Tensor = aten::adaptive_avg_pool3d(%0, %6)
682+
return (%10))IR";
683+
684+
auto g = std::make_shared<torch::jit::Graph>();
685+
torch::jit::parseIR(graph, g.get());
686+
687+
// PyTorch adaptive_avg_pool3d needs a 5D input or a 4D input
688+
auto in = at::randint(-5, 5, {4, 5, 3, 15, 16}, at::kCUDA);
689+
690+
auto jit_in = at::clone(in);
691+
auto params = torch_tensorrt::core::ir::get_static_params(g->inputs(), {});
692+
auto jit_results = torch_tensorrt::tests::util::RunGraph(g, params, {jit_in});
693+
694+
auto trt_in = at::clone(in);
695+
params = torch_tensorrt::core::ir::get_static_params(g->inputs(), {});
696+
auto trt_results = torch_tensorrt::tests::util::RunGraphEngine(g, params, {trt_in});
697+
698+
ASSERT_TRUE(torch_tensorrt::tests::util::almostEqual(jit_results[0], trt_results[0], 2e-6));
699+
}
700+
701+
TEST(Converters, ATenAdaptiveAvgPool3DUsingPluginConvertsCorrectly) {
702+
const auto graph =
703+
R"IR(
704+
graph(%0 : Tensor):
705+
%2 : int = prim::Constant[value=7]()
706+
%3 : int = prim::Constant[value=6]()
707+
%4 : int = prim::Constant[value=5]()
708+
%6 : int[] = prim::ListConstruct(%2, %3, %4)
709+
%10 : Tensor = aten::adaptive_avg_pool3d(%0, %6)
710+
return (%10))IR";
711+
712+
auto g = std::make_shared<torch::jit::Graph>();
713+
torch::jit::parseIR(graph, g.get());
714+
715+
// PyTorch adaptive_avg_pool3d needs a 5D input or a 4D input
716+
auto in = at::randint(-5, 5, {4, 5, 3, 15, 16}, at::kCUDA);
717+
718+
auto jit_in = at::clone(in);
719+
auto params = torch_tensorrt::core::ir::get_static_params(g->inputs(), {});
720+
auto jit_results = torch_tensorrt::tests::util::RunGraph(g, params, {jit_in});
721+
722+
auto trt_in = at::clone(in);
723+
params = torch_tensorrt::core::ir::get_static_params(g->inputs(), {});
724+
auto trt_results = torch_tensorrt::tests::util::RunGraphEngine(g, params, {trt_in});
725+
726+
ASSERT_TRUE(torch_tensorrt::tests::util::almostEqual(jit_results[0], trt_results[0], 2e-6));
727+
}
728+
729+
TEST(Converters, ATenAdaptiveMaxPool3DGlobalPoolingConvertsCorrectly) {
730+
const auto graph =
731+
R"IR(
732+
graph(%0 : Tensor):
733+
%2 : int = prim::Constant[value=1]()
734+
%3 : int = prim::Constant[value=1]()
735+
%4 : int = prim::Constant[value=1]()
736+
%6 : int[] = prim::ListConstruct(%2, %3, %4)
737+
%10 : Tensor, %11 : Tensor = aten::adaptive_max_pool3d(%0, %6)
738+
return (%10, %11))IR";
739+
740+
auto g = std::make_shared<torch::jit::Graph>();
741+
torch::jit::parseIR(graph, g.get());
742+
743+
// PyTorch adaptive_max_pool3d needs a 5D input or a 4D input
744+
auto in = at::randint(-5, 5, {5, 3, 15, 16}, at::kCUDA);
745+
746+
auto jit_in = at::clone(in);
747+
auto params = torch_tensorrt::core::ir::get_static_params(g->inputs(), {});
748+
auto jit_results = torch_tensorrt::tests::util::RunGraph(g, params, {jit_in});
749+
750+
auto trt_in = at::clone(in);
751+
params = torch_tensorrt::core::ir::get_static_params(g->inputs(), {});
752+
auto trt_results = torch_tensorrt::tests::util::RunGraphEngine(g, params, {trt_in});
753+
754+
ASSERT_TRUE(torch_tensorrt::tests::util::almostEqual(jit_results[0], trt_results[0], 2e-6));
755+
}
756+
757+
TEST(Converters, ATenAdaptiveMaxPool3DUsingPluginConvertsCorrectly) {
758+
const auto graph =
759+
R"IR(
760+
graph(%0 : Tensor):
761+
%2 : int = prim::Constant[value=7]()
762+
%3 : int = prim::Constant[value=8]()
763+
%4 : int = prim::Constant[value=9]()
764+
%6 : int[] = prim::ListConstruct(%2, %3, %4)
765+
%10 : Tensor, %11 : Tensor = aten::adaptive_max_pool3d(%0, %6)
766+
return (%10, %11))IR";
767+
768+
auto g = std::make_shared<torch::jit::Graph>();
769+
torch::jit::parseIR(graph, g.get());
770+
771+
// PyTorch adaptive_max_pool3d needs a 5D input or a 4D input
772+
auto in = at::randint(-5, 5, {4, 5, 3, 15, 16}, at::kCUDA);
773+
774+
auto jit_in = at::clone(in);
775+
auto params = torch_tensorrt::core::ir::get_static_params(g->inputs(), {});
776+
auto jit_results = torch_tensorrt::tests::util::RunGraph(g, params, {jit_in});
777+
778+
auto trt_in = at::clone(in);
779+
params = torch_tensorrt::core::ir::get_static_params(g->inputs(), {});
780+
auto trt_results = torch_tensorrt::tests::util::RunGraphEngine(g, params, {trt_in});
781+
782+
ASSERT_TRUE(torch_tensorrt::tests::util::almostEqual(jit_results[0], trt_results[0], 2e-6));
783+
}

0 commit comments

Comments
 (0)