Skip to content

Commit 7e404e6

Browse files
authored
Merge pull request #980 from NVIDIA/pr934
Adding test case for Pr934
2 parents c952291 + 535d1a5 commit 7e404e6

File tree

3 files changed

+69
-0
lines changed

3 files changed

+69
-0
lines changed

core/conversion/converters/BUILD

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -79,6 +79,7 @@ cc_library(
7979
"impl/squeeze.cpp",
8080
"impl/stack.cpp",
8181
"impl/topk.cpp",
82+
"impl/max.cpp",
8283
"impl/unary.cpp",
8384
"impl/unsqueeze.cpp",
8485
],
Lines changed: 43 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,43 @@
1+
#include "NvInfer.h"
2+
#include "core/conversion/converters/converters.h"
3+
#include "core/conversion/tensorcontainer/TensorContainer.h"
4+
#include "core/util/prelude.h"
5+
#include "torch/torch.h"
6+
7+
#include <ATen/ATen.h>
8+
#include <vector>
9+
10+
namespace torch_tensorrt {
11+
namespace core {
12+
namespace conversion {
13+
namespace converters {
14+
namespace impl {
15+
namespace {
16+
auto max_registrations TORCHTRT_UNUSED = RegisterNodeConversionPatterns().pattern(
17+
{"aten::max.dim(Tensor self, int dim, bool keepdim=False) -> (Tensor values, Tensor indices)",
18+
[](ConversionCtx* ctx, const torch::jit::Node* n, args& args) -> bool {
19+
auto self = args[0].ITensorOrFreeze(ctx);
20+
auto dim = args[1].unwrapToInt();
21+
auto selfDim = util::toVec(self->getDimensions());
22+
if (dim < 0) {
23+
dim = selfDim.size() + dim;
24+
}
25+
uint32_t shiftDim = 1 << dim;
26+
auto TopKOperation = nvinfer1::TopKOperation::kMAX;
27+
auto new_layer = ctx->net->addTopK(*self, TopKOperation, 1, shiftDim);
28+
TORCHTRT_CHECK(new_layer, "Unable to create max layer from node: " << *n);
29+
30+
auto out0 = ctx->AssociateValueAndTensor(n->outputs()[0], new_layer->getOutput(0));
31+
auto out1 = ctx->AssociateValueAndTensor(n->outputs()[1], new_layer->getOutput(1));
32+
33+
LOG_DEBUG("Output tensor(0) shape: " << out0->getDimensions());
34+
LOG_DEBUG("Output tensor(1) shape: " << out1->getDimensions());
35+
36+
return true;
37+
}});
38+
} // namespace
39+
} // namespace impl
40+
} // namespace converters
41+
} // namespace conversion
42+
} // namespace core
43+
} // namespace torch_tensorrt

tests/core/conversion/converters/test_topk.cpp

Lines changed: 25 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -30,3 +30,28 @@ TEST(Converters, ATenTopKConvertsCorrectly) {
3030
ASSERT_TRUE(
3131
torch_tensorrt::tests::util::almostEqual(jit_results[1], trt_results[1].reshape_as(jit_results[1]), 2e-6));
3232
}
33+
34+
TEST(Converters, ATenMaxDimConvertsCorrectly) {
35+
const auto graph = R"IR(
36+
graph(%x.1 : Tensor):
37+
%2 : int = prim::Constant[value=0]()
38+
%3 : bool = prim::Constant[value=0]()
39+
%4 : Tensor, %5 : Tensor = aten::max(%x.1, %2, %3)
40+
return (%4, %5))IR";
41+
42+
auto g = std::make_shared<torch::jit::Graph>();
43+
torch::jit::parseIR(graph, g.get());
44+
45+
auto in = at::rand({2, 3, 5, 5}, {at::kCUDA});
46+
47+
auto params = torch_tensorrt::core::ir::get_static_params(g->inputs(), {});
48+
auto jit_results = torch_tensorrt::tests::util::RunGraph(g, params, {in});
49+
50+
params = torch_tensorrt::core::ir::get_static_params(g->inputs(), {});
51+
auto trt_results = torch_tensorrt::tests::util::RunGraphEngine(g, params, {in});
52+
53+
ASSERT_TRUE(
54+
torch_tensorrt::tests::util::almostEqual(jit_results[0], trt_results[0].reshape_as(jit_results[0]), 2e-6));
55+
ASSERT_TRUE(
56+
torch_tensorrt::tests::util::almostEqual(jit_results[1], trt_results[1].reshape_as(jit_results[1]), 2e-6));
57+
}

0 commit comments

Comments
 (0)