Skip to content

Commit b91e3e0

Browse files
authored
Merge pull request #1512 from mfeliz-cruise/michael.feliz/sum_bool
Support aten::sum with bool tensor input
2 parents 3decf45 + 20a0716 commit b91e3e0

File tree

2 files changed

+17
-0
lines changed

2 files changed

+17
-0
lines changed

core/conversion/converters/impl/reduce.cpp

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -72,6 +72,13 @@ auto reduce_registrations TORCHTRT_UNUSED =
7272
auto in_dims = util::toVec(in_tensor->getDimensions());
7373
LOG_WARNING("Sum Converter disregards dtype");
7474

75+
if (in_tensor->getType() == nvinfer1::DataType::kBOOL) {
76+
LOG_DEBUG(
77+
"Found type " << in_tensor->getType() << " in aten::sum, casting to "
78+
<< nvinfer1::DataType::kINT32 << " for compatibility.");
79+
in_tensor = castITensor(ctx, in_tensor, nvinfer1::DataType::kINT32);
80+
}
81+
7582
uint32_t axis_mask = (uint32_t)(((uint64_t)1 << in_dims.size()) - 1);
7683

7784
auto sum_layer = ctx->net->addReduce(*in_tensor, nvinfer1::ReduceOperation::kSUM, axis_mask, false);

tests/core/conversion/converters/test_reduce.cpp

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -137,6 +137,16 @@ converts_keepdims_correctly(mean, Mean);
137137

138138
#undef converts_keepdims_correctly
139139

140+
TEST(Converters, ATenSumBoolConvertsCorrectly) {
141+
const auto graph = R"IR(
142+
graph(%0 : Tensor):
143+
%4 : None = prim::Constant()
144+
%5 : Tensor = aten::sum(%0, %4)
145+
return (%5))IR";
146+
auto in = at::randint(-1, 2, {4, 4, 4}, at::kCUDA).to(at::kBool);
147+
test_body(graph, in);
148+
}
149+
140150
TEST(Converters, ATenSumDimNegOneIndexConvertsCorrectly) {
141151
const auto graph = R"IR(
142152
graph(%0 : Tensor):

0 commit comments

Comments
 (0)