Skip to content

Commit c73d0d1

Browse files
authored
Merge pull request #302 from NVIDIA/topk
Adding support for aten::topk
2 parents b452c6e + 61661ff commit c73d0d1

File tree

4 files changed

+99
-1
lines changed

4 files changed

+99
-1
lines changed

core/conversion/converters/BUILD

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -53,6 +53,7 @@ cc_library(
5353
"impl/stack.cpp",
5454
"impl/lstm_cell.cpp",
5555
"impl/unsqueeze.cpp",
56+
"impl/topk.cpp",
5657
],
5758
deps = [
5859
"@tensorrt//:nvinfer",
Lines changed: 62 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,62 @@
1+
#include "NvInfer.h"
2+
#include "core/conversion/converters/converters.h"
3+
#include "core/conversion/tensorcontainer/TensorContainer.h"
4+
#include "core/util/prelude.h"
5+
#include "torch/torch.h"
6+
7+
#include <ATen/ATen.h>
8+
#include <vector>
9+
10+
namespace trtorch {
11+
namespace core {
12+
namespace conversion {
13+
namespace converters {
14+
namespace impl {
15+
namespace {
16+
17+
auto topk_registrations TRTORCH_UNUSED = RegisterNodeConversionPatterns().pattern(
18+
{"aten::topk(Tensor self, int k, int dim=-1, bool largest=True, bool sorted=True) -> (Tensor values, Tensor indices)",
19+
[](ConversionCtx* ctx, const torch::jit::Node* n, args& args) -> bool {
20+
auto self = args[0].ITensorOrFreeze(ctx);
21+
auto k = args[1].unwrapToInt();
22+
auto dim = args[2].unwrapToInt();
23+
auto largest = args[3].unwrapToBool();
24+
LOG_DEBUG(
25+
"Note: sorted argument is not used in TensorRT for aten::topk, results will depend on the value of largest");
26+
// auto sorted = args[4].unwrapToBool(); # Currently unused
27+
28+
auto selfDim = util::toVec(self->getDimensions());
29+
30+
// reduceAxes The reduction dimensions. The bit in position i of bitmask reduceAxes corresponds to explicit
31+
// dimension i of the result. E.g., the least significant bit corresponds to the first explicit dimension and the
32+
// next to least significant bit corresponds to the second explicit dimension.
33+
34+
if (dim < 0) {
35+
dim = selfDim.size() + dim;
36+
}
37+
38+
uint32_t shiftDim = 1 << dim;
39+
40+
LOG_DEBUG("Output topk reduce dim: " << dim);
41+
42+
auto TopKOperation = largest ? (nvinfer1::TopKOperation::kMAX) : (nvinfer1::TopKOperation::kMIN);
43+
44+
auto new_layer = ctx->net->addTopK(*self, TopKOperation, k, shiftDim);
45+
46+
TRTORCH_CHECK(new_layer, "Unable to create topk layer from node: " << *n);
47+
48+
auto out0 = ctx->AssociateValueAndTensor(n->outputs()[0], new_layer->getOutput(0));
49+
auto out1 = ctx->AssociateValueAndTensor(n->outputs()[1], new_layer->getOutput(1));
50+
51+
LOG_DEBUG("Output tensor(0) shape: " << out0->getDimensions());
52+
LOG_DEBUG("Output tensor(1) shape: " << out1->getDimensions());
53+
54+
return true;
55+
}});
56+
57+
} // namespace
58+
} // namespace impl
59+
} // namespace converters
60+
} // namespace conversion
61+
} // namespace core
62+
} // namespace trtorch

tests/core/conversion/converters/BUILD

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -71,6 +71,10 @@ converter_test(
7171
name = "test_stack"
7272
)
7373

74+
converter_test(
75+
name = "test_topk"
76+
)
77+
7478
converter_test(
7579
name = "test_lstm_cell"
7680
)
@@ -103,6 +107,7 @@ test_suite(
103107
":test_stack",
104108
":test_lstm_cell",
105109
":test_unsqueeze",
106-
":test_squeeze"
110+
":test_squeeze",
111+
":test_topk",
107112
]
108113
)
Lines changed: 30 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,30 @@
1+
#include <string>
2+
#include "core/compiler.h"
3+
#include "gtest/gtest.h"
4+
#include "tests/util/util.h"
5+
#include "torch/csrc/jit/ir/irparser.h"
6+
7+
TEST(Converters, ATenTopKConvertsCorrectly) {
8+
const auto graph = R"IR(
9+
graph(%0 : Tensor):
10+
%1 : int = prim::Constant[value=20]()
11+
%2 : int = prim::Constant[value=-1]()
12+
%3 : bool = prim::Constant[value=1]()
13+
%4 : bool = prim::Constant[value=1]()
14+
%5 : Tensor, %6 : Tensor = aten::topk(%0, %1, %2, %3, %4)
15+
return (%5, %6))IR";
16+
17+
auto g = std::make_shared<torch::jit::Graph>();
18+
torch::jit::parseIR(graph, &*g);
19+
20+
auto in = at::rand({10, 10, 100}, {at::kCUDA});
21+
22+
auto params = trtorch::core::conversion::get_named_params(g->inputs(), {});
23+
auto jit_results = trtorch::tests::util::RunGraph(g, params, {in});
24+
25+
params = trtorch::core::conversion::get_named_params(g->inputs(), {});
26+
auto trt_results = trtorch::tests::util::RunGraphEngine(g, params, {in});
27+
28+
ASSERT_TRUE(trtorch::tests::util::almostEqual(jit_results[0], trt_results[0].reshape_as(jit_results[0]), 2e-6));
29+
ASSERT_TRUE(trtorch::tests::util::almostEqual(jit_results[1], trt_results[1].reshape_as(jit_results[1]), 2e-6));
30+
}

0 commit comments

Comments
 (0)