Skip to content

Commit b32f5a3

Browse files
cccclaifacebook-github-bot
authored andcommitted
add util to get non lower nodes (#1088)
Summary: Pull Request resolved: #1088 Add a util function to get the non lowered nodes from the graph so it's easy to do post processing Reviewed By: tarun292 Differential Revision: D50507998 fbshipit-source-id: fa8770218185df7db3f35454b7538864df40b182
1 parent 28e1165 commit b32f5a3

File tree

3 files changed

+53
-0
lines changed

3 files changed

+53
-0
lines changed

exir/backend/test/TARGETS

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -266,6 +266,7 @@ python_unittest(
266266
"test_utils.py",
267267
],
268268
deps = [
269+
":op_partitioner_demo",
269270
"//caffe2:torch",
270271
"//executorch/exir:lib",
271272
"//executorch/exir/backend:backend_api",

exir/backend/test/test_utils.py

Lines changed: 24 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,9 @@
1111
from executorch.exir import CaptureConfig
1212
from executorch.exir.backend.backend_api import to_backend
1313
from executorch.exir.backend.partitioner import Partitioner, PartitionResult
14+
from executorch.exir.backend.test.op_partitioner_demo import AddMulPartitionerDemo
1415
from executorch.exir.backend.utils import (
16+
get_non_lowered_nodes,
1517
is_identical_graph,
1618
remove_first_quant_and_last_dequant,
1719
replace_quantized_partition_with_op,
@@ -356,3 +358,25 @@ def forward(self, input):
356358
FileCheck().check_count("test_lib.test_q_linear", 1, exactly=True).run(
357359
actual_static_quant_linear.code
358360
)
361+
362+
def test_get_non_lowered_nodes(self):
363+
class Model(torch.nn.Module):
364+
def __init__(self):
365+
super().__init__()
366+
367+
def forward(self, a, x, b):
368+
y = torch.mm(a, x)
369+
z = y + b
370+
a = z - a
371+
y = torch.mm(a, x)
372+
z = y + b
373+
return z
374+
375+
m = Model()
376+
inputs = (torch.randn(2, 2), torch.randn(2, 2), torch.randn(2, 2))
377+
edge = exir.capture(m, inputs, exir.CaptureConfig()).to_edge()
378+
edge.exported_program = to_backend(edge.exported_program, AddMulPartitionerDemo)
379+
edge.dump()
380+
number_of_cpu_nodes = get_non_lowered_nodes(edge.exported_program.graph)
381+
# Only sub is not not lowerable
382+
self.assertEqual(len(number_of_cpu_nodes), 1)

exir/backend/utils.py

Lines changed: 28 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -5,11 +5,13 @@
55
# LICENSE file in the root directory of this source tree.
66

77
import logging
8+
import operator
89
from collections import defaultdict
910
from functools import lru_cache
1011
from typing import Dict, Iterable, List, Optional, Set, Tuple, Union
1112

1213
import torch
14+
from executorch.exir.delegate import executorch_call_delegate
1315
from executorch.exir.dialects._ops import ops as exir_ops
1416

1517
from executorch.exir.lowered_backend_module import create_submodule_from_nodes
@@ -165,6 +167,32 @@ def replace_quantized_partition_with_op(
165167
return (replaced_op, dequant_nodes, quant_nodes)
166168

167169

170+
def _get_item_from_executorch_call_delegate(node: torch.fx.Node) -> bool:
171+
"""
172+
Check if the node is the getitem followed by executorch_call_delegate node. These getitems node
173+
are just for getting the result from delegate because the input/output to delegates are flattened
174+
"""
175+
return (
176+
node.target == operator.getitem
177+
and len(node.args) == 2
178+
and node.args[0].target == executorch_call_delegate # pyre-ignore
179+
and isinstance(node.args[1], int)
180+
)
181+
182+
183+
def get_non_lowered_nodes(graph: torch.fx.Graph) -> List[torch.fx.Node]:
184+
"""
185+
Returns a list of non lowered nodes in the graph module.
186+
"""
187+
return [
188+
node
189+
for node in graph.nodes
190+
if node.op == "call_function"
191+
and node.target != executorch_call_delegate
192+
and (not _get_item_from_executorch_call_delegate(node))
193+
]
194+
195+
168196
# TODO - style: use templated types
169197
class DelegateMappingBuilder:
170198
"""

0 commit comments

Comments
 (0)