Skip to content

add util to get non lower nodes #1088

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed
wants to merge 1 commit into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions exir/backend/test/TARGETS
Original file line number Diff line number Diff line change
Expand Up @@ -266,6 +266,7 @@ python_unittest(
"test_utils.py",
],
deps = [
":op_partitioner_demo",
"//caffe2:torch",
"//executorch/exir:lib",
"//executorch/exir/backend:backend_api",
Expand Down
24 changes: 24 additions & 0 deletions exir/backend/test/test_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,9 @@
from executorch.exir import CaptureConfig
from executorch.exir.backend.backend_api import to_backend
from executorch.exir.backend.partitioner import Partitioner, PartitionResult
from executorch.exir.backend.test.op_partitioner_demo import AddMulPartitionerDemo
from executorch.exir.backend.utils import (
get_non_lowered_nodes,
is_identical_graph,
remove_first_quant_and_last_dequant,
replace_quantized_partition_with_op,
Expand Down Expand Up @@ -356,3 +358,25 @@ def forward(self, input):
FileCheck().check_count("test_lib.test_q_linear", 1, exactly=True).run(
actual_static_quant_linear.code
)

def test_get_non_lowered_nodes(self):
class Model(torch.nn.Module):
def __init__(self):
super().__init__()

def forward(self, a, x, b):
y = torch.mm(a, x)
z = y + b
a = z - a
y = torch.mm(a, x)
z = y + b
return z

m = Model()
inputs = (torch.randn(2, 2), torch.randn(2, 2), torch.randn(2, 2))
edge = exir.capture(m, inputs, exir.CaptureConfig()).to_edge()
edge.exported_program = to_backend(edge.exported_program, AddMulPartitionerDemo)
edge.dump()
number_of_cpu_nodes = get_non_lowered_nodes(edge.exported_program.graph)
# Only sub is not not lowerable
self.assertEqual(len(number_of_cpu_nodes), 1)
28 changes: 28 additions & 0 deletions exir/backend/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,11 +5,13 @@
# LICENSE file in the root directory of this source tree.

import logging
import operator
from collections import defaultdict
from functools import lru_cache
from typing import Dict, Iterable, List, Optional, Set, Tuple, Union

import torch
from executorch.exir.delegate import executorch_call_delegate
from executorch.exir.dialects._ops import ops as exir_ops

from executorch.exir.lowered_backend_module import create_submodule_from_nodes
Expand Down Expand Up @@ -165,6 +167,32 @@ def replace_quantized_partition_with_op(
return (replaced_op, dequant_nodes, quant_nodes)


def _get_item_from_executorch_call_delegate(node: torch.fx.Node) -> bool:
"""
Check if the node is the getitem followed by executorch_call_delegate node. These getitems node
are just for getting the result from delegate because the input/output to delegates are flattened
"""
return (
node.target == operator.getitem
and len(node.args) == 2
and node.args[0].target == executorch_call_delegate # pyre-ignore
and isinstance(node.args[1], int)
)


def get_non_lowered_nodes(graph: torch.fx.Graph) -> List[torch.fx.Node]:
"""
Returns a list of non lowered nodes in the graph module.
"""
return [
node
for node in graph.nodes
if node.op == "call_function"
and node.target != executorch_call_delegate
and (not _get_item_from_executorch_call_delegate(node))
]


# TODO - style: use templated types
class DelegateMappingBuilder:
"""
Expand Down