Skip to content

Commit bce8464

Browse files
committed
feat(element_wise): Auto cast to higher precision for mismatched types
Signed-off-by: Naren Dasan <[email protected]> Signed-off-by: Naren Dasan <[email protected]>
1 parent 5fadfd4 commit bce8464

File tree

4 files changed

+47
-4
lines changed

4 files changed

+47
-4
lines changed

core/conversion/converters/converter_util.cpp

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -65,6 +65,13 @@ nvinfer1::ILayer* add_elementwise(
6565
nvinfer1::ITensor* self,
6666
nvinfer1::ITensor* other,
6767
const std::string& name) {
68+
if (self->getType() == nvinfer1::DataType::kFLOAT && other->getType() == nvinfer1::DataType::kINT32) {
69+
LOG_DEBUG("Type mismatch, casting other to " << self->getType());
70+
other = castITensor(ctx, other, self->getType());
71+
} else if (self->getType() == nvinfer1::DataType::kINT32 && other->getType() == nvinfer1::DataType::kFLOAT) {
72+
LOG_DEBUG("Type mismatch, casting self to " << other->getType());
73+
self = castITensor(ctx, self, other->getType());
74+
}
6875
// ensure self to have larger number of dimension
6976
bool swapSelfOther = false;
7077
if (self->getDimensions().nbDims < other->getDimensions().nbDims) {

core/conversion/converters/impl/element_wise.cpp

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -412,6 +412,7 @@ auto element_wise_registrations TORCHTRT_UNUSED =
412412
// Should implement self * other
413413
auto self = args[0].ITensorOrFreeze(ctx);
414414
auto other = args[1].ITensorOrFreeze(ctx);
415+
415416
auto mul =
416417
add_elementwise(ctx, nvinfer1::ElementWiseOperation::kPROD, self, other, util::node_info(n));
417418
TORCHTRT_CHECK(mul, "Unable to create mul layer from node: " << *n);
@@ -426,6 +427,7 @@ auto element_wise_registrations TORCHTRT_UNUSED =
426427
// TODO: Remove with functionalization
427428
auto self = args[0].ITensorOrFreeze(ctx);
428429
auto other = scalar_to_tensor(ctx, args[1].unwrapToScalar());
430+
429431
auto mul =
430432
add_elementwise(ctx, nvinfer1::ElementWiseOperation::kPROD, self, other, util::node_info(n));
431433
TORCHTRT_CHECK(mul, "Unable to create mul layer from node: " << *n);

tests/core/conversion/converters/test_element_wise.cpp

Lines changed: 35 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,9 @@ void pointwise_test_helper(
1212
std::vector<int64_t> shape1 = {5},
1313
std::vector<int64_t> shape2 = {5},
1414
bool negative_input = false,
15-
bool int_tensors = false) {
15+
bool int_tensors = false,
16+
bool float_int_tensors = false,
17+
bool int_float_tensors = false) {
1618
auto g = std::make_shared<torch::jit::Graph>();
1719
torch::jit::parseIR(graph_ir, g.get());
1820

@@ -27,11 +29,24 @@ void pointwise_test_helper(
2729
if (!singleInput) {
2830
torch_inputs.push_back(at::randint(1, 5, shape2, {at::kCUDA}));
2931
}
32+
33+
TORCHTRT_CHECK(!((int_tensors && (float_int_tensors || int_float_tensors)) || (float_int_tensors && int_float_tensors)),
34+
"Invalid test configuration, only one of int_tensors, float_int_tensors, int_float_tensors can be true");
35+
3036
if(int_tensors){
3137
for(size_t i = 0UL; i < torch_inputs.size(); ++i){
3238
torch_inputs[i] = torch_inputs[i].to(at::kInt);
3339
}
40+
} else if(float_int_tensors) {
41+
TORCHTRT_CHECK(!singleInput, "float_int_tensors tests require two inputs");
42+
torch_inputs[0] = torch_inputs[0].to(at::kFloat);
43+
torch_inputs[1] = torch_inputs[1].to(at::kInt);
44+
} else if (int_float_tensors) {
45+
TORCHTRT_CHECK(!singleInput, "int_float_tensors tests require two inputs");
46+
torch_inputs[0] = torch_inputs[0].to(at::kInt);
47+
torch_inputs[1] = torch_inputs[1].to(at::kFloat);
3448
}
49+
3550
auto params = torch_tensorrt::core::ir::get_static_params(g->inputs(), {});
3651
auto jit_results = torch_tensorrt::tests::util::RunGraph(g, params, torch_inputs);
3752

@@ -62,6 +77,8 @@ TEST(Converters, ATenAddConvertsCorrectly) {
6277
pointwise_test_helper(graph, false, false, {4}, {3, 4});
6378
pointwise_test_helper(graph, false, true, {3, 4, 3}, {4, 3});
6479
pointwise_test_helper(graph, false, true, {4, 3}, {3, 4, 3});
80+
pointwise_test_helper(graph, false, true, {5}, {5}, false, false, true);
81+
pointwise_test_helper(graph, false, true, {5}, {5}, false, false, false, true);
6582
}
6683

6784
TEST(Converters, ATenAddWithAlphaConvertsCorrectly) {
@@ -75,9 +92,11 @@ TEST(Converters, ATenAddWithAlphaConvertsCorrectly) {
7592
pointwise_test_helper(graph, false, false, {4}, {3, 4});
7693
pointwise_test_helper(graph, false, true, {3, 4, 3}, {4, 3});
7794
pointwise_test_helper(graph, false, true, {4, 3}, {3, 4, 3});
95+
pointwise_test_helper(graph, false, true, {5}, {5}, false, false, true);
96+
pointwise_test_helper(graph, false, true, {5}, {5}, false, false, false, true);
7897
}
7998

80-
TEST(Converters, ATenAddImplicitWithAlphaConvertsCorrectly) {
99+
TEST(Converters, ATenAddInplaceWithAlphaConvertsCorrectly) {
81100
const auto graph = R"IR(
82101
graph(%0 : Tensor, %1 : Tensor):
83102
%2 : float = prim::Constant[value=7.6]()
@@ -109,6 +128,8 @@ TEST(Converters, ATenSubConvertsCorrectly) {
109128
pointwise_test_helper(graph, false, false, {4}, {3, 4});
110129
pointwise_test_helper(graph, false, true, {3, 4, 3}, {4, 3});
111130
pointwise_test_helper(graph, false, true, {4, 3}, {3, 4, 3});
131+
pointwise_test_helper(graph, false, true, {5}, {5}, false, false, true);
132+
pointwise_test_helper(graph, false, true, {5}, {5}, false, false, false, true);
112133
}
113134

114135
TEST(Converters, ATenMulConvertsCorrectly) {
@@ -121,6 +142,8 @@ TEST(Converters, ATenMulConvertsCorrectly) {
121142
pointwise_test_helper(graph, false, false, {4}, {3, 4});
122143
pointwise_test_helper(graph, false, true, {3, 4, 3}, {4, 3});
123144
pointwise_test_helper(graph, false, true, {4, 3}, {3, 4, 3});
145+
pointwise_test_helper(graph, false, true, {5}, {5}, false, false, true);
146+
pointwise_test_helper(graph, false, true, {5}, {5}, false, false, false, true);
124147
}
125148

126149
TEST(Converters, ATenMulWithScalarConvertsCorrectly) {
@@ -151,6 +174,8 @@ TEST(Converters, ATenDivConvertsCorrectly) {
151174
pointwise_test_helper(graph, false, false, {4}, {3, 4});
152175
pointwise_test_helper(graph, false, true, {3, 4, 3}, {4, 3});
153176
pointwise_test_helper(graph, false, true, {4, 3}, {3, 4, 3});
177+
pointwise_test_helper(graph, false, true, {5}, {5}, false, false, true);
178+
pointwise_test_helper(graph, false, true, {5}, {5}, false, false, false, true);
154179
}
155180

156181
TEST(Converters, ATenDivWithScalarConvertsCorrectly) {
@@ -173,6 +198,8 @@ TEST(Converters, ATenDivRoundingFloorConvertsCorrectly) {
173198
pointwise_test_helper(graph, false, false, {4}, {3, 4}, true);
174199
pointwise_test_helper(graph, false, true, {3, 4, 3}, {4, 3}, true);
175200
pointwise_test_helper(graph, false, true, {4, 3}, {3, 4, 3}, true);
201+
pointwise_test_helper(graph, false, true, {5}, {5}, false, false, true);
202+
pointwise_test_helper(graph, false, true, {5}, {5}, false, false, false, true);
176203
}
177204

178205
TEST(Converters, ATenDivRoundingTruncConvertsCorrectly) {
@@ -186,6 +213,8 @@ TEST(Converters, ATenDivRoundingTruncConvertsCorrectly) {
186213
pointwise_test_helper(graph, false, false, {4}, {3, 4}, true);
187214
pointwise_test_helper(graph, false, true, {3, 4, 3}, {4, 3}, true);
188215
pointwise_test_helper(graph, false, true, {4, 3}, {3, 4, 3}, true);
216+
pointwise_test_helper(graph, false, true, {5}, {5}, false, false, true);
217+
pointwise_test_helper(graph, false, true, {5}, {5}, false, false, false, true);
189218
}
190219

191220
TEST(Converters, ATenDivRoundingNoneConvertsCorrectly) {
@@ -211,6 +240,8 @@ TEST(Converters, ATenPowTensorConvertsCorrectly) {
211240
pointwise_test_helper(graph, false, false, {4}, {3, 4});
212241
pointwise_test_helper(graph, false, true, {3, 4, 3}, {4, 3});
213242
pointwise_test_helper(graph, false, true, {4, 3}, {3, 4, 3});
243+
pointwise_test_helper(graph, false, true, {5}, {5}, false, false, true);
244+
pointwise_test_helper(graph, false, true, {5}, {5}, false, false, false, true);
214245
}
215246

216247
TEST(Converters, ATenPowScalarConvertsCorrectly) {
@@ -251,6 +282,8 @@ TEST(Converters, ATenFloorDivideConvertsCorrectly) {
251282
pointwise_test_helper(graph, false, false, {4}, {3, 4});
252283
pointwise_test_helper(graph, false, true, {3, 4, 3}, {4, 3});
253284
pointwise_test_helper(graph, false, true, {4, 3}, {3, 4, 3});
285+
pointwise_test_helper(graph, false, true, {5}, {5}, false, false, true);
286+
pointwise_test_helper(graph, false, true, {5}, {5}, false, false, false, true);
254287
}
255288

256289
TEST(Converters, ATenFloorDivideWithScalarConvertsCorrectly) {

tests/util/run_graph_engine.cpp

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -30,6 +30,7 @@ std::vector<core::ir::Input> toInputsDynamic(std::vector<at::Tensor> ten, bool d
3030

3131
for (auto i : ten) {
3232
auto opt = core::util::toVec(i.sizes());
33+
auto dtype = core::util::ScalarTypeToTRTDataType(i.scalar_type());
3334

3435
if (dynamic_batch) {
3536
std::vector<int64_t> min_range(opt);
@@ -38,15 +39,15 @@ std::vector<core::ir::Input> toInputsDynamic(std::vector<at::Tensor> ten, bool d
3839
min_range[0] = ceil(opt[0] / 2.0);
3940
max_range[0] = 2 * opt[0];
4041

41-
a.push_back(core::ir::Input(min_range, opt, max_range));
42+
a.push_back(core::ir::Input(min_range, opt, max_range, dtype));
4243
} else {
4344
std::vector<int64_t> min_range(opt);
4445
std::vector<int64_t> max_range(opt);
4546

4647
min_range[1] = ceil(opt[1] / 2.0);
4748
max_range[1] = 2 * opt[1];
4849

49-
a.push_back(core::ir::Input(min_range, opt, max_range));
50+
a.push_back(core::ir::Input(min_range, opt, max_range, dtype));
5051
}
5152
}
5253

0 commit comments

Comments
 (0)