Skip to content

Commit a32e254

Browse files
committed
fix: Bugfix in convNd_to_convolution lowering pass
- Lowering pass did not respect `prim::If` block boundaries - Refactor convNd implementation to use more precise guard-insert paradigm instead of subgraph rewriting - Write general function to apply for all convolution replacements - When replacing a subgraph that occurs within an "If" block, the rewriter places the actual logic of the code outside of the block, so the rewrite makes the code execute both the "if" and the "else" path regardless of what the condition is - Add a test case to validate refactoring on conv1d
1 parent deda87b commit a32e254

File tree

2 files changed

+123
-35
lines changed

2 files changed

+123
-35
lines changed

core/lowering/passes/convNd_to_convolution.cpp

Lines changed: 49 additions & 35 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
11
#include <torch/csrc/jit/passes/subgraph_rewrite.h>
2+
#include "torch/csrc/jit/ir/irparser.h"
23

34
#include "core/util/prelude.h"
45

@@ -7,78 +8,91 @@ namespace core {
78
namespace lowering {
89
namespace passes {
910

10-
void Conv1DToConvolution(std::shared_ptr<torch::jit::Graph>& graph) {
11-
std::string conv1d_pattern = R"IR(
12-
graph(%x, %w, %b, %s, %p, %d, %g):
13-
%4 : Tensor = aten::conv1d(%x, %w, %b, %s, %p, %d, %g)
14-
return (%4))IR";
11+
void replaceConv(
12+
torch::jit::Block* block,
13+
const std::string& node_kind,
14+
const std::string& unwrapped_conv,
15+
const size_t num_input_args) {
16+
// Iterate through nodes in block, seaching for aten::conv*
17+
for (auto it = block->nodes().begin(); it != block->nodes().end(); it++) {
18+
auto n = *it;
19+
20+
// Recursively explore nested blocks, such as those arising from prim::If
21+
for (auto nested_block : n->blocks()) {
22+
replaceConv(nested_block, node_kind, unwrapped_conv, num_input_args);
23+
}
24+
25+
// If node matches desired kind and number of input arguments, replace it
26+
if ((n->kind().toQualString() == node_kind) && (n->inputs().size() == num_input_args)) {
27+
// Establish insert point within block
28+
torch::jit::WithInsertPoint guard(*it);
29+
30+
// Initialize new fused subgraph from IR code provided
31+
auto fused_g = std::make_shared<torch::jit::Graph>();
32+
torch::jit::parseIR(unwrapped_conv, fused_g.get());
33+
34+
// Insert subgraph in place of aten::conv*, replacing inputs and outputs accordingly
35+
torch::jit::Value* new_output = insertGraph(*it->owningGraph(), *fused_g, it->inputs()).at(0);
36+
new_output->setType(it->output()->type());
37+
it->output()->replaceAllUsesWith(new_output);
38+
it.destroyCurrent();
39+
}
40+
}
41+
}
1542

16-
std::string convolution_pattern = R"IR(
43+
void Conv1DToConvolution(std::shared_ptr<torch::jit::Graph>& graph) {
44+
const std::string conv1d_node_kind = "aten::conv1d";
45+
const std::string convolution_pattern = R"IR(
1746
graph(%x, %w, %b, %s, %p, %d, %g):
1847
%1 : bool = prim::Constant[value=0]()
1948
%2 : int[] = prim::Constant[value=[0]]()
2049
%4 : Tensor = aten::_convolution(%x, %w, %b, %s, %p, %d, %1, %2, %g, %1, %1, %1, %1)
2150
return (%4))IR";
2251

23-
torch::jit::SubgraphRewriter map_conv1d_to_convolution;
24-
map_conv1d_to_convolution.RegisterRewritePattern(conv1d_pattern, convolution_pattern);
25-
map_conv1d_to_convolution.runOnGraph(graph);
52+
// Schema is aten::conv1d(%x, %w, %b, %s, %p, %d, %g) --> 7 inputs
53+
replaceConv(graph->block(), conv1d_node_kind, convolution_pattern, 7);
2654
LOG_GRAPH("Post map conv1d -> _convolution: " << *graph);
2755
}
2856

2957
void ConvTransposed1DToConvolution(std::shared_ptr<torch::jit::Graph>& graph) {
30-
std::string conv_transpose1d_pattern = R"IR(
31-
graph(%x, %w, %b, %s, %p, %o, %g, %d):
32-
%4 : Tensor = aten::conv_transpose1d(%x, %w, %b, %s, %p, %o, %g, %d)
33-
return (%4))IR";
34-
std::string convolution_pattern = R"IR(
58+
const std::string conv_transpose1d_node_kind = "aten::conv_transpose1d";
59+
const std::string convolution_pattern = R"IR(
3560
graph(%x, %w, %b, %s, %p, %o, %g, %d):
3661
%1 : bool = prim::Constant[value=1]()
3762
%2 : bool = prim::Constant[value=1]()
3863
%4 : Tensor = aten::_convolution(%x, %w, %b, %s, %p, %d, %1, %o, %g, %2, %2, %2, %2)
3964
return (%4))IR";
4065

41-
torch::jit::SubgraphRewriter map_conv_transpose1d_to_convolution;
42-
map_conv_transpose1d_to_convolution.RegisterRewritePattern(conv_transpose1d_pattern, convolution_pattern);
43-
map_conv_transpose1d_to_convolution.runOnGraph(graph);
66+
// Schema is aten::conv_transpose1d(%x, %w, %b, %s, %p, %o, %g, %d) --> 8 inputs
67+
replaceConv(graph->block(), conv_transpose1d_node_kind, convolution_pattern, 8);
4468
LOG_GRAPH("Post map conv_transpose1d -> _convolution: " << *graph);
4569
}
4670

