Skip to content

Commit 73bb1f9

Browse files
authored
fix the expand_copy lower issue
Differential Revision: D69470884 Pull Request resolved: #8380
1 parent 994d94d commit 73bb1f9

File tree

2 files changed

+5
-2
lines changed

2 files changed

+5
-2
lines changed

backends/qualcomm/_passes/convert_bmm_to_matmul.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -46,7 +46,8 @@ def _get_ordered_inputs(
4646
def call(self, graph_module: torch.fx.GraphModule):
4747
graph = graph_module.graph
4848
partitions = get_source_partitions(
49-
graph, [operator.matmul, torch.matmul, torch.bmm]
49+
graph,
50+
[operator.matmul, torch.matmul, torch.bmm, torch.ops.aten.matmul.default],
5051
)
5152
for _, src_partitions in partitions.items():
5253
for src_partition in src_partitions:

backends/qualcomm/_passes/convert_to_linear.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -190,7 +190,9 @@ def _extract_bmm_ops(self, partitioned_nodes: List[edge_op]) -> List[torch.fx.No
190190
return ret
191191

192192
def _convert(self, graph_module: torch.fx.GraphModule):
193-
partitions = get_source_partitions(graph_module.graph, [torch.nn.Linear])
193+
partitions = get_source_partitions(
194+
graph_module.graph, [torch.nn.Linear, torch.ops.aten.linear.default]
195+
)
194196
for _, src_partitions in partitions.items():
195197
for src_partition in src_partitions:
196198
op_cnt = Counter(

0 commit comments

Comments
 (0)