Skip to content

Commit 4793bb2

Browse files
feat: Add support for aten::meshgrid (#1601)
1 parent e612746 commit 4793bb2

File tree

2 files changed

+118
-0
lines changed

2 files changed

+118
-0
lines changed

core/conversion/converters/impl/expand.cpp

Lines changed: 83 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -393,6 +393,89 @@ auto expand_registrations TORCHTRT_UNUSED =
393393
auto out_tensor = ctx->AssociateValueAndTensor(n->outputs()[0], collapse->getOutput(0));
394394
LOG_DEBUG("Output tensor shape: " << out_tensor->getDimensions());
395395

396+
return true;
397+
}})
398+
.pattern(
399+
{"aten::meshgrid(Tensor[] tensors) -> (Tensor[])",
400+
[](ConversionCtx* ctx, const torch::jit::Node* n, args& args) -> bool {
401+
// torch.meshgrid only supports 1D or 0D input tensors
402+
auto arg_tensors = args[0].IValue()->toListRef();
403+
std::vector<nvinfer1::ITensor*> tensors;
404+
for (auto t : arg_tensors) {
405+
if (t.isTensor()) {
406+
auto torch_tensor = t.toTensor();
407+
tensors.push_back(tensor_to_const(ctx, torch_tensor));
408+
} else {
409+
auto cont = t.toCustomClass<TensorContainer>();
410+
tensors.push_back(cont->tensor());
411+
}
412+
}
413+
414+
// build the output shape for all tensors in the output list
415+
nvinfer1::Dims output_dims;
416+
output_dims.nbDims = tensors.size();
417+
for (size_t idx = 0UL; idx < tensors.size(); ++idx) {
418+
auto dims = tensors[idx]->getDimensions();
419+
output_dims.d[idx] = dims.nbDims == 0 ? 1 : dims.d[0];
420+
}
421+
std::vector<nvinfer1::ITensor*> out_tensors;
422+
// Reshape tensors into output shape (reshape, expand)
423+
for (size_t idx = 0UL; idx < tensors.size(); ++idx) {
424+
auto t = tensors[idx];
425+
auto dims = t->getDimensions();
426+
nvinfer1::Dims reshape_dims;
427+
reshape_dims.nbDims = tensors.size();
428+
for (size_t reshape_idx = 0UL; reshape_idx < tensors.size(); ++reshape_idx) {
429+
if (reshape_idx == idx) {
430+
reshape_dims.d[reshape_idx] = dims.nbDims == 0 ? 1 : dims.d[0];
431+
} else {
432+
reshape_dims.d[reshape_idx] = 1;
433+
}
434+
}
435+
// Add a reshape layer before expanding dims
436+
auto reshape_layer = ctx->net->addShuffle(*t);
437+
reshape_layer->setReshapeDimensions(reshape_dims);
438+
std::stringstream reshape_layer_name;
439+
reshape_layer_name << util::node_info(n) << "_meshgrid_reshape_" << std::to_string(idx);
440+
reshape_layer->setName(reshape_layer_name.str().c_str());
441+
auto reshaped = reshape_layer->getOutput(0);
442+
LOG_DEBUG("Tensor " << idx << " reshaped to : " << reshaped->getDimensions() << " from " << dims);
443+
444+
// Add slice layer for expansion
445+
std::vector<int64_t> start_vec(output_dims.nbDims, 0);
446+
auto start_offset = util::toDims(c10::IntArrayRef(start_vec));
447+
448+
std::vector<int64_t> strides_vec(output_dims.nbDims, 0);
449+
for (int64_t i = 0; i < output_dims.nbDims; i++) {
450+
strides_vec[i] = (reshaped->getDimensions().d[i] != 1);
451+
}
452+
453+
auto strides = util::toDims(c10::IntArrayRef(strides_vec));
454+
455+
auto slice_layer = ctx->net->addSlice(*reshaped, start_offset, output_dims, strides);
456+
std::stringstream slice_layer_name;
457+
slice_layer_name << util::node_info(n) << "_meshgrid_slice_" << std::to_string(idx);
458+
slice_layer->setName(slice_layer_name.str().c_str());
459+
auto slice_output = slice_layer->getOutput(0);
460+
LOG_DEBUG("Tensor " << idx << " expanded to : " << slice_output->getDimensions());
461+
out_tensors.push_back(slice_output);
462+
}
463+
464+
// Pack output tensors into list
465+
c10::ListTypePtr lt = n->output()->type()->expect<c10::ListType>();
466+
c10::TypePtr elementType = lt->getElementType();
467+
auto list = c10::impl::GenericList(elementType);
468+
list.reserve(out_tensors.size());
469+
470+
for (auto t : out_tensors) {
471+
auto tensor_holder = TensorContainer();
472+
tensor_holder.hold_tensor(t);
473+
auto ival = c10::IValue(std::move(c10::make_intrusive<TensorContainer>(tensor_holder)));
474+
list.emplace_back(ival);
475+
}
476+
477+
auto output_list = std::move(torch::jit::IValue(list));
478+
ctx->AssociateValueAndIValue(n->outputs()[0], output_list);
396479
return true;
397480
}});
398481

