Skip to content

Commit 2ea9f00

Browse files
authored
Merge pull request #1693 from gs-olive/convNd_lowering_bugfix
fix: Bugfix in convNd_to_convolution lowering pass
2 parents 20277d4 + 746a9d6 commit 2ea9f00

File tree

3 files changed

+134
-36
lines changed

3 files changed

+134
-36
lines changed

core/lowering/passes/convNd_to_convolution.cpp

Lines changed: 49 additions & 35 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
11
#include <torch/csrc/jit/passes/subgraph_rewrite.h>
2+
#include "torch/csrc/jit/ir/irparser.h"
23

34
#include "core/util/prelude.h"
45

@@ -7,78 +8,91 @@ namespace core {
78
namespace lowering {
89
namespace passes {
910

10-
void Conv1DToConvolution(std::shared_ptr<torch::jit::Graph>& graph) {
11-
std::string conv1d_pattern = R"IR(
12-
graph(%x, %w, %b, %s, %p, %d, %g):
13-
%4 : Tensor = aten::conv1d(%x, %w, %b, %s, %p, %d, %g)
14-
return (%4))IR";
11+
void replaceConv(
12+
torch::jit::Block* block,
13+
const std::string& node_kind,
14+
const std::string& unwrapped_conv,
15+
const size_t num_input_args) {
16+
// Iterate through nodes in block, seaching for aten::conv*
17+
for (auto it = block->nodes().begin(); it != block->nodes().end(); it++) {
18+
auto n = *it;
19+
20+
// Recursively explore nested blocks, such as those arising from prim::If
21+
for (auto nested_block : n->blocks()) {
22+
replaceConv(nested_block, node_kind, unwrapped_conv, num_input_args);
23+
}
24+
25+
// If node matches desired kind and number of input arguments, replace it
26+
if ((n->kind().toQualString() == node_kind) && (n->inputs().size() == num_input_args)) {
27+
// Establish insert point within block
28+
torch::jit::WithInsertPoint guard(*it);
29+
30+
// Initialize new fused subgraph from IR code provided
31+
auto fused_g = std::make_shared<torch::jit::Graph>();
32+
torch::jit::parseIR(unwrapped_conv, fused_g.get());
33+
34+
// Insert subgraph in place of aten::conv*, replacing inputs and outputs accordingly
35+
torch::jit::Value* new_output = insertGraph(*it->owningGraph(), *fused_g, it->inputs()).at(0);
36+
new_output->setType(it->output()->type());
37+
it->output()->replaceAllUsesWith(new_output);
38+
it.destroyCurrent();
39+
}
40+
}
41+
}
1542

16-
std::string convolution_pattern = R"IR(
43+
void Conv1DToConvolution(std::shared_ptr<torch::jit::Graph>& graph) {
44+
const std::string conv1d_node_kind = "aten::conv1d";
45+
const std::string convolution_pattern = R"IR(
1746
graph(%x, %w, %b, %s, %p, %d, %g):
1847
%1 : bool = prim::Constant[value=0]()
1948
%2 : int[] = prim::Constant[value=[0]]()
2049
%4 : Tensor = aten::_convolution(%x, %w, %b, %s, %p, %d, %1, %2, %g, %1, %1, %1, %1)
2150
return (%4))IR";
2251

23-
torch::jit::SubgraphRewriter map_conv1d_to_convolution;
24-
map_conv1d_to_convolution.RegisterRewritePattern(conv1d_pattern, convolution_pattern);
25-
map_conv1d_to_convolution.runOnGraph(graph);
52+
// Schema is aten::conv1d(%x, %w, %b, %s, %p, %d, %g) --> 7 inputs
53+
replaceConv(graph->block(), conv1d_node_kind, convolution_pattern, 7);
2654
LOG_GRAPH("Post map conv1d -> _convolution: " << *graph);
2755
}
2856

2957
void ConvTransposed1DToConvolution(std::shared_ptr<torch::jit::Graph>& graph) {
30-
std::string conv_transpose1d_pattern = R"IR(
31-
graph(%x, %w, %b, %s, %p, %o, %g, %d):
32-
%4 : Tensor = aten::conv_transpose1d(%x, %w, %b, %s, %p, %o, %g, %d)
33-
return (%4))IR";
34-
std::string convolution_pattern = R"IR(
58+
const std::string conv_transpose1d_node_kind = "aten::conv_transpose1d";
59+
const std::string convolution_pattern = R"IR(
3560
graph(%x, %w, %b, %s, %p, %o, %g, %d):
3661
%1 : bool = prim::Constant[value=1]()
3762
%2 : bool = prim::Constant[value=1]()
3863
%4 : Tensor = aten::_convolution(%x, %w, %b, %s, %p, %d, %1, %o, %g, %2, %2, %2, %2)
3964
return (%4))IR";
4065

41-
torch::jit::SubgraphRewriter map_conv_transpose1d_to_convolution;
42-
map_conv_transpose1d_to_convolution.RegisterRewritePattern(conv_transpose1d_pattern, convolution_pattern);
43-
map_conv_transpose1d_to_convolution.runOnGraph(graph);
66+
// Schema is aten::conv_transpose1d(%x, %w, %b, %s, %p, %o, %g, %d) --> 8 inputs
67+
replaceConv(graph->block(), conv_transpose1d_node_kind, convolution_pattern, 8);
4468
LOG_GRAPH("Post map conv_transpose1d -> _convolution: " << *graph);
4569
}
4670

