Skip to content

Commit 54e022f

Browse files
committed
fix: Add schemas to conv replace
- Add support for transposed conv2d and conv3d, as well as for conv3d - Add testing for all convolution functions, rename test accordingly
1 parent fb42d42 commit 54e022f

File tree

6 files changed

+456
-204
lines changed

6 files changed

+456
-204
lines changed

core/lowering/lowering.cpp

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -124,7 +124,9 @@ void LowerGraph(std::shared_ptr<torch::jit::Graph>& g, std::vector<torch::jit::I
124124
passes::Conv1DToConvolution(g);
125125
passes::ConvTransposed1DToConvolution(g);
126126
passes::Conv2DToConvolution(g);
127+
passes::ConvTransposed2DToConvolution(g);
127128
passes::Conv3DToConvolution(g);
129+
passes::ConvTransposed3DToConvolution(g);
128130
passes::FuseAddMMBranches(g);
129131
passes::RemoveBNDimCheck(g);
130132
// torch::jit::UnrollLoops(g);

core/lowering/passes/convNd_to_convolution.cpp

Lines changed: 28 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -82,6 +82,20 @@ void Conv2DToConvolution(std::shared_ptr<torch::jit::Graph>& graph) {
8282
LOG_GRAPH("Post map conv2d -> _convolution: " << *graph);
8383
}
8484

85+
void ConvTransposed2DToConvolution(std::shared_ptr<torch::jit::Graph>& graph) {
86+
const std::string conv_transpose2d_node_kind = "aten::conv_transpose2d";
87+
const std::string convolution_pattern = R"IR(
88+
graph(%x, %w, %b, %s, %p, %o, %g, %d):
89+
%1 : bool = prim::Constant[value=1]()
90+
%2 : bool = prim::Constant[value=1]()
91+
%4 : Tensor = aten::_convolution(%x, %w, %b, %s, %p, %d, %1, %o, %g, %2, %2, %2, %2)
92+
return (%4))IR";
93+
94+
// Schema is aten::conv_transpose2d(%x, %w, %b, %s, %p, %o, %g, %d) --> 8 inputs
95+
replaceConv(graph->block(), conv_transpose2d_node_kind, convolution_pattern, 8);
96+
LOG_GRAPH("Post map conv_transpose2d -> _convolution: " << *graph);
97+
}
98+
8599
void Conv3DToConvolution(std::shared_ptr<torch::jit::Graph>& graph) {
86100
const std::string conv3d_node_kind = "aten::conv3d";
87101
const std::string convolution_pattern = R"IR(
@@ -96,6 +110,20 @@ void Conv3DToConvolution(std::shared_ptr<torch::jit::Graph>& graph) {
96110
LOG_GRAPH("Post map conv3d -> _convolution: " << *graph);
97111
}
98112

113+
void ConvTransposed3DToConvolution(std::shared_ptr<torch::jit::Graph>& graph) {
114+
const std::string conv_transpose3d_node_kind = "aten::conv_transpose3d";
115+
const std::string convolution_pattern = R"IR(
116+
graph(%x, %w, %b, %s, %p, %o, %g, %d):
117+
%1 : bool = prim::Constant[value=1]()
118+
%2 : bool = prim::Constant[value=1]()
119+
%4 : Tensor = aten::_convolution(%x, %w, %b, %s, %p, %d, %1, %o, %g, %2, %2, %2, %2)
120+
return (%4))IR";
121+
122+
// Schema is aten::conv_transpose3d(%x, %w, %b, %s, %p, %o, %g, %d) --> 8 inputs
123+
replaceConv(graph->block(), conv_transpose3d_node_kind, convolution_pattern, 8);
124+
LOG_GRAPH("Post map conv_transpose3d -> _convolution: " << *graph);
125+
}
126+
99127
} // namespace passes
100128
} // namespace lowering
101129
} // namespace core

core/lowering/passes/passes.h

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,9 @@ void NotateModuleForFallback(
1515
void Conv1DToConvolution(std::shared_ptr<torch::jit::Graph>& graph);
1616
void ConvTransposed1DToConvolution(std::shared_ptr<torch::jit::Graph>& graph);
1717
void Conv2DToConvolution(std::shared_ptr<torch::jit::Graph>& graph);
18+
void ConvTransposed2DToConvolution(std::shared_ptr<torch::jit::Graph>& graph);
1819
void Conv3DToConvolution(std::shared_ptr<torch::jit::Graph>& graph);
20+
void ConvTransposed3DToConvolution(std::shared_ptr<torch::jit::Graph>& graph);
1921
void FuseAddMMBranches(std::shared_ptr<torch::jit::Graph> graph);
2022
void LinearToAddMM(std::shared_ptr<torch::jit::Graph>& graph);
2123
void EliminateExceptionOrPassPattern(std::shared_ptr<torch::jit::Graph> graph);

tests/core/lowering/BUILD

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -28,7 +28,7 @@ cc_test(
2828
)
2929

3030
lowering_test(
31-
name = "test_conv1d_pass",
31+
name = "test_conv_pass",
3232
)
3333

3434
lowering_test(
@@ -102,7 +102,7 @@ lowering_test(
102102
test_suite(
103103
name = "lowering_tests",
104104
tests = [
105-
":test_conv1d_pass",
105+
":test_conv_pass",
106106
":test_device_casting",
107107
":test_exception_elimination_pass",
108108
":test_linear_to_addmm",

tests/core/lowering/test_conv1d_pass.cpp

Lines changed: 0 additions & 202 deletions
This file was deleted.

0 commit comments

Comments
 (0)