4771
void Conv2DToConvolution(std::shared_ptr<torch::jit::Graph>& graph) {
48-
std::string conv2d_pattern = R"IR(
49-
graph(%x, %w, %b, %s, %p, %d, %g):
50-
%4 : Tensor = aten::conv2d(%x, %w, %b, %s, %p, %d, %g)
51-
return (%4))IR";
52-
std::string convolution_pattern = R"IR(
72+
const std::string conv2d_node_kind = "aten::conv2d";
73+
const std::string convolution_pattern = R"IR(
5374
graph(%x, %w, %b, %s, %p, %d, %g):
5475
%1 : bool = prim::Constant[value=0]()
5576
%2 : int[] = prim::Constant[value=[0, 0]]()
5677
%4 : Tensor = aten::_convolution(%x, %w, %b, %s, %p, %d, %1, %2, %g, %1, %1, %1, %1)
5778
return (%4))IR";
5879

59-
// replace matmul + add pattern to linear
60-
torch::jit::SubgraphRewriter map_conv2d_to_convolution;
61-
map_conv2d_to_convolution.RegisterRewritePattern(conv2d_pattern, convolution_pattern);
62-
map_conv2d_to_convolution.runOnGraph(graph);
80+
// Schema is aten::conv2d(%x, %w, %b, %s, %p, %d, %g) --> 7 inputs
81+
replaceConv(graph->block(), conv2d_node_kind, convolution_pattern, 7);
6382
LOG_GRAPH("Post map conv2d -> _convolution: " << *graph);
6483
}
6584

6685
void Conv3DToConvolution(std::shared_ptr<torch::jit::Graph>& graph) {
67-
std::string conv3d_pattern = R"IR(
68-
graph(%x, %w, %b, %s, %p, %d, %g):
69-
%4 : Tensor = aten::conv3d(%x, %w, %b, %s, %p, %d, %g)
70-
return (%4))IR";
71-
std::string convolution_pattern = R"IR(
86+
const std::string conv3d_node_kind = "aten::conv3d";
87+
const std::string convolution_pattern = R"IR(
7288
graph(%x, %w, %b, %s, %p, %d, %g):
7389
%1 : bool = prim::Constant[value=0]()
7490
%2 : int[] = prim::Constant[value=[0, 0, 0]]()
7591
%4 : Tensor = aten::_convolution(%x, %w, %b, %s, %p, %d, %1, %2, %g, %1, %1, %1, %1)
7692
return (%4))IR";
7793

78-
// replace matmul + add pattern to linear
79-
torch::jit::SubgraphRewriter map_conv3d_to_convolution;
80-
map_conv3d_to_convolution.RegisterRewritePattern(conv3d_pattern, convolution_pattern);
81-
map_conv3d_to_convolution.runOnGraph(graph);
94+
// Schema is aten::conv3d(%x, %w, %b, %s, %p, %d, %g) --> 7 inputs
95+
replaceConv(graph->block(), conv3d_node_kind, convolution_pattern, 7);
8296
LOG_GRAPH("Post map conv3d -> _convolution: " << *graph);
8397
}
8498

tests/core/lowering/test_conv1d_pass.cpp

