Skip to content

Commit 7456d31

Browse files
committed
fix: Error on aten::div with truncation
- `aten::div` with truncation on integer tensor inputs currently throws an error if both inputs are integer type, as the TRT unary operations for absolute value and floor do not apply to Int32 or Bool types - For absolute value, this is a legitimate bug as `aten::abs` is functional for integer types - For the floor operation, `aten::floor` does not explicitly support integer inputs, and `torch.floor()` does not work with Int32 inputs by default. However, `torch.div(..., rounding_mode="trunc")` with integer tensors does return an integer value, and so the corollary Torch-TRT converter should behave similarly - Modified `aten:abs` converter logic to be a utility, as it is used in multiple locations - Added regression test to ensure truncation divide with two integer tensors is functional
1 parent b5bcccf commit 7456d31

File tree

5 files changed

+83
-36
lines changed

5 files changed

+83
-36
lines changed

core/conversion/converters/converter_util.cpp

Lines changed: 31 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -156,6 +156,37 @@ nvinfer1::ILayer* add_elementwise(
156156
return ele;
157157
}
158158

159+
nvinfer1::ILayer* add_absolute_value(
160+
ConversionCtx* ctx,
161+
const torch::jit::Node* n,
162+
nvinfer1::ITensor* self,
163+
const std::string& name) {
164+
nvinfer1::ILayer* absolute_value;
165+
166+
// Check if TRT Unary ops support the input type
167+
bool unary_supported_input = (self->getType() == nvinfer1::DataType::kFLOAT) ||
168+
(self->getType() == nvinfer1::DataType::kHALF) || (self->getType() == nvinfer1::DataType::kINT8);
169+
if (unary_supported_input) {
170+
absolute_value = ctx->net->addUnary(*self, nvinfer1::UnaryOperation::kABS);
171+
TORCHTRT_CHECK(absolute_value, "Unable to create abs layer from node: " << *n);
172+
absolute_value->setName(name.c_str());
173+
} else {
174+
LOG_GRAPH(
175+
"Tensor is of unsupported type "
176+
<< self->getType() << " for IUnaryLayer::kABS. Using backup implementation via IElementWise (max(x, -x)");
177+
// For types not supported by kABS, use an elementwise implementation abs(x) = max(x, -1 * x)
178+
at::Tensor neg_one = torch::full({1}, -1).to(util::TRTDataTypeToScalarType(self->getType()));
179+
auto neg_one_const = tensor_to_const(ctx, neg_one);
180+
auto neg_layer = add_elementwise(
181+
ctx, nvinfer1::ElementWiseOperation::kPROD, self, neg_one_const, util::node_info(n) + std::string("_Negation"));
182+
TORCHTRT_CHECK(neg_layer, "Unable to create prod layer from node: " << *n);
183+
absolute_value = add_elementwise(ctx, nvinfer1::ElementWiseOperation::kMAX, self, neg_layer->getOutput(0), name);
184+
TORCHTRT_CHECK(absolute_value, "Unable to create max layer from node: " << *n);
185+
}
186+
187+
return absolute_value;
188+
}
189+
159190
nvinfer1::ITensor* applyIdentityOp(ConversionCtx* ctx, nvinfer1::ITensor* tensor, const std::string& tensor_name) {
160191
auto id_layer = ctx->net->addIdentity(*tensor);
161192
auto id_out_tensor = id_layer->getOutput(0);

core/conversion/converters/converter_util.h

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -42,6 +42,12 @@ nvinfer1::ILayer* add_elementwise(
4242
nvinfer1::ITensor* other,
4343
const std::string& name);
4444

45+
nvinfer1::ILayer* add_absolute_value(
46+
ConversionCtx* ctx,
47+
const torch::jit::Node* n,
48+
nvinfer1::ITensor* self,
49+
const std::string& name);
50+
4551
// Apply an identity operation on a tensor. Used in the case where an input is an output to a network.
4652
nvinfer1::ITensor* applyIdentityOp(ConversionCtx* ctx, nvinfer1::ITensor* tensor, const std::string& name);
4753

core/conversion/converters/impl/element_wise.cpp

Lines changed: 18 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -326,8 +326,24 @@ auto element_wise_registrations TORCHTRT_UNUSED =
326326
} else if (rounding_mode == "trunc") {
327327
// trunc = floor(abs(div)) * sign(div)
328328
auto tmp_div = add_elementwise(ctx, nvinfer1::ElementWiseOperation::kDIV, self, other, "tmp_div");
329-
auto abs = ctx->net->addUnary(*tmp_div->getOutput(0), nvinfer1::UnaryOperation::kABS);
330-
auto floor = ctx->net->addUnary(*abs->getOutput(0), nvinfer1::UnaryOperation::kFLOOR);
329+
auto abs = add_absolute_value(ctx, n, tmp_div->getOutput(0), util::node_info(n) + "_absolute_val");
330+
331+
// In this case, we allow the floor unary on non-TRT Unary types, as it is needed for this
332+
// specific function. Floor applied to non-float types equates to identity
333+
nvinfer1::ILayer* floor;
334+
if ((abs->getOutput(0)->getType() == nvinfer1::DataType::kINT32) ||
335+
(abs->getOutput(0)->getType() == nvinfer1::DataType::kBOOL)) {
336+
LOG_GRAPH(
337+
"Tensor is of unsupported type " << abs->getOutput(0)->getType()
338+
<< " for IUnaryLayer::kFLOOR. Using identity instead.");
339+
floor = ctx->net->addIdentity(*abs->getOutput(0));
340+
TORCHTRT_CHECK(floor, "Unable to create identity layer from node: " << *n);
341+
} else {
342+
floor = ctx->net->addUnary(*abs->getOutput(0), nvinfer1::UnaryOperation::kFLOOR);
343+
TORCHTRT_CHECK(floor, "Unable to create floor layer from node: " << *n);
344+
}
345+
floor->setName((util::node_info(n) + "_floor").c_str());
346+
331347
auto sign = ctx->net->addUnary(*tmp_div->getOutput(0), nvinfer1::UnaryOperation::kSIGN);
332348
div = add_elementwise(
333349
ctx,

core/conversion/converters/impl/unary.cpp

Lines changed: 4 additions & 34 deletions
Original file line numberDiff line numberDiff line change
@@ -13,40 +13,10 @@ namespace {
1313
auto abs_registration TORCHTRT_UNUSED = RegisterNodeConversionPatterns().pattern(
1414
{"aten::abs(Tensor self) -> Tensor", [](ConversionCtx* ctx, const torch::jit::Node* n, args& args) -> bool {
1515
auto in = args[0].ITensorOrFreeze(ctx);
16-
bool unary_supported_input = in->getType() == nvinfer1::DataType::kFLOAT ||
17-
in->getType() == nvinfer1::DataType::kHALF || in->getType() == nvinfer1::DataType::kINT8;
18-
if (unary_supported_input) {
19-
auto unary_layer = ctx->net->addUnary(*in, nvinfer1::UnaryOperation::kABS);
20-
TORCHTRT_CHECK(unary_layer, "Unable to create abs layer from node: " << *n);
21-
unary_layer->setName(util::node_info(n).c_str());
22-
auto out_tensor = ctx->AssociateValueAndTensor(n->outputs()[0], unary_layer->getOutput(0));
23-
LOG_DEBUG("Output tensor shape: " << out_tensor->getDimensions());
24-
return true;
25-
} else {
26-
LOG_GRAPH(
27-
"Tensor is of unsupported type "
28-
<< in->getType() << " for IUnaryLayer::kABS. Using backup implementation via IElementWise (max(x, -x)");
29-
// For types not supported by kABS, use an elementwise implementation abs(x) = max(x, -1 * x)
30-
at::Tensor neg_one = torch::full({1}, -1).to(util::TRTDataTypeToScalarType(in->getType()));
31-
auto neg_one_const = tensor_to_const(ctx, neg_one);
32-
auto neg_layer = add_elementwise(
33-
ctx,
34-
nvinfer1::ElementWiseOperation::kPROD,
35-
in,
36-
neg_one_const,
37-
util::node_info(n) + std::string("_Negation"));
38-
TORCHTRT_CHECK(neg_layer, "Unable to create prod layer from node: " << *n);
39-
auto max_layer = add_elementwise(
40-
ctx,
41-
nvinfer1::ElementWiseOperation::kMAX,
42-
in,
43-
neg_layer->getOutput(0),
44-
util::node_info(n) + std::string("_Max"));
45-
TORCHTRT_CHECK(max_layer, "Unable to create max layer from node: " << *n);
46-
auto out_tensor = ctx->AssociateValueAndTensor(n->outputs()[0], max_layer->getOutput(0));
47-
LOG_DEBUG("Output tensor shape: " << out_tensor->getDimensions());
48-
return true;
49-
}
16+
auto abs_layer = add_absolute_value(ctx, n, in, util::node_info(n));
17+
auto out_tensor = ctx->AssociateValueAndTensor(n->outputs()[0], abs_layer->getOutput(0));
18+
LOG_DEBUG("Output tensor shape: " << out_tensor->getDimensions());
19+
return true;
5020
}});
5121

5222
auto reciprocal_registration TORCHTRT_UNUSED = RegisterNodeConversionPatterns().pattern(

tests/core/conversion/converters/test_element_wise.cpp

Lines changed: 24 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@
44
#include "gtest/gtest.h"
55
#include "tests/util/util.h"
66
#include "torch/csrc/jit/ir/irparser.h"
7+
#include "torch/torch.h"
78

89
void pointwise_test_helper(
910
std::string graph_ir,
@@ -235,6 +236,29 @@ TEST(Converters, ATenDivRoundingNoneConvertsCorrectly) {
235236
pointwise_test_helper(graph, false, true, {4, 3}, {3, 4, 3}, true);
236237
}
237238

239+
TEST(Converters, ATenDivRoundingTruncWithIntsConvertsCorrectly) {
240+
const auto graph = R"IR(
241+
graph(%0 : Tensor, %1 : Tensor):
242+
%trunc : str = prim::Constant[value="trunc"]()
243+
%out : Tensor = aten::div(%0, %1, %trunc)
244+
return (%out))IR";
245+
246+
auto g = std::make_shared<torch::jit::Graph>();
247+
torch::jit::parseIR(graph, g.get());
248+
249+
// Avoid divide-by-zero issues by making denominator >= 1
250+
auto in_0 = at::randint(-5, 5, {4, 1, 7, 8}, {at::kCUDA}).to(torch::kInt32);
251+
auto in_1 = at::randint(1, 10, {4, 1, 7, 8}, {at::kCUDA}).to(torch::kInt32);
252+
253+
auto params = torch_tensorrt::core::ir::get_static_params(g->inputs(), {});
254+
auto jit_results = torch_tensorrt::tests::util::RunGraph(g, params, {in_0, in_1});
255+
256+
params = torch_tensorrt::core::ir::get_static_params(g->inputs(), {});
257+
auto trt_results = torch_tensorrt::tests::util::RunGraphEngine(g, params, {in_0, in_1});
258+
259+
ASSERT_TRUE(torch_tensorrt::tests::util::exactlyEqual(jit_results[0], trt_results[0].reshape_as(jit_results[0])));
260+
}
261+
238262
TEST(Converters, ATenPowTensorConvertsCorrectly) {
239263
const auto graph = R"IR(
240264
graph(%x.1 : Tensor, %x2.1 : Tensor):

0 commit comments

Comments
 (0)