Skip to content

Commit f7a97cb

Browse files
committed
fix: Repair Citrinet-1024 compilation issues
- Enable automatic type-casting in `aten::sum` for bool tensor inputs to agree with Torch casting behavior - Fix bug in `aten::div` where all internal div layers have the same name
1 parent 1361028 commit f7a97cb

File tree

2 files changed

+10
-1
lines changed

2 files changed

+10
-1
lines changed

core/conversion/converters/impl/element_wise.cpp

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -325,7 +325,8 @@ auto element_wise_registrations TORCHTRT_UNUSED =
325325
add_elementwise(ctx, nvinfer1::ElementWiseOperation::kFLOOR_DIV, self, other, util::node_info(n));
326326
} else if (rounding_mode == "trunc") {
327327
// trunc = floor(abs(div)) * sign(div)
328-
auto tmp_div = add_elementwise(ctx, nvinfer1::ElementWiseOperation::kDIV, self, other, "tmp_div");
328+
auto tmp_div = add_elementwise(
329+
ctx, nvinfer1::ElementWiseOperation::kDIV, self, other, util::node_info(n) + "_tmp_div");
329330
auto abs = add_abs(ctx, n, tmp_div->getOutput(0), util::node_info(n) + "_absolute_val");
330331

331332
// In this case, we allow the floor unary on non-TRT Unary types, as it is needed for this

core/conversion/converters/impl/reduce.cpp

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -113,6 +113,14 @@ auto reduce_registrations TORCHTRT_UNUSED =
113113
LOG_DEBUG("Keep dims: " << keepdim);
114114

115115
LOG_WARNING("Sum converter disregards dtype");
116+
117+
if (in_tensor->getType() == nvinfer1::DataType::kBOOL) {
118+
LOG_DEBUG(
119+
"Found type " << in_tensor->getType() << " in aten::sum, casting to "
120+
<< nvinfer1::DataType::kINT32 << " for compatibility.");
121+
in_tensor = castITensor(ctx, in_tensor, nvinfer1::DataType::kINT32);
122+
}
123+
116124
auto sum_layer = ctx->net->addReduce(*in_tensor, nvinfer1::ReduceOperation::kSUM, axis_mask, keepdim);
117125

118126
TORCHTRT_CHECK(sum_layer, "Unable to create sum layer from node: " << *n);

0 commit comments

Comments
 (0)