4771
void Conv2DToConvolution(std::shared_ptr<torch::jit::Graph>& graph) {
48-
std::string conv2d_pattern = R"IR(
49-
graph(%x, %w, %b, %s, %p, %d, %g):
50-
%4 : Tensor = aten::conv2d(%x, %w, %b, %s, %p, %d, %g)
51-
return (%4))IR";
52-
std::string convolution_pattern = R"IR(
72+
const std::string conv2d_node_kind = "aten::conv2d";
73+
const std::string convolution_pattern = R"IR(
5374
graph(%x, %w, %b, %s, %p, %d, %g):
5475
%1 : bool = prim::Constant[value=0]()
5576
%2 : int[] = prim::Constant[value=[0, 0]]()
5677
%4 : Tensor = aten::_convolution(%x, %w, %b, %s, %p, %d, %1, %2, %g, %1, %1, %1, %1)
5778
return (%4))IR";
5879

59-
// replace matmul + add pattern to linear
60-
torch::jit::SubgraphRewriter map_conv2d_to_convolution;
61-
map_conv2d_to_convolution.RegisterRewritePattern(conv2d_pattern, convolution_pattern);
62-
map_conv2d_to_convolution.runOnGraph(graph);
80+
// Schema is aten::conv2d(%x, %w, %b, %s, %p, %d, %g) --> 7 inputs
81+
replaceConv(graph->block(), conv2d_node_kind, convolution_pattern, 7);
6382
LOG_GRAPH("Post map conv2d -> _convolution: " << *graph);
6483
}
6584

6685
void Conv3DToConvolution(std::shared_ptr<torch::jit::Graph>& graph) {
67-
std::string conv3d_pattern = R"IR(
68-
graph(%x, %w, %b, %s, %p, %d, %g):
69-
%4 : Tensor = aten::conv3d(%x, %w, %b, %s, %p, %d, %g)
70-
return (%4))IR";
71-
std::string convolution_pattern = R"IR(
86+
const std::string conv3d_node_kind = "aten::conv3d";
87+
const std::string convolution_pattern = R"IR(
7288
graph(%x, %w, %b, %s, %p, %d, %g):
7389
%1 : bool = prim::Constant[value=0]()
7490
%2 : int[] = prim::Constant[value=[0, 0, 0]]()
7591
%4 : Tensor = aten::_convolution(%x, %w, %b, %s, %p, %d, %1, %2, %g, %1, %1, %1, %1)
7692
return (%4))IR";
7793

78-
// replace matmul + add pattern to linear
79-
torch::jit::SubgraphRewriter map_conv3d_to_convolution;
80-
map_conv3d_to_convolution.RegisterRewritePattern(conv3d_pattern, convolution_pattern);
81-
map_conv3d_to_convolution.runOnGraph(graph);
94+
// Schema is aten::conv3d(%x, %w, %b, %s, %p, %d, %g) --> 7 inputs
95+
replaceConv(graph->block(), conv3d_node_kind, convolution_pattern, 7);
8296
LOG_GRAPH("Post map conv3d -> _convolution: " << *graph);
8397
}
8498

tests/core/lowering/test_conv1d_pass.cpp

Lines changed: 83 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,8 @@
55
#include "tests/util/util.h"
66
#include "torch/csrc/jit/ir/irparser.h"
77
#include "torch/csrc/jit/ir/subgraph_matcher.h"
8+
#include "torch/csrc/jit/passes/canonicalize.h"
9+
#include "torch/csrc/jit/passes/constant_pooling.h"
810

