Skip to content

Commit e657c1c

Browse files
authored
Merge pull request #1707 from mfeliz-cruise/michael.feliz/aten_any
[feat] Add converter for aten::any.dim
2 parents e9da9b0 + 84964a5 commit e657c1c

File tree

2 files changed

+78
-1
lines changed

2 files changed

+78
-1
lines changed

core/conversion/converters/impl/reduce.cpp

Lines changed: 34 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -203,7 +203,8 @@ auto reduce_registrations TORCHTRT_UNUSED =
203203
return true;
204204
}})
205205
.pattern(
206-
{"aten::min(Tensor self) -> Tensor", [](ConversionCtx* ctx, const torch::jit::Node* n, args& args) -> bool {
206+
{"aten::min(Tensor self) -> Tensor",
207+
[](ConversionCtx* ctx, const torch::jit::Node* n, args& args) -> bool {
207208
auto in_tensor = args[0].ITensorOrFreeze(ctx);
208209
auto in_dims = util::toVec(in_tensor->getDimensions());
209210

@@ -216,6 +217,38 @@ auto reduce_registrations TORCHTRT_UNUSED =
216217
min_layer->setName(util::node_info(n).c_str());
217218
auto out_tensor = ctx->AssociateValueAndTensor(n->outputs()[0], min_layer->getOutput(0));
218219

220+
LOG_DEBUG("Output shape: " << out_tensor->getDimensions());
221+
return true;
222+
}})
223+
.pattern(
224+
{"aten::any.dim(Tensor self, int dim, bool keepdim=False) -> Tensor",
225+
[](ConversionCtx* ctx, const torch::jit::Node* n, args& args) -> bool {
226+
auto in_tensor = args[0].ITensorOrFreeze(ctx);
227+
auto in_dims = in_tensor->getDimensions();
228+
auto dim = args[1].unwrapToInt();
229+
LOG_DEBUG("Dim to reduce (original): " << dim);
230+
dim = dim < 0 ? (in_dims.nbDims + dim) : dim;
231+
LOG_DEBUG("Dim to reduce (converted): " << dim);
232+
233+
uint32_t axis_mask = 1 << dim;
234+
LOG_DEBUG("Axis Mask: " << std::bitset<32>(axis_mask));
235+
236+
auto keepdim = args[2].unwrapToBool();
237+
LOG_DEBUG("Keep dims: " << keepdim);
238+
239+
// Reduce does not work on bool inputs
240+
if (in_tensor->getType() == nvinfer1::DataType::kBOOL) {
241+
in_tensor =
242+
castITensor(ctx, in_tensor, nvinfer1::DataType::kINT32, (util::node_info(n) + "_in").c_str());
243+
}
244+
auto sum_layer = ctx->net->addReduce(*in_tensor, nvinfer1::ReduceOperation::kSUM, axis_mask, keepdim);
245+
246+
TORCHTRT_CHECK(sum_layer, "Unable to create sum layer from node: " << *n);
247+
248+
sum_layer->setName(util::node_info(n).c_str());
249+
auto out_tensor = castITensor(
250+
ctx, sum_layer->getOutput(0), nvinfer1::DataType::kBOOL, (util::node_info(n) + "_out").c_str());
251+
out_tensor = ctx->AssociateValueAndTensor(n->outputs()[0], out_tensor);
219252
LOG_DEBUG("Output shape: " << out_tensor->getDimensions());
220253
return true;
221254
}});

tests/core/conversion/converters/test_reduce.cpp

Lines changed: 44 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -300,6 +300,50 @@ TEST(Converters, ATenMeanDimNegIndexKeepDimsConvertsCorrectly) {
300300
test_body(graph, in);
301301
}
302302

303+
TEST(Converters, ATenAnyDimConvertsCorrectly) {
304+
const auto graph = R"IR(
305+
graph(%0 : Tensor):
306+
%1 : int = prim::Constant[value=1]()
307+
%3 : bool = prim::Constant[value=0]()
308+
%5 : Tensor = aten::any(%0, %1, %3)
309+
return (%5))IR";
310+
auto in = at::randint(0, 2, {4, 4, 4}, at::kCUDA);
311+
test_body(graph, in);
312+
}
313+
314+
TEST(Converters, ATenAnyDimAllFalseConvertsCorrectly) {
315+
const auto graph = R"IR(
316+
graph(%0 : Tensor):
317+
%1 : int = prim::Constant[value=2]()
318+
%3 : bool = prim::Constant[value=0]()
319+
%5 : Tensor = aten::any(%0, %1, %3)
320+
return (%5))IR";
321+
auto in = at::zeros({3, 7, 4}, at::kCUDA).to(torch::kBool);
322+
test_body(graph, in);
323+
}
324+
325+
TEST(Converters, ATenAnyDimKeepDimConvertsCorrectly) {
326+
const auto graph = R"IR(
327+
graph(%0 : Tensor):
328+
%1 : int = prim::Constant[value=1]()
329+
%3 : bool = prim::Constant[value=1]()
330+
%5 : Tensor = aten::any(%0, %1, %3)
331+
return (%5))IR";
332+
auto in = at::randint(0, 2, {4, 4, 4}, at::kCUDA).to(torch::kHalf);
333+
test_body(graph, in);
334+
}
335+
336+
TEST(Converters, ATenAnyDimNegIndexConvertsCorrectly) {
337+
const auto graph = R"IR(
338+
graph(%0 : Tensor):
339+
%1 : int = prim::Constant[value=-1]()
340+
%3 : bool = prim::Constant[value=1]()
341+
%5 : Tensor = aten::any(%0, %1, %3)
342+
return (%5))IR";
343+
auto in = at::randint(-2, 2, {2, 32}, at::kCUDA);
344+
test_body(graph, in);
345+
}
346+
303347
TEST(Converters, UnpackVarLowersCorrectly) {
304348
const auto graph = R"IR(
305349
graph(%x.1 : Tensor):

0 commit comments

Comments
 (0)