Skip to content

Commit 76800bc

Browse files
authored
Merge pull request #2115 from mfeliz-cruise/michael.feliz/dynamic_select_and_masked_fill
[feat] TS: Add support for dynamic select and masked_fill
2 parents 61e338e + 165981c commit 76800bc

File tree

3 files changed

+63
-7
lines changed

3 files changed

+63
-7
lines changed

core/conversion/converters/impl/select.cpp

Lines changed: 3 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -165,10 +165,7 @@ auto select_registrations TORCHTRT_UNUSED =
165165
}
166166

167167
shuffle_layer->setReshapeDimensions(util::squeezeDims(
168-
out->getDimensions(),
169-
dim,
170-
ctx->input_is_dynamic,
171-
ctx->input_is_dynamic && (num_zero_dimensions > 0)));
168+
out->getDimensions(), dim, false, ctx->input_is_dynamic && (num_zero_dimensions > 0)));
172169
shuffle_layer->setName(util::node_info(n).c_str());
173170
out = shuffle_layer->getOutput(0);
174171
}
@@ -710,9 +707,8 @@ auto select_registrations TORCHTRT_UNUSED =
710707
auto val_t_dtype = util::TRTDataTypeToScalarType(self->getType());
711708

712709
// Initialize contant tensor for fill with the inherited data type
713-
auto val_t = tensor_to_const(
714-
ctx, torch::full(util::toVec(self->getDimensions()), val, {torch::dtype(val_t_dtype)}));
715-
710+
std::vector<int64_t> singleton_dims(self->getDimensions().nbDims, 1);
711+
auto val_t = tensor_to_const(ctx, torch::full(singleton_dims, val, {torch::dtype(val_t_dtype)}));
716712
TORCHTRT_CHECK(
717713
util::broadcastable(self->getDimensions(), mask->getDimensions(), /*multidirectional=*/false),
718714
"Self and mask tensors are not broadcastable");

tests/core/conversion/converters/test_masked_fill.cpp

Lines changed: 34 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -43,6 +43,40 @@ TEST(Converters, ATenMaskedFillZerosConvertsCorrectly) {
4343
torch_tensorrt::tests::util::almostEqual(jit_results[0], trt_results[0].reshape_as(jit_results[0]), 2e-6));
4444
}
4545

46+
TEST(Converters, ATenMaskedFillZerosDynamicConvertsCorrectly) {
47+
const auto graph = R"IR(
48+
graph(%x.1 : Tensor):
49+
%44 : Device = prim::Constant[value="cuda"]()
50+
%8 : bool = prim::Constant[value=0]()
51+
%7 : None = prim::Constant()
52+
%f32_dtype: int = prim::Constant[value=11]()
53+
%1 : int = prim::Constant[value=0]() # bert.py:5:26
54+
%2 : int = prim::Constant[value=1]() # bert.py:5:32
55+
%33 : int = prim::Constant[value=2]() # bert.py:6:31
56+
%3 : int[] = prim::ListConstruct(%1, %1, %2)
57+
%9 : Tensor = aten::tensor(%3, %f32_dtype, %7, %8) # bert.py:5:11
58+
%mask.1 : Tensor = aten::to(%9, %44, %7, %8, %8) # bert.py:5:11
59+
%mask.2 : Tensor = trt::const(%mask.1)
60+
%34 : Tensor = aten::masked_fill(%x.1, %mask.1, %33) # bert.py:6:11
61+
return (%34, %mask.2))IR";
62+
63+
auto g = std::make_shared<torch::jit::Graph>();
64+
65+
torch::jit::parseIR(graph, &*g);
66+
67+
auto in = at::zeros({1, 2, 3}, {at::kCUDA});
68+
69+
auto jit_in = at::clone(in);
70+
auto params = torch_tensorrt::core::ir::get_static_params(g->inputs(), {});
71+
auto jit_results = torch_tensorrt::tests::util::RunGraph(g, params, {jit_in});
72+
73+
auto trt_in = at::clone(in);
74+
torch_tensorrt::core::lowering::passes::RemoveNOPs(g);
75+
auto trt_results = torch_tensorrt::tests::util::RunGraphEngineDynamic(g, params, {trt_in});
76+
77+
ASSERT_TRUE(torch_tensorrt::tests::util::almostEqual(jit_results[0], trt_results[0]));
78+
}
79+
4680
TEST(Converters, ATenMaskedFillMixedTypesFloatIntConvertsCorrectly) {
4781
const auto graph = R"IR(
4882
graph(%x.1 : Tensor, %x.2 : Tensor):

tests/core/conversion/converters/test_select.cpp

Lines changed: 26 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -32,6 +32,32 @@ TEST(Converters, ATenSelectIntConvertsCorrectly) {
3232
ASSERT_TRUE(torch_tensorrt::tests::util::almostEqual(jit_results[0], trt, 2e-6));
3333
}
3434

35+
TEST(Converters, ATenSelectIntDynamicConvertsCorrectly) {
36+
const auto graph = R"IR(
37+
graph(%0 : Tensor):
38+
%2 : int = prim::Constant[value=0]()
39+
%3 : Tensor = aten::select(%0, %2, %2)
40+
return (%3))IR";
41+
42+
auto g = std::make_shared<torch::jit::Graph>();
43+
44+
torch::jit::parseIR(graph, g.get());
45+
46+
auto in = at::randint(1, 10, {5, 7, 9}, {at::kCUDA});
47+
48+
auto jit_in = at::clone(in);
49+
auto params = torch_tensorrt::core::ir::get_static_params(g->inputs(), {});
50+
auto jit_results = torch_tensorrt::tests::util::RunGraph(g, params, {jit_in});
51+
52+
auto trt_in = at::clone(in);
53+
params = torch_tensorrt::core::ir::get_static_params(g->inputs(), {});
54+
auto trt_results = torch_tensorrt::tests::util::RunGraphEngineDynamic(g, params, {trt_in});
55+
56+
auto trt = trt_results[0].reshape(jit_results[0].sizes());
57+
58+
ASSERT_TRUE(torch_tensorrt::tests::util::almostEqual(jit_results[0], trt, 2e-6));
59+
}
60+
3561
TEST(Converters, ATenSelectIntDimIsOneConvertsCorrectly) {
3662
const auto graph = R"IR(
3763
graph(%0 : Tensor):

0 commit comments

Comments
 (0)