|
| 1 | +#include "NvInfer.h" |
| 2 | +#include "core/conversion/converters/converters.h" |
| 3 | +#include "core/conversion/tensorcontainer/TensorContainer.h" |
| 4 | +#include "core/util/prelude.h" |
| 5 | +#include "torch/torch.h" |
| 6 | + |
| 7 | +#include <ATen/ATen.h> |
| 8 | +#include <vector> |
| 9 | + |
| 10 | +namespace torch_tensorrt { |
| 11 | +namespace core { |
| 12 | +namespace conversion { |
| 13 | +namespace converters { |
| 14 | +namespace impl { |
| 15 | +namespace { |
| 16 | +auto max_registrations TORCHTRT_UNUSED = RegisterNodeConversionPatterns().pattern( |
| 17 | + {"aten::max.dim(Tensor self, int dim, bool keepdim=False) -> (Tensor values, Tensor indices)", |
| 18 | + [](ConversionCtx* ctx, const torch::jit::Node* n, args& args) -> bool { |
| 19 | + auto self = args[0].ITensorOrFreeze(ctx); |
| 20 | + auto k = 1; |
| 21 | + auto dim = args[1].unwrapToInt(); |
| 22 | + auto largest = true; |
| 23 | + auto selfDim = util::toVec(self->getDimensions()); |
| 24 | + if (dim < 0) { |
| 25 | + dim = selfDim.size() + dim; |
| 26 | + } |
| 27 | + uint32_t shiftDim = 1 << dim; |
| 28 | + |
| 29 | + auto TopKOperation = largest ? (nvinfer1::TopKOperation::kMAX) : (nvinfer1::TopKOperation::kMIN); |
| 30 | + |
| 31 | + auto new_layer = ctx->net->addTopK(*self, TopKOperation, 1, shiftDim); |
| 32 | + TORCHTRT_CHECK(new_layer, "Unable to create max layer from node: " << *n); |
| 33 | + |
| 34 | + auto out0 = ctx->AssociateValueAndTensor(n->outputs()[0], new_layer->getOutput(0)); |
| 35 | + auto out1 = ctx->AssociateValueAndTensor(n->outputs()[1], new_layer->getOutput(1)); |
| 36 | + |
| 37 | + LOG_DEBUG("Output tensor(0) shape: " << out0->getDimensions()); |
| 38 | + LOG_DEBUG("Output tensor(1) shape: " << out1->getDimensions()); |
| 39 | + |
| 40 | + return true; |
| 41 | + }}); |
| 42 | +} // namespace |
| 43 | +} // namespace impl |
| 44 | +} // namespace converters |
| 45 | +} // namespace conversion |
| 46 | +} // namespace core |
| 47 | +} // namespace torch_tensorrt |
0 commit comments