Skip to content

Commit 5b1af77

Browse files
authored
Unsqueeze operator with dynamic inout (#1624)
1 parent 8c6b0a7 commit 5b1af77

File tree

2 files changed

+23
-1
lines changed

2 files changed

+23
-1
lines changed

core/conversion/converters/impl/unsqueeze.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -32,7 +32,7 @@ auto unsqueeze_registrations TORCHTRT_UNUSED = RegisterNodeConversionPatterns().
3232

3333
auto shuffle_layer = ctx->net->addShuffle(*self);
3434
TORCHTRT_CHECK(shuffle_layer, "Unable to create shuffle layer from node: " << *n);
35-
shuffle_layer->setReshapeDimensions(util::unsqueezeDims(self->getDimensions(), dim));
35+
shuffle_layer->setReshapeDimensions(util::unsqueezeDims(self->getDimensions(), dim, 1, false));
3636

3737
auto out = ctx->AssociateValueAndTensor(n->outputs()[0], shuffle_layer->getOutput(0));
3838

tests/core/conversion/converters/test_unsqueeze.cpp

Lines changed: 22 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -47,3 +47,25 @@ TEST(Converters, ATenUnsqueezeNegativeDimConvertsCorrectly) {
4747
ASSERT_TRUE(
4848
torch_tensorrt::tests::util::almostEqual(jit_results[0], trt_results[0].reshape_as(jit_results[0]), 2e-6));
4949
}
50+
51+
TEST(Converters, ATenUnsqueezeConvertsCorrectlyWithDynamicInput) {
52+
const auto graph = R"IR(
53+
graph(%0 : Tensor):
54+
%1 : int = prim::Constant[value=1]()
55+
%2 : Tensor = aten::unsqueeze(%0, %1)
56+
return (%2))IR";
57+
58+
auto g = std::make_shared<torch::jit::Graph>();
59+
torch::jit::parseIR(graph, g.get());
60+
61+
auto in = at::randint(1, 10, {1, 10}, {at::kCUDA});
62+
63+
auto params = torch_tensorrt::core::ir::get_static_params(g->inputs(), {});
64+
auto jit_results = torch_tensorrt::tests::util::RunGraph(g, params, {in});
65+
66+
params = torch_tensorrt::core::ir::get_static_params(g->inputs(), {});
67+
auto trt_results = torch_tensorrt::tests::util::RunGraphEngine(g, params, {in});
68+
69+
ASSERT_TRUE(
70+
torch_tensorrt::tests::util::almostEqual(jit_results[0], trt_results[0].reshape_as(jit_results[0]), 2e-6));
71+
}

0 commit comments

Comments
 (0)