911
TEST(LoweringPasses, Conv1dCorrectly) {
1012
const auto source_graph = R"IR(
@@ -117,3 +119,84 @@ TEST(LoweringPasses, ConvTransposed1dCorrectly) {
117119

118120
ASSERT_TRUE(torch_tensorrt::tests::util::almostEqual(trt_results_sg[0], trt_results_tg[0], 2e-6));
119121
}
122+
123+
TEST(LoweringPasses, Conv1dWithConditionalLowersCorrectly) {
124+
std::string source_graph = R"IR(
125+
graph(%0 : Tensor,
126+
%1 : Float(4, 3, 3, strides=[9, 3, 1]),
127+
%2 : Float(3)):
128+
%4 : int = prim::Constant[value=0]()
129+
%5 : int = prim::Constant[value=1]()
130+
%6 : int = prim::Constant[value=1]()
131+
%stride : int[] = prim::ListConstruct(%6)
132+
%padding : int[] = prim::ListConstruct(%4)
133+
%dilation : int[] = prim::ListConstruct(%5)
134+
135+
# Add intentionally-invalid weight tensor to ensure prim::If blocks are respected
136+
%true : bool = prim::Constant[value=1]()
137+
%invalid_weight : Tensor = aten::transpose(%0, %4, %5)
138+
%12 : Tensor = prim::If(%true)
139+
block0():
140+
%res: Tensor = aten::conv1d(%0, %1, %2, %stride, %padding, %dilation, %6)
141+
-> (%res)
142+
block1():
143+
%res: Tensor = aten::conv1d(%invalid_weight, %1, %2, %stride, %padding, %dilation, %6)
144+
-> (%res)
145+
return (%12))IR";
146+
147+
std::string target_graph = R"IR(
148+
graph(%0 : Tensor,
149+
%1 : Float(4, 3, 3, strides=[9, 3, 1]),
150+
%2 : Float(3)):
151+
%4 : int = prim::Constant[value=0]()
152+
%5 : int = prim::Constant[value=1]()
153+
%true : bool = prim::Constant[value=1]()
154+
%3 : bool = prim::Constant[value=0]()
155+
%output_padding : int[] = prim::Constant[value=[0]]()
156+
%6 : int = prim::Constant[value=1]()
157+
%stride : int[] = prim::ListConstruct(%6)
158+
%padding : int[] = prim::ListConstruct(%4)
159+
%dilation : int[] = prim::ListConstruct(%5)
160+
161+
# Add intentionally-invalid weight tensor to ensure prim::If blocks are respected
162+
%invalid_weight : Tensor = aten::transpose(%0, %4, %5)
163+
%12 : Tensor = prim::If(%true)
164+
block0():
165+
%res: Tensor = aten::_convolution(%0, %1, %2, %stride, %padding, %dilation, %3, %output_padding, %6, %3, %3, %3, %3)
166+
-> (%res)
167+
block1():
168+
%res: Tensor = aten::_convolution(%invalid_weight, %1, %2, %stride, %padding, %dilation, %3, %output_padding, %6, %3, %3, %3, %3)
169+
-> (%res)
170+
return (%12))IR";
171+
172+
torch_tensorrt::core::util::logging::get_logger().set_reportable_log_level(
173+
torch_tensorrt::core::util::logging::LogLevel::kGRAPH);
174+
auto sg = std::make_shared<torch::jit::Graph>();
175+
torch::jit::parseIR(source_graph, &*sg);
176+
torch_tensorrt::core::lowering::passes::Conv1DToConvolution(sg);
177+
torch::jit::ConstantPooling(sg);
178+
sg = torch::jit::Canonicalize(sg, false);
179+
180+
auto tg = std::make_shared<torch::jit::Graph>();
181+
torch::jit::parseIR(target_graph, &*tg);
182+
torch::jit::ConstantPooling(tg);
183+
tg = torch::jit::Canonicalize(tg, false);
184+
185+
// Validate identical graphs after pooling constants and canonicalizing
186+
ASSERT_TRUE((tg->toString() == sg->toString()));
187+
188+
auto in = at::randint(1, 2, {1, 3, 3}, {at::kCUDA});
189+
auto w = at::randint(1, 2, {4, 3, 3}, {at::kCUDA});
190+
auto b = at::randint(1, 10, {4}, {at::kCUDA});
191+
192+
auto trt_in = at::clone(in);
193+
auto trt_w = at::clone(w);
194+
auto trt_b = at::clone(b);
195+
auto params = torch_tensorrt::core::ir::get_static_params(sg->inputs(), {trt_w, trt_b});
196+
auto trt_results_sg = torch_tensorrt::tests::util::RunGraphEngine(sg, params, {trt_in});
197+
198+
params = torch_tensorrt::core::ir::get_static_params(tg->inputs(), {trt_w, trt_b});
199+
auto trt_results_tg = torch_tensorrt::tests::util::RunGraphEngine(tg, params, {trt_in});
200+
201+
ASSERT_TRUE(torch_tensorrt::tests::util::almostEqual(trt_results_sg[0], trt_results_tg[0], 2e-6));
202+
}

tests/core/partitioning/test_conditionals.cpp

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -35,12 +35,13 @@ TEST(Partitioning, FallbackOnConditionalsCorrectly) {
3535
auto g = mod.get_method("forward").graph();
3636
torch_tensorrt::core::CompileSpec cfg(inputs);
3737
cfg.partitioning_info.enabled = true;
38+
cfg.partitioning_info.forced_fallback_operators.push_back("aten::log_sigmoid");
3839
torch::jit::script::Module new_mod = torch_tensorrt::core::CompileGraph(mod, cfg);
3940
auto new_g = new_mod.get_method("forward").graph();
4041

4142
auto conditional_engines_count = count_trt_engines_in_conditionals(new_g);
4243

43-
ASSERT_TRUE(conditional_engines_count == 1);
44+
ASSERT_TRUE(conditional_engines_count == 2);
4445
}
4546

4647
TEST(Partitioning, FallbackInplaceOPInConditionalsCorrectly) {

0 commit comments

Comments
 (0)