Skip to content

Commit 569bcde

Browse files
committed
feat: Add converter files for torch::max
Signed-off-by: hongwei03 <[email protected]>
1 parent 11bcb98 commit 569bcde

File tree

1 file changed

+47
-0
lines changed
  • core/conversion/converters/impl

1 file changed

+47
-0
lines changed
Lines changed: 47 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,47 @@
1+
#include "NvInfer.h"
2+
#include "core/conversion/converters/converters.h"
3+
#include "core/conversion/tensorcontainer/TensorContainer.h"
4+
#include "core/util/prelude.h"
5+
#include "torch/torch.h"
6+
7+
#include <ATen/ATen.h>
8+
#include <vector>
9+
10+
namespace torch_tensorrt {
11+
namespace core {
12+
namespace conversion {
13+
namespace converters {
14+
namespace impl {
15+
namespace {
16+
auto max_registrations TORCHTRT_UNUSED = RegisterNodeConversionPatterns().pattern(
17+
{"aten::max.dim(Tensor self, int dim, bool keepdim=False) -> (Tensor values, Tensor indices)",
18+
[](ConversionCtx* ctx, const torch::jit::Node* n, args& args) -> bool {
19+
auto self = args[0].ITensorOrFreeze(ctx);
20+
auto k = 1;
21+
auto dim = args[1].unwrapToInt();
22+
auto largest = true;
23+
auto selfDim = util::toVec(self->getDimensions());
24+
if (dim < 0) {
25+
dim = selfDim.size() + dim;
26+
}
27+
uint32_t shiftDim = 1 << dim;
28+
29+
auto TopKOperation = largest ? (nvinfer1::TopKOperation::kMAX) : (nvinfer1::TopKOperation::kMIN);
30+
31+
auto new_layer = ctx->net->addTopK(*self, TopKOperation, 1, shiftDim);
32+
TORCHTRT_CHECK(new_layer, "Unable to create max layer from node: " << *n);
33+
34+
auto out0 = ctx->AssociateValueAndTensor(n->outputs()[0], new_layer->getOutput(0));
35+
auto out1 = ctx->AssociateValueAndTensor(n->outputs()[1], new_layer->getOutput(1));
36+
37+
LOG_DEBUG("Output tensor(0) shape: " << out0->getDimensions());
38+
LOG_DEBUG("Output tensor(1) shape: " << out1->getDimensions());
39+
40+
return true;
41+
}});
42+
} // namespace
43+
} // namespace impl
44+
} // namespace converters
45+
} // namespace conversion
46+
} // namespace core
47+
} // namespace torch_tensorrt

0 commit comments

Comments
 (0)