Skip to content

Commit 34c30a3

Browse files
authored
Implement a coversion pass: pow(2,x) to mul(x,x).
Differential Revision: D73405855 Pull Request resolved: #10373
1 parent 8191c35 commit 34c30a3

File tree

2 files changed

+59
-0
lines changed

2 files changed

+59
-0
lines changed

backends/cadence/aot/replace_ops.py

Lines changed: 29 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2259,6 +2259,34 @@ def call(self, graph_module: torch.fx.GraphModule) -> PassResult:
22592259
return result
22602260

22612261

2262+
2263+
@register_cadence_pass(CadencePassAttribute(opt_level=1))
2264+
class ReplacePowWithMullPass(ExportPass):
2265+
"""
2266+
Replace the pow op with degree 2 for a mul op.
2267+
"""
2268+
2269+
def call_operator(
2270+
self,
2271+
op,
2272+
args: Tuple[Argument, ...],
2273+
kwargs: Dict[str, Argument],
2274+
meta: NodeMetadata,
2275+
) -> ProxyValue:
2276+
# TODO(eigen): Add support for other degrees.
2277+
if op not in {
2278+
exir_ops.edge.aten.pow.Scalar,
2279+
} or args[0] != 2:
2280+
return super().call_operator(op, args, kwargs, meta)
2281+
2282+
return super().call_operator(
2283+
exir_ops.edge.aten.mul.Tensor,
2284+
(args[1], args[1]),
2285+
{},
2286+
meta,
2287+
)
2288+
2289+
22622290
# This class encapsulates all the functions that replace/switch one op in the
22632291
# graph with another.
22642292
class CadenceReplaceOpsInGraph:
@@ -2299,4 +2327,5 @@ class CadenceReplaceOpsInGraph:
22992327
ReplaceWhereWithFullArgsWithWhereScalar,
23002328
ReplaceGeluWithApproximateGeluPass,
23012329
ReplaceSplitWithSlicePass,
2330+
ReplacePowWithMullPass,
23022331
]

backends/cadence/aot/tests/test_replace_ops_passes.py

Lines changed: 30 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -30,6 +30,7 @@
3030
ReplaceEmptyTensorsWithFullPass,
3131
ReplaceFunctionallyEquivalentOpTargets,
3232
ReplaceGeluWithApproximateGeluPass,
33+
ReplacePowWithMullPass,
3334
ReplaceIm2RowWithViewPass,
3435
ReplaceLinearWithFullyConnectedOpPass,
3536
ReplaceMMWithAddMMPass,
@@ -1334,6 +1335,35 @@ def test_replace_split_with_sizes_with_slice(self):
13341335
2,
13351336
)
13361337

1338+
def test_replace_pow_with_mul(self):
1339+
class Pow(torch.nn.Module):
1340+
def forward(self, input):
1341+
return torch.ops.aten.pow.Scalar(2, input)
1342+
1343+
input = torch.randn(2, 1, 64)
1344+
1345+
graph_module = export_to_edge(Pow(), (input,)).exported_program().graph_module
1346+
1347+
p = ReplacePowWithMullPass()
1348+
graph_after_passes = cast(PassResult, p(graph_module)).graph_module
1349+
1350+
1351+
self.assertEqual(
1352+
count_node(
1353+
graph_after_passes,
1354+
exir_ops.edge.aten.pow.Scalar,
1355+
),
1356+
0,
1357+
)
1358+
1359+
self.assertEqual(
1360+
count_node(
1361+
graph_after_passes,
1362+
exir_ops.edge.aten.mul.Tensor,
1363+
),
1364+
1,
1365+
)
1366+
13371367

13381368
class TestReplaceIm2rowWithViewPass(unittest.TestCase):
13391369
def test_no_replacement_for_conv(self):

0 commit comments

Comments
 (0)