Skip to content

Commit cd7de1d

Browse files
committed
fix: Add schemas to conv replace
- Add support for transposed conv2d and conv3d, as well as for conv3d - Add testing for all convolution functions, rename test accordingly
1 parent 20277d4 commit cd7de1d

File tree

6 files changed

+505
-156
lines changed

6 files changed

+505
-156
lines changed

core/lowering/lowering.cpp

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -124,7 +124,9 @@ void LowerGraph(std::shared_ptr<torch::jit::Graph>& g, std::vector<torch::jit::I
124124
passes::Conv1DToConvolution(g);
125125
passes::ConvTransposed1DToConvolution(g);
126126
passes::Conv2DToConvolution(g);
127+
passes::ConvTransposed2DToConvolution(g);
127128
passes::Conv3DToConvolution(g);
129+
passes::ConvTransposed3DToConvolution(g);
128130
passes::FuseAddMMBranches(g);
129131
passes::RemoveBNDimCheck(g);
130132
// torch::jit::UnrollLoops(g);

core/lowering/passes/convNd_to_convolution.cpp

Lines changed: 77 additions & 35 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
11
#include <torch/csrc/jit/passes/subgraph_rewrite.h>
2+
#include "torch/csrc/jit/ir/irparser.h"
23

34
#include "core/util/prelude.h"
45

@@ -7,81 +8,122 @@ namespace core {
78
namespace lowering {
89
namespace passes {
910

10-
void Conv1DToConvolution(std::shared_ptr<torch::jit::Graph>& graph) {
11-
std::string conv1d_pattern = R"IR(
12-
graph(%x, %w, %b, %s, %p, %d, %g):
13-
%4 : Tensor = aten::conv1d(%x, %w, %b, %s, %p, %d, %g)
14-
return (%4))IR";
11+
void replaceConv(
12+
torch::jit::Block* block,
13+
const std::string& node_kind,
14+
const std::string& unwrapped_conv,
15+
const size_t num_input_args) {
16+
// Iterate through nodes in block, seaching for aten::conv*
17+
for (auto it = block->nodes().begin(); it != block->nodes().end(); it++) {
18+
auto n = *it;
19+
20+
// Recursively explore nested blocks, such as those arising from prim::If
21+
for (auto nested_block : n->blocks()) {
22+
replaceConv(nested_block, node_kind, unwrapped_conv, num_input_args);
23+
}
24+
25+
// If node matches desired kind and number of input arguments, replace it
26+
if ((n->kind().toQualString() == node_kind) && (n->inputs().size() == num_input_args)) {
27+
// Establish insert point within block
28+
torch::jit::WithInsertPoint guard(*it);
29+
30+
// Initialize new fused subgraph from IR code provided
31+
auto fused_g = std::make_shared<torch::jit::Graph>();
32+
torch::jit::parseIR(unwrapped_conv, fused_g.get());
33+
34+
// Insert subgraph in place of aten::conv*, replacing inputs and outputs accordingly
35+
torch::jit::Value* new_output = insertGraph(*it->owningGraph(), *fused_g, it->inputs()).at(0);
36+
new_output->setType(it->output()->type());
37+
it->output()->replaceAllUsesWith(new_output);
38+
it.destroyCurrent();
39+
}
40+
}
41+
}
1542

16-
std::string convolution_pattern = R"IR(
43+
void Conv1DToConvolution(std::shared_ptr<torch::jit::Graph>& graph) {
44+
const std::string conv1d_node_kind = "aten::conv1d";
45+
const std::string convolution_pattern = R"IR(
1746
graph(%x, %w, %b, %s, %p, %d, %g):
1847
%1 : bool = prim::Constant[value=0]()
1948
%2 : int[] = prim::Constant[value=[0]]()
2049
%4 : Tensor = aten::_convolution(%x, %w, %b, %s, %p, %d, %1, %2, %g, %1, %1, %1, %1)
2150
return (%4))IR";
2251

23-
torch::jit::SubgraphRewriter map_conv1d_to_convolution;
24-
map_conv1d_to_convolution.RegisterRewritePattern(conv1d_pattern, convolution_pattern);
25-
map_conv1d_to_convolution.runOnGraph(graph);
52+
// Schema is aten::conv1d(%x, %w, %b, %s, %p, %d, %g) --> 7 inputs
53+
replaceConv(graph->block(), conv1d_node_kind, convolution_pattern, 7);
2654
LOG_GRAPH("Post map conv1d -> _convolution: " << *graph);
2755
}
2856

2957
void ConvTransposed1DToConvolution(std::shared_ptr<torch::jit::Graph>& graph) {
30-
std::string conv_transpose1d_pattern = R"IR(
31-
graph(%x, %w, %b, %s, %p, %o, %g, %d):
32-
%4 : Tensor = aten::conv_transpose1d(%x, %w, %b, %s, %p, %o, %g, %d)
33-
return (%4))IR";
34-
std::string convolution_pattern = R"IR(
58+
const std::string conv_transpose1d_node_kind = "aten::conv_transpose1d";
59+
const std::string convolution_pattern = R"IR(
3560
graph(%x, %w, %b, %s, %p, %o, %g, %d):
3661
%1 : bool = prim::Constant[value=1]()
3762
%2 : bool = prim::Constant[value=1]()
3863
%4 : Tensor = aten::_convolution(%x, %w, %b, %s, %p, %d, %1, %o, %g, %2, %2, %2, %2)
3964
return (%4))IR";
4065

41-
torch::jit::SubgraphRewriter map_conv_transpose1d_to_convolution;
42-
map_conv_transpose1d_to_convolution.RegisterRewritePattern(conv_transpose1d_pattern, convolution_pattern);
43-
map_conv_transpose1d_to_convolution.runOnGraph(graph);
66+
// Schema is aten::conv_transpose1d(%x, %w, %b, %s, %p, %o, %g, %d) --> 8 inputs
67+
replaceConv(graph->block(), conv_transpose1d_node_kind, convolution_pattern, 8);
4468
LOG_GRAPH("Post map conv_transpose1d -> _convolution: " << *graph);
4569
}
4670

4771
void Conv2DToConvolution(std::shared_ptr<torch::jit::Graph>& graph) {
48-
std::string conv2d_pattern = R"IR(
49-
graph(%x, %w, %b, %s, %p, %d, %g):
50-
%4 : Tensor = aten::conv2d(%x, %w, %b, %s, %p, %d, %g)
51-
return (%4))IR";
52-
std::string convolution_pattern = R"IR(
72+
const std::string conv2d_node_kind = "aten::conv2d";
73+
const std::string convolution_pattern = R"IR(
5374
graph(%x, %w, %b, %s, %p, %d, %g):
5475
%1 : bool = prim::Constant[value=0]()
5576
%2 : int[] = prim::Constant[value=[0, 0]]()
5677
%4 : Tensor = aten::_convolution(%x, %w, %b, %s, %p, %d, %1, %2, %g, %1, %1, %1, %1)
5778
return (%4))IR";
5879

