5
5
#include " tests/util/util.h"
6
6
#include " torch/csrc/jit/ir/irparser.h"
7
7
#include " torch/csrc/jit/ir/subgraph_matcher.h"
8
+ #include " torch/csrc/jit/passes/canonicalize.h"
9
+ #include " torch/csrc/jit/passes/constant_pooling.h"
8
10
9
11
TEST (LoweringPasses, Conv1dCorrectly) {
10
12
const auto source_graph = R"IR(
@@ -119,7 +121,7 @@ TEST(LoweringPasses, ConvTransposed1dCorrectly) {
119
121
}
120
122
121
123
TEST (LoweringPasses, Conv1dWithConditionalLowersCorrectly) {
122
- const auto source_graph = R"IR(
124
+ std::string source_graph = R"IR(
123
125
graph(%0 : Tensor,
124
126
%1 : Float(4, 3, 3, strides=[9, 3, 1]),
125
127
%2 : Float(3)):
@@ -142,21 +144,21 @@ TEST(LoweringPasses, Conv1dWithConditionalLowersCorrectly) {
142
144
-> (%res)
143
145
return (%12))IR" ;
144
146
145
- const auto target_graph = R"IR(
147
+ std::string target_graph = R"IR(
146
148
graph(%0 : Tensor,
147
149
%1 : Float(4, 3, 3, strides=[9, 3, 1]),
148
150
%2 : Float(3)):
149
- %3 : bool = prim::Constant[value=0]()
150
151
%4 : int = prim::Constant[value=0]()
151
152
%5 : int = prim::Constant[value=1]()
153
+ %true : bool = prim::Constant[value=1]()
154
+ %3 : bool = prim::Constant[value=0]()
155
+ %output_padding : int[] = prim::Constant[value=[0]]()
152
156
%6 : int = prim::Constant[value=1]()
153
157
%stride : int[] = prim::ListConstruct(%6)
154
158
%padding : int[] = prim::ListConstruct(%4)
155
159
%dilation : int[] = prim::ListConstruct(%5)
156
- %output_padding : int[] = prim::Constant[value=[0]]()
157
160
158
161
# Add intentionally-invalid weight tensor to ensure prim::If blocks are respected
159
- %true : bool = prim::Constant[value=1]()
160
162
%invalid_weight : Tensor = aten::transpose(%0, %4, %5)
161
163
%12 : Tensor = prim::If(%true)
162
164
block0():
@@ -172,9 +174,16 @@ TEST(LoweringPasses, Conv1dWithConditionalLowersCorrectly) {
172
174
auto sg = std::make_shared<torch::jit::Graph>();
173
175
torch::jit::parseIR (source_graph, &*sg);
174
176
torch_tensorrt::core::lowering::passes::Conv1DToConvolution (sg);
177
+ torch::jit::ConstantPooling (sg);
178
+ sg = torch::jit::Canonicalize (sg, false );
175
179
176
180
auto tg = std::make_shared<torch::jit::Graph>();
177
181
torch::jit::parseIR (target_graph, &*tg);
182
+ torch::jit::ConstantPooling (tg);
183
+ tg = torch::jit::Canonicalize (tg, false );
184
+
185
+ // Validate identical graphs after pooling constants and canonicalizing
186
+ ASSERT_TRUE ((tg->toString () == sg->toString ()));
178
187
179
188
auto in = at::randint (1 , 2 , {1 , 3 , 3 }, {at::kCUDA });
180
189
auto w = at::randint (1 , 2 , {4 , 3 , 3 }, {at::kCUDA });
0 commit comments