1
1
#include < torch/csrc/jit/passes/subgraph_rewrite.h>
2
+ #include " torch/csrc/jit/ir/irparser.h"
2
3
3
4
#include " core/util/prelude.h"
4
5
@@ -7,81 +8,122 @@ namespace core {
7
8
namespace lowering {
8
9
namespace passes {
9
10
10
- void Conv1DToConvolution (std::shared_ptr<torch::jit::Graph>& graph) {
11
- std::string conv1d_pattern = R"IR(
12
- graph(%x, %w, %b, %s, %p, %d, %g):
13
- %4 : Tensor = aten::conv1d(%x, %w, %b, %s, %p, %d, %g)
14
- return (%4))IR" ;
11
+ void replaceConv (
12
+ torch::jit::Block* block,
13
+ const std::string& node_kind,
14
+ const std::string& unwrapped_conv,
15
+ const size_t num_input_args) {
16
+ // Iterate through nodes in block, seaching for aten::conv*
17
+ for (auto it = block->nodes ().begin (); it != block->nodes ().end (); it++) {
18
+ auto n = *it;
19
+
20
+ // Recursively explore nested blocks, such as those arising from prim::If
21
+ for (auto nested_block : n->blocks ()) {
22
+ replaceConv (nested_block, node_kind, unwrapped_conv, num_input_args);
23
+ }
24
+
25
+ // If node matches desired kind and number of input arguments, replace it
26
+ if ((n->kind ().toQualString () == node_kind) && (n->inputs ().size () == num_input_args)) {
27
+ // Establish insert point within block
28
+ torch::jit::WithInsertPoint guard (*it);
29
+
30
+ // Initialize new fused subgraph from IR code provided
31
+ auto fused_g = std::make_shared<torch::jit::Graph>();
32
+ torch::jit::parseIR (unwrapped_conv, fused_g.get ());
33
+
34
+ // Insert subgraph in place of aten::conv*, replacing inputs and outputs accordingly
35
+ torch::jit::Value* new_output = insertGraph (*it->owningGraph (), *fused_g, it->inputs ()).at (0 );
36
+ new_output->setType (it->output ()->type ());
37
+ it->output ()->replaceAllUsesWith (new_output);
38
+ it.destroyCurrent ();
39
+ }
40
+ }
41
+ }
15
42
16
- std::string convolution_pattern = R"IR(
43
+ void Conv1DToConvolution (std::shared_ptr<torch::jit::Graph>& graph) {
44
+ const std::string conv1d_node_kind = " aten::conv1d" ;
45
+ const std::string convolution_pattern = R"IR(
17
46
graph(%x, %w, %b, %s, %p, %d, %g):
18
47
%1 : bool = prim::Constant[value=0]()
19
48
%2 : int[] = prim::Constant[value=[0]]()
20
49
%4 : Tensor = aten::_convolution(%x, %w, %b, %s, %p, %d, %1, %2, %g, %1, %1, %1, %1)
21
50
return (%4))IR" ;
22
51
23
- torch::jit::SubgraphRewriter map_conv1d_to_convolution;
24
- map_conv1d_to_convolution.RegisterRewritePattern (conv1d_pattern, convolution_pattern);
25
- map_conv1d_to_convolution.runOnGraph (graph);
52
+ // Schema is aten::conv1d(%x, %w, %b, %s, %p, %d, %g) --> 7 inputs
53
+ replaceConv (graph->block (), conv1d_node_kind, convolution_pattern, 7 );
26
54
LOG_GRAPH (" Post map conv1d -> _convolution: " << *graph);
27
55
}
28
56
29
57
void ConvTransposed1DToConvolution (std::shared_ptr<torch::jit::Graph>& graph) {
30
- std::string conv_transpose1d_pattern = R"IR(
31
- graph(%x, %w, %b, %s, %p, %o, %g, %d):
32
- %4 : Tensor = aten::conv_transpose1d(%x, %w, %b, %s, %p, %o, %g, %d)
33
- return (%4))IR" ;
34
- std::string convolution_pattern = R"IR(
58
+ const std::string conv_transpose1d_node_kind = " aten::conv_transpose1d" ;
59
+ const std::string convolution_pattern = R"IR(
35
60
graph(%x, %w, %b, %s, %p, %o, %g, %d):
36
61
%1 : bool = prim::Constant[value=1]()
37
62
%2 : bool = prim::Constant[value=1]()
38
63
%4 : Tensor = aten::_convolution(%x, %w, %b, %s, %p, %d, %1, %o, %g, %2, %2, %2, %2)
39
64
return (%4))IR" ;
40
65
41
- torch::jit::SubgraphRewriter map_conv_transpose1d_to_convolution;
42
- map_conv_transpose1d_to_convolution.RegisterRewritePattern (conv_transpose1d_pattern, convolution_pattern);
43
- map_conv_transpose1d_to_convolution.runOnGraph (graph);
66
+ // Schema is aten::conv_transpose1d(%x, %w, %b, %s, %p, %o, %g, %d) --> 8 inputs
67
+ replaceConv (graph->block (), conv_transpose1d_node_kind, convolution_pattern, 8 );
44
68
LOG_GRAPH (" Post map conv_transpose1d -> _convolution: " << *graph);
45
69
}
46
70
47
71
void Conv2DToConvolution (std::shared_ptr<torch::jit::Graph>& graph) {
48
- std::string conv2d_pattern = R"IR(
49
- graph(%x, %w, %b, %s, %p, %d, %g):
50
- %4 : Tensor = aten::conv2d(%x, %w, %b, %s, %p, %d, %g)
51
- return (%4))IR" ;
52
- std::string convolution_pattern = R"IR(
72
+ const std::string conv2d_node_kind = " aten::conv2d" ;
73
+ const std::string convolution_pattern = R"IR(
53
74
graph(%x, %w, %b, %s, %p, %d, %g):
54
75
%1 : bool = prim::Constant[value=0]()
55
76
%2 : int[] = prim::Constant[value=[0, 0]]()
56
77
%4 : Tensor = aten::_convolution(%x, %w, %b, %s, %p, %d, %1, %2, %g, %1, %1, %1, %1)
57
78
return (%4))IR" ;
58
79
59
- // replace matmul + add pattern to linear
60
- torch::jit::SubgraphRewriter map_conv2d_to_convolution;
61
- map_conv2d_to_convolution.RegisterRewritePattern (conv2d_pattern, convolution_pattern);
62
- map_conv2d_to_convolution.runOnGraph (graph);
80
+ // Schema is aten::conv2d(%x, %w, %b, %s, %p, %d, %g) --> 7 inputs
81
+ replaceConv (graph->block (), conv2d_node_kind, convolution_pattern, 7 );
63
82
LOG_GRAPH (" Post map conv2d -> _convolution: " << *graph);
64
83
}
65
84
66
- void Conv3DToConvolution (std::shared_ptr<torch::jit::Graph>& graph) {
67
- std::string conv3d_pattern = R"IR(
68
- graph(%x, %w, %b, %s, %p, %d, %g):
69
- %4 : Tensor = aten::conv3d(%x, %w, %b, %s, %p, %d, %g)
85
+ void ConvTransposed2DToConvolution (std::shared_ptr<torch::jit::Graph>& graph) {
86
+ const std::string conv_transpose2d_node_kind = " aten::conv_transpose2d" ;
87
+ const std::string convolution_pattern = R"IR(
88
+ graph(%x, %w, %b, %s, %p, %o, %g, %d):
89
+ %1 : bool = prim::Constant[value=1]()
90
+ %2 : bool = prim::Constant[value=1]()
91
+ %4 : Tensor = aten::_convolution(%x, %w, %b, %s, %p, %d, %1, %o, %g, %2, %2, %2, %2)
70
92
return (%4))IR" ;
71
- std::string convolution_pattern = R"IR(
93
+
94
+ // Schema is aten::conv_transpose2d(%x, %w, %b, %s, %p, %o, %g, %d) --> 8 inputs
95
+ replaceConv (graph->block (), conv_transpose2d_node_kind, convolution_pattern, 8 );
96
+ LOG_GRAPH (" Post map conv_transpose2d -> _convolution: " << *graph);
97
+ }
98
+
99
+ void Conv3DToConvolution (std::shared_ptr<torch::jit::Graph>& graph) {
100
+ const std::string conv3d_node_kind = " aten::conv3d" ;
101
+ const std::string convolution_pattern = R"IR(
72
102
graph(%x, %w, %b, %s, %p, %d, %g):
73
103
%1 : bool = prim::Constant[value=0]()
74
104
%2 : int[] = prim::Constant[value=[0, 0, 0]]()
75
105
%4 : Tensor = aten::_convolution(%x, %w, %b, %s, %p, %d, %1, %2, %g, %1, %1, %1, %1)
76
106
return (%4))IR" ;
77
107
78
- // replace matmul + add pattern to linear
79
- torch::jit::SubgraphRewriter map_conv3d_to_convolution;
80
- map_conv3d_to_convolution.RegisterRewritePattern (conv3d_pattern, convolution_pattern);
81
- map_conv3d_to_convolution.runOnGraph (graph);
108
+ // Schema is aten::conv3d(%x, %w, %b, %s, %p, %d, %g) --> 7 inputs
109
+ replaceConv (graph->block (), conv3d_node_kind, convolution_pattern, 7 );
82
110
LOG_GRAPH (" Post map conv3d -> _convolution: " << *graph);
83
111
}
84
112
113
+ void ConvTransposed3DToConvolution (std::shared_ptr<torch::jit::Graph>& graph) {
114
+ const std::string conv_transpose3d_node_kind = " aten::conv_transpose3d" ;
115
+ const std::string convolution_pattern = R"IR(
116
+ graph(%x, %w, %b, %s, %p, %o, %g, %d):
117
+ %1 : bool = prim::Constant[value=1]()
118
+ %2 : bool = prim::Constant[value=1]()
119
+ %4 : Tensor = aten::_convolution(%x, %w, %b, %s, %p, %d, %1, %o, %g, %2, %2, %2, %2)
120
+ return (%4))IR" ;
121
+
122
+ // Schema is aten::conv_transpose3d(%x, %w, %b, %s, %p, %o, %g, %d) --> 8 inputs
123
+ replaceConv (graph->block (), conv_transpose3d_node_kind, convolution_pattern, 8 );
124
+ LOG_GRAPH (" Post map conv_transpose3d -> _convolution: " << *graph);
125
+ }
126
+
85
127
} // namespace passes
86
128
} // namespace lowering
87
129
} // namespace core
0 commit comments