59-
// replace matmul + add pattern to linear
60-
torch::jit::SubgraphRewriter map_conv2d_to_convolution;
61-
map_conv2d_to_convolution.RegisterRewritePattern(conv2d_pattern, convolution_pattern);
62-
map_conv2d_to_convolution.runOnGraph(graph);
80+
// Schema is aten::conv2d(%x, %w, %b, %s, %p, %d, %g) --> 7 inputs
81+
replaceConv(graph->block(), conv2d_node_kind, convolution_pattern, 7);
6382
LOG_GRAPH("Post map conv2d -> _convolution: " << *graph);
6483
}
6584

66-
void Conv3DToConvolution(std::shared_ptr<torch::jit::Graph>& graph) {
67-
std::string conv3d_pattern = R"IR(
68-
graph(%x, %w, %b, %s, %p, %d, %g):
69-
%4 : Tensor = aten::conv3d(%x, %w, %b, %s, %p, %d, %g)
85+
void ConvTransposed2DToConvolution(std::shared_ptr<torch::jit::Graph>& graph) {
86+
const std::string conv_transpose2d_node_kind = "aten::conv_transpose2d";
87+
const std::string convolution_pattern = R"IR(
88+
graph(%x, %w, %b, %s, %p, %o, %g, %d):
89+
%1 : bool = prim::Constant[value=1]()
90+
%2 : bool = prim::Constant[value=1]()
91+
%4 : Tensor = aten::_convolution(%x, %w, %b, %s, %p, %d, %1, %o, %g, %2, %2, %2, %2)
7092
return (%4))IR";
71-
std::string convolution_pattern = R"IR(
93+
94+
// Schema is aten::conv_transpose2d(%x, %w, %b, %s, %p, %o, %g, %d) --> 8 inputs
95+
replaceConv(graph->block(), conv_transpose2d_node_kind, convolution_pattern, 8);
96+
LOG_GRAPH("Post map conv_transpose2d -> _convolution: " << *graph);
97+
}
98+
99+
void Conv3DToConvolution(std::shared_ptr<torch::jit::Graph>& graph) {
100+
const std::string conv3d_node_kind = "aten::conv3d";
101+
const std::string convolution_pattern = R"IR(
72102
graph(%x, %w, %b, %s, %p, %d, %g):
73103
%1 : bool = prim::Constant[value=0]()
74104
%2 : int[] = prim::Constant[value=[0, 0, 0]]()
75105
%4 : Tensor = aten::_convolution(%x, %w, %b, %s, %p, %d, %1, %2, %g, %1, %1, %1, %1)
76106
return (%4))IR";
77107

