Skip to content

Commit 746a9d6

Browse files
committed
fix: Minor bugfix in partitioning test
- Partitioning test incorrectly expected 1 conditional engine, but got 2 since `log_sigmoid` operator is not currently supported
1 parent a32e254 commit 746a9d6

File tree

2 files changed

+16
-6
lines changed

2 files changed

+16
-6
lines changed

tests/core/lowering/test_conv1d_pass.cpp

Lines changed: 14 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,8 @@
55
#include "tests/util/util.h"
66
#include "torch/csrc/jit/ir/irparser.h"
77
#include "torch/csrc/jit/ir/subgraph_matcher.h"
8+
#include "torch/csrc/jit/passes/canonicalize.h"
9+
#include "torch/csrc/jit/passes/constant_pooling.h"
810

911
TEST(LoweringPasses, Conv1dCorrectly) {
1012
const auto source_graph = R"IR(
@@ -119,7 +121,7 @@ TEST(LoweringPasses, ConvTransposed1dCorrectly) {
119121
}
120122

121123
TEST(LoweringPasses, Conv1dWithConditionalLowersCorrectly) {
122-
const auto source_graph = R"IR(
124+
std::string source_graph = R"IR(
123125
graph(%0 : Tensor,
124126
%1 : Float(4, 3, 3, strides=[9, 3, 1]),
125127
%2 : Float(3)):
@@ -142,21 +144,21 @@ TEST(LoweringPasses, Conv1dWithConditionalLowersCorrectly) {
142144
-> (%res)
143145
return (%12))IR";
144146

145-
const auto target_graph = R"IR(
147+
std::string target_graph = R"IR(
146148
graph(%0 : Tensor,
147149
%1 : Float(4, 3, 3, strides=[9, 3, 1]),
148150
%2 : Float(3)):
149-
%3 : bool = prim::Constant[value=0]()
150151
%4 : int = prim::Constant[value=0]()
151152
%5 : int = prim::Constant[value=1]()
153+
%true : bool = prim::Constant[value=1]()
154+
%3 : bool = prim::Constant[value=0]()
155+
%output_padding : int[] = prim::Constant[value=[0]]()
152156
%6 : int = prim::Constant[value=1]()
153157
%stride : int[] = prim::ListConstruct(%6)
154158
%padding : int[] = prim::ListConstruct(%4)
155159
%dilation : int[] = prim::ListConstruct(%5)
156-
%output_padding : int[] = prim::Constant[value=[0]]()
157160
158161
# Add intentionally-invalid weight tensor to ensure prim::If blocks are respected
159-
%true : bool = prim::Constant[value=1]()
160162
%invalid_weight : Tensor = aten::transpose(%0, %4, %5)
161163
%12 : Tensor = prim::If(%true)
162164
block0():
@@ -172,9 +174,16 @@ TEST(LoweringPasses, Conv1dWithConditionalLowersCorrectly) {
172174
auto sg = std::make_shared<torch::jit::Graph>();
173175
torch::jit::parseIR(source_graph, &*sg);
174176
torch_tensorrt::core::lowering::passes::Conv1DToConvolution(sg);
177+
torch::jit::ConstantPooling(sg);
178+
sg = torch::jit::Canonicalize(sg, false);
175179

176180
auto tg = std::make_shared<torch::jit::Graph>();
177181
torch::jit::parseIR(target_graph, &*tg);
182+
torch::jit::ConstantPooling(tg);
183+
tg = torch::jit::Canonicalize(tg, false);
184+
185+
// Validate identical graphs after pooling constants and canonicalizing
186+
ASSERT_TRUE((tg->toString() == sg->toString()));
178187

179188
auto in = at::randint(1, 2, {1, 3, 3}, {at::kCUDA});
180189
auto w = at::randint(1, 2, {4, 3, 3}, {at::kCUDA});

tests/core/partitioning/test_conditionals.cpp

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -35,12 +35,13 @@ TEST(Partitioning, FallbackOnConditionalsCorrectly) {
3535
auto g = mod.get_method("forward").graph();
3636
torch_tensorrt::core::CompileSpec cfg(inputs);
3737
cfg.partitioning_info.enabled = true;
38+
cfg.partitioning_info.forced_fallback_operators.push_back("aten::log_sigmoid");
3839
torch::jit::script::Module new_mod = torch_tensorrt::core::CompileGraph(mod, cfg);
3940
auto new_g = new_mod.get_method("forward").graph();
4041

4142
auto conditional_engines_count = count_trt_engines_in_conditionals(new_g);
4243

43-
ASSERT_TRUE(conditional_engines_count == 1);
44+
ASSERT_TRUE(conditional_engines_count == 2);
4445
}
4546

4647
TEST(Partitioning, FallbackInplaceOPInConditionalsCorrectly) {

0 commit comments

Comments
 (0)