tests/core/conversion/converters/test_expand.cpp

Lines changed: 35 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -669,3 +669,38 @@ TEST(Converters, ATenRepeatInterleave3dScalarNoDimConvertsCorrectlyWithDynamicIn
669669

670670
ASSERT_TRUE(torch_tensorrt::tests::util::almostEqual(jit_results[0], trt, 2e-6));
671671
}
672+
673+
TEST(Converters, ATenMeshGridConvertsCorrectly) {
674+
const auto graph = R"IR(
675+
graph(%x : Tensor, %y : Tensor, %z : Tensor):
676+
%0 : Tensor[] = prim::ListConstruct(%x, %y, %z)
677+
%1 : Tensor[] = aten::meshgrid(%0)
678+
%x_0 : Tensor, %y_0 : Tensor, %z_0 : Tensor = prim::ListUnpack(%1)
679+
return (%x_0, %y_0, %z_0))IR";
680+
681+
auto g = std::make_shared<torch::jit::Graph>();
682+
683+
torch::jit::parseIR(graph, g.get());
684+
685+
auto x = at::randint(1, 10, {2}, {at::kCUDA}).to(torch::kInt);
686+
auto jit_x = at::clone(x);
687+
688+
auto y = at::randint(1, 10, {5}, {at::kCUDA}).to(torch::kInt);
689+
auto jit_y = at::clone(y);
690+
691+
auto z = torch::tensor(22, {at::kCUDA}).to(torch::kInt); // 0D
692+
auto jit_z = at::clone(z);
693+
694+
auto params = torch_tensorrt::core::ir::get_static_params(g->inputs(), {});
695+
auto jit_results = torch_tensorrt::tests::util::RunGraph(g, params, {jit_x, jit_y, jit_z});
696+
697+
auto trt_x = at::clone(jit_x);
698+
auto trt_y = at::clone(jit_y);
699+
auto trt_z = at::clone(jit_z);
700+
params = torch_tensorrt::core::ir::get_static_params(g->inputs(), {});
701+
auto trt_results = torch_tensorrt::tests::util::RunGraphEngine(g, params, {trt_x, trt_y, trt_z});
702+
703+
ASSERT_TRUE(torch_tensorrt::tests::util::almostEqual(jit_results[0], trt_results[0], 2e-6));
704+
ASSERT_TRUE(torch_tensorrt::tests::util::almostEqual(jit_results[1], trt_results[1], 2e-6));
705+
ASSERT_TRUE(torch_tensorrt::tests::util::almostEqual(jit_results[2], trt_results[2], 2e-6));
706+
}

0 commit comments

Comments
 (0)