Skip to content

Commit def9762

Browse files
Min Guofacebook-github-bot
authored andcommitted
debug_bmm_to_matmul
1 parent a142ddf commit def9762

File tree

2 files changed

+7
-1
lines changed

2 files changed

+7
-1
lines changed

backends/qualcomm/_passes/convert_bmm_to_matmul.py

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -50,9 +50,13 @@ def call(self, graph_module: torch.fx.GraphModule):
5050
)
5151
for _, src_partitions in partitions.items():
5252
for src_partition in src_partitions:
53+
print("partition node")
54+
print(src_partition.nodes)
5355
op_cnt = Counter([n.target for n in src_partition.nodes])
5456
if op_cnt not in self.patterns:
55-
continue
57+
raise AssertionError(
58+
"Found a new pattern needed be converted to linear op"
59+
)
5660

5761
inputs = src_partition.input_nodes
5862
bmm_node = [n for n in src_partition.nodes if n.target == self.bmm][0]

backends/qualcomm/_passes/convert_to_linear.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -193,6 +193,8 @@ def _convert(self, graph_module: torch.fx.GraphModule):
193193
partitions = get_source_partitions(graph_module.graph, [torch.nn.Linear])
194194
for _, src_partitions in partitions.items():
195195
for src_partition in src_partitions:
196+
print("partition node to linear")
197+
print(src_partition.nodes)
196198
op_cnt = Counter(
197199
[
198200
n.target

0 commit comments

Comments
 (0)