78-
// replace matmul + add pattern to linear
79-
torch::jit::SubgraphRewriter map_conv3d_to_convolution;
80-
map_conv3d_to_convolution.RegisterRewritePattern(conv3d_pattern, convolution_pattern);
81-
map_conv3d_to_convolution.runOnGraph(graph);
108+
// Schema is aten::conv3d(%x, %w, %b, %s, %p, %d, %g) --> 7 inputs
109+
replaceConv(graph->block(), conv3d_node_kind, convolution_pattern, 7);
82110
LOG_GRAPH("Post map conv3d -> _convolution: " << *graph);
83111
}
84112

113+
void ConvTransposed3DToConvolution(std::shared_ptr<torch::jit::Graph>& graph) {
114+
const std::string conv_transpose3d_node_kind = "aten::conv_transpose3d";
115+
const std::string convolution_pattern = R"IR(
116+
graph(%x, %w, %b, %s, %p, %o, %g, %d):
117+
%1 : bool = prim::Constant[value=1]()
118+
%2 : bool = prim::Constant[value=1]()
119+
%4 : Tensor = aten::_convolution(%x, %w, %b, %s, %p, %d, %1, %o, %g, %2, %2, %2, %2)
120+
return (%4))IR";
121+
122+
// Schema is aten::conv_transpose3d(%x, %w, %b, %s, %p, %o, %g, %d) --> 8 inputs
123+
replaceConv(graph->block(), conv_transpose3d_node_kind, convolution_pattern, 8);
124+
LOG_GRAPH("Post map conv_transpose3d -> _convolution: " << *graph);
125+
}
126+
85127
} // namespace passes
86128
} // namespace lowering
87129
} // namespace core

core/lowering/passes/passes.h

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,9 @@ void NotateModuleForFallback(
1515
void Conv1DToConvolution(std::shared_ptr<torch::jit::Graph>& graph);
1616
void ConvTransposed1DToConvolution(std::shared_ptr<torch::jit::Graph>& graph);
1717
void Conv2DToConvolution(std::shared_ptr<torch::jit::Graph>& graph);
18+
void ConvTransposed2DToConvolution(std::shared_ptr<torch::jit::Graph>& graph);
1819
void Conv3DToConvolution(std::shared_ptr<torch::jit::Graph>& graph);
20+
void ConvTransposed3DToConvolution(std::shared_ptr<torch::jit::Graph>& graph);
1921
void FuseAddMMBranches(std::shared_ptr<torch::jit::Graph> graph);
2022
void LinearToAddMM(std::shared_ptr<torch::jit::Graph>& graph);
2123
void EliminateExceptionOrPassPattern(std::shared_ptr<torch::jit::Graph> graph);

tests/core/lowering/BUILD

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -28,7 +28,7 @@ cc_test(
2828
)
2929

3030
lowering_test(
31-
name = "test_conv1d_pass",
31+
name = "test_conv_pass",
3232
)
3333

3434
lowering_test(
@@ -102,7 +102,7 @@ lowering_test(
102102
test_suite(
103103
name = "lowering_tests",
104104
tests = [
105-
":test_conv1d_pass",
105+
":test_conv_pass",
106106
":test_device_casting",
107107
":test_exception_elimination_pass",
108108
":test_linear_to_addmm",

tests/core/lowering/test_conv1d_pass.cpp

Lines changed: 0 additions & 119 deletions
This file was deleted.

0 commit comments

Comments
 (0)