Lines changed: 74 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -117,3 +117,77 @@ TEST(LoweringPasses, ConvTransposed1dCorrectly) {
117117

118118
ASSERT_TRUE(torch_tensorrt::tests::util::almostEqual(trt_results_sg[0], trt_results_tg[0], 2e-6));
119119
}
120+
121+
TEST(LoweringPasses, Conv1dWithConditionalLowersCorrectly) {
122+
const auto source_graph = R"IR(
123+
graph(%0 : Tensor,
124+
%1 : Float(4, 3, 3, strides=[9, 3, 1]),
125+
%2 : Float(3)):
126+
%4 : int = prim::Constant[value=0]()
127+
%5 : int = prim::Constant[value=1]()
128+
%6 : int = prim::Constant[value=1]()
129+
%stride : int[] = prim::ListConstruct(%6)
130+
%padding : int[] = prim::ListConstruct(%4)
131+
%dilation : int[] = prim::ListConstruct(%5)
132+
133+
# Add intentionally-invalid weight tensor to ensure prim::If blocks are respected
134+
%true : bool = prim::Constant[value=1]()
135+
%invalid_weight : Tensor = aten::transpose(%0, %4, %5)
136+
%12 : Tensor = prim::If(%true)
137+
block0():
138+
%res: Tensor = aten::conv1d(%0, %1, %2, %stride, %padding, %dilation, %6)
139+
-> (%res)
140+
block1():
141+
%res: Tensor = aten::conv1d(%invalid_weight, %1, %2, %stride, %padding, %dilation, %6)
142+
-> (%res)
143+
return (%12))IR";
144+
145+
const auto target_graph = R"IR(
146+
graph(%0 : Tensor,
147+
%1 : Float(4, 3, 3, strides=[9, 3, 1]),
148+
%2 : Float(3)):
149+
%3 : bool = prim::Constant[value=0]()
150+
%4 : int = prim::Constant[value=0]()
151+
%5 : int = prim::Constant[value=1]()
152+
%6 : int = prim::Constant[value=1]()
153+
%stride : int[] = prim::ListConstruct(%6)
154+
%padding : int[] = prim::ListConstruct(%4)
155+
%dilation : int[] = prim::ListConstruct(%5)
156+
%output_padding : int[] = prim::Constant[value=[0]]()
157+
158+
# Add intentionally-invalid weight tensor to ensure prim::If blocks are respected
159+
%true : bool = prim::Constant[value=1]()
160+
%invalid_weight : Tensor = aten::transpose(%0, %4, %5)
161+
%12 : Tensor = prim::If(%true)
162+
block0():
163+
%res: Tensor = aten::_convolution(%0, %1, %2, %stride, %padding, %dilation, %3, %output_padding, %6, %3, %3, %3, %3)
164+
-> (%res)
165+
block1():
166+
%res: Tensor = aten::_convolution(%invalid_weight, %1, %2, %stride, %padding, %dilation, %3, %output_padding, %6, %3, %3, %3, %3)
167+
-> (%res)
168+
return (%12))IR";
169+
170+
torch_tensorrt::core::util::logging::get_logger().set_reportable_log_level(
171+
torch_tensorrt::core::util::logging::LogLevel::kGRAPH);
172+
auto sg = std::make_shared<torch::jit::Graph>();
173+
torch::jit::parseIR(source_graph, &*sg);
174+
torch_tensorrt::core::lowering::passes::Conv1DToConvolution(sg);
175+
176+
auto tg = std::make_shared<torch::jit::Graph>();
177+
torch::jit::parseIR(target_graph, &*tg);
178+
179+
auto in = at::randint(1, 2, {1, 3, 3}, {at::kCUDA});
180+
auto w = at::randint(1, 2, {4, 3, 3}, {at::kCUDA});
181+
auto b = at::randint(1, 10, {4}, {at::kCUDA});
182+
183+
auto trt_in = at::clone(in);
184+
auto trt_w = at::clone(w);
185+
auto trt_b = at::clone(b);
186+
auto params = torch_tensorrt::core::ir::get_static_params(sg->inputs(), {trt_w, trt_b});
187+
auto trt_results_sg = torch_tensorrt::tests::util::RunGraphEngine(sg, params, {trt_in});
188+
189+
params = torch_tensorrt::core::ir::get_static_params(tg->inputs(), {trt_w, trt_b});
190+
auto trt_results_tg = torch_tensorrt::tests::util::RunGraphEngine(tg, params, {trt_in});
191+
192+
ASSERT_TRUE(torch_tensorrt::tests::util::almostEqual(trt_results_sg[0], trt_results_tg[0], 2e-6));
193+
}

0 commit comments

Comments
 (0)