Skip to content

Commit e7bd9ac

Browse files
committed
Support aten::sum with bool tensor input
TensorRT sum layers do not support bool tensor inputs. Add support by casting the input to int32. Fixes # (issue) Please delete options that are not relevant and/or add your own. - Bug fix (non-breaking change which fixes an issue) - New feature (non-breaking change which adds functionality) - Breaking change (fix or feature that would cause existing functionality to not work as expected) - This change requires a documentation update - [ ] My code follows the style guidelines of this project (You can use the linters) - [ ] I have performed a self-review of my own code - [ ] I have commented my code, particularly in hard-to-understand areas and hacks - [ ] I have made corresponding changes to the documentation - [ ] I have added tests to verify my fix or my feature - [ ] New and existing unit tests pass locally with my changes - [ ] I have added the relevant labels to my PR in so that relevant reviewers are notified
1 parent 3decf45 commit e7bd9ac

File tree

2 files changed

+17
-0
lines changed

2 files changed

+17
-0
lines changed

core/conversion/converters/impl/reduce.cpp

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -71,6 +71,13 @@ auto reduce_registrations TORCHTRT_UNUSED =
7171
auto in_tensor = args[0].ITensorOrFreeze(ctx);
7272
auto in_dims = util::toVec(in_tensor->getDimensions());
7373
LOG_WARNING("Sum Converter disregards dtype");
74+
75+
if (in_tensor->getType() == nvinfer1::DataType::kBOOL) {
76+
LOG_DEBUG(
77+
"Found type " << in_tensor->getType() << " in aten::sum, casting to "
78+
<< nvinfer1::DataType::kINT32 << " for compatibility.");
79+
in_tensor = castITensor(ctx, in_tensor, nvinfer1::DataType::kINT32);
80+
}
7481

7582
uint32_t axis_mask = (uint32_t)(((uint64_t)1 << in_dims.size()) - 1);
7683

tests/core/conversion/converters/test_reduce.cpp

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -137,6 +137,16 @@ converts_keepdims_correctly(mean, Mean);
137137

138138
#undef converts_keepdims_correctly
139139

140+
TEST(Converters, ATenSumBoolConvertsCorrectly) {
141+
const auto graph = R"IR(
142+
graph(%0 : Tensor):
143+
%4 : None = prim::Constant()
144+
%5 : Tensor = aten::sum(%0, %4)
145+
return (%5))IR";
146+
auto in = at::randint(-1, 2, {4, 4, 4}, at::kCUDA).to(at::kBool);
147+
test_body(graph, in);
148+
}
149+
140150
TEST(Converters, ATenSumDimNegOneIndexConvertsCorrectly) {
141151
const auto graph = R"IR(
142152
graph(%0 : Tensor):

0 commit comments

Comments
 (0)