Skip to content

Commit 43a11c9

Browse files
committed
Canonicalize aten::multiply to aten::mul
1 parent 2b1cedf commit 43a11c9

File tree

2 files changed

+34
-0
lines changed

2 files changed

+34
-0
lines changed

core/lowering/passes/op_aliasing.cpp

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -35,6 +35,20 @@ void AliasOperators(std::shared_ptr<torch::jit::Graph>& graph) {
3535
rewrite_scatter.RegisterRewritePattern(scatter_sub_pattern, scatter_pattern);
3636
rewrite_scatter.runOnGraph(graph);
3737
LOG_GRAPH("Post map scatter_ -> scatter: " << *graph);
38+
39+
std::string multiply_pattern = R"IR(
40+
graph(%self, %other):
41+
%o : Tensor = aten::multiply(%self, %other)
42+
return (%o))IR";
43+
std::string mul_pattern = R"IR(
44+
graph(%self, %other):
45+
%o : Tensor = aten::mul(%self, %other)
46+
return (%o))IR";
47+
48+
torch::jit::SubgraphRewriter rewrite_multiply;
49+
rewrite_multiply.RegisterRewritePattern(multiply_pattern, mul_pattern);
50+
rewrite_multiply.runOnGraph(graph);
51+
LOG_GRAPH("Post map multiply -> mul: " << *graph);
3852
}
3953

4054
} // namespace passes

tests/core/lowering/test_operator_aliasing_pass.cpp

Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -25,3 +25,23 @@ TEST(LoweringPasses, LoweringTrueDivideCorrectly) {
2525

2626
ASSERT_TRUE(!torch::jit::findPatternMatches(*tg, *sg).empty());
2727
}
28+
29+
TEST(LoweringPasses, LoweringMultiplyCorrectly) {
30+
std::string source_graph = R"IR(
31+
graph(%s, %o):
32+
%2 = aten::multiply(%s, %o)
33+
return (%2))IR";
34+
std::string target_graph = R"IR(
35+
graph(%s, %o):
36+
%2 = aten::mul(%s, %o)
37+
return (%2))IR";
38+
39+
auto sg = std::make_shared<torch::jit::Graph>();
40+
torch::jit::parseIR(source_graph, sg.get());
41+
torch_tensorrt::core::lowering::passes::AliasOperators(sg);
42+
43+
auto tg = std::make_shared<torch::jit::Graph>();
44+
torch::jit::parseIR(target_graph, tg.get());
45+
46+
ASSERT_TRUE(!torch::jit::findPatternMatches(*tg, *sg).empty());
47+
}

0 commit comments

Comments
 (0)