Skip to content

Commit 2cfba1a

Browse files
authored
patch rms norm recompose pass
Differential Revision: D69090807 Pull Request resolved: #8170
1 parent 821a2fe commit 2cfba1a

File tree

1 file changed

+3
-1
lines changed

1 file changed

+3
-1
lines changed

backends/qualcomm/_passes/recompose_rms_norm.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -34,7 +34,9 @@ def _get_gamma_node(self, output_node):
3434

3535
def call(self, graph_module: torch.fx.GraphModule):
3636
graph = graph_module.graph
37-
partitions = get_source_partitions(graph, [torch.nn.RMSNorm])
37+
partitions = get_source_partitions(
38+
graph, [torch.nn.RMSNorm, torch.ops.aten.rms_norm.default]
39+
)
3840
for _, src_partitions in partitions.items():
3941
for src_partition in src_partitions:
4042
input_len = len(src_partition.input_nodes)

0 commit comments

Comments
 (0)