Skip to content

migrate utils from jarvis to cadence #6720

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 1 commit into from
Nov 12, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
13 changes: 13 additions & 0 deletions backends/cadence/aot/TARGETS
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@ load(
"CXX",
)
load("@fbsource//xplat/executorch/codegen:codegen.bzl", "executorch_generated_lib")
load("@fbcode_macros//build_defs:python_unittest.bzl", "python_unittest")

oncall("odai_jarvis")

Expand Down Expand Up @@ -103,3 +104,15 @@ executorch_generated_lib(
"//executorch/kernels/portable:operators",
],
)

python_unittest(
name = "test_pass_filter",
srcs = [
"tests/test_pass_filter.py",
],
typing = True,
deps = [
":pass_utils",
"//executorch/exir:pass_base",
],
)
8 changes: 4 additions & 4 deletions backends/cadence/aot/pass_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,26 +28,26 @@ class CadencePassAttribute:


# A dictionary that maps an ExportPass to its attributes.
_ALL_CADENCE_PASSES: dict[ExportPass, CadencePassAttribute] = {}
ALL_CADENCE_PASSES: dict[ExportPass, CadencePassAttribute] = {}


def get_cadence_pass_attribute(p: ExportPass) -> CadencePassAttribute:
return _ALL_CADENCE_PASSES[p]
return ALL_CADENCE_PASSES[p]


# A decorator that registers a pass.
def register_cadence_pass(
pass_attribute: CadencePassAttribute,
) -> Callable[[ExportPass], ExportPass]:
def wrapper(cls: ExportPass) -> ExportPass:
_ALL_CADENCE_PASSES[cls] = pass_attribute
ALL_CADENCE_PASSES[cls] = pass_attribute
return cls

return wrapper


def get_all_available_cadence_passes() -> Set[ExportPass]:
return set(_ALL_CADENCE_PASSES.keys())
return set(ALL_CADENCE_PASSES.keys())


# Create a new filter to filter out relevant passes from all Jarvis passes.
Expand Down
160 changes: 160 additions & 0 deletions backends/cadence/aot/tests/test_pass_filter.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,160 @@
# (c) Meta Platforms, Inc. and affiliates. Confidential and proprietary.

# pyre-unsafe


import unittest

from copy import deepcopy

from executorch.backends.cadence.aot import pass_utils
from executorch.backends.cadence.aot.pass_utils import (
ALL_CADENCE_PASSES,
CadencePassAttribute,
create_cadence_pass_filter,
register_cadence_pass,
)

from executorch.exir.pass_base import ExportPass


class TestBase(unittest.TestCase):
def setUp(self):
# Before running each test, create a copy of _all_passes to later restore it after test.
# This avoids messing up the original _all_passes when running tests.
self._all_passes_original = deepcopy(ALL_CADENCE_PASSES)
# Clear _all_passes to do a clean test. It'll be restored after each test in tearDown().
pass_utils.ALL_CADENCE_PASSES.clear()

def tearDown(self):
# Restore _all_passes to original state before test.
pass_utils.ALL_CADENCE_PASSES = self._all_passes_original

def get_filtered_passes(self, filter_):
return {cls: attr for cls, attr in ALL_CADENCE_PASSES.items() if filter_(cls)}


# Test pass registration
class TestPassRegistration(TestBase):
def test_register_cadence_pass(self):
pass_attr_O0 = CadencePassAttribute(opt_level=0)
pass_attr_debug = CadencePassAttribute(opt_level=None, debug_pass=True)
pass_attr_O1_all_backends = CadencePassAttribute(
opt_level=1,
)

# Register 1st pass with opt_level=0
@register_cadence_pass(pass_attr_O0)
class DummyPass_O0(ExportPass):
pass

# Register 2nd pass with opt_level=1, all backends.
@register_cadence_pass(pass_attr_O1_all_backends)
class DummyPass_O1_All_Backends(ExportPass):
pass

# Register 3rd pass with opt_level=None, debug=True
@register_cadence_pass(pass_attr_debug)
class DummyPass_Debug(ExportPass):
pass

# Check if the three passes are indeed added into _all_passes
expected_all_passes = {
DummyPass_O0: pass_attr_O0,
DummyPass_Debug: pass_attr_debug,
DummyPass_O1_All_Backends: pass_attr_O1_all_backends,
}
self.assertEqual(pass_utils.ALL_CADENCE_PASSES, expected_all_passes)


# Test pass filtering
class TestPassFiltering(TestBase):
def test_filter_none(self):
pass_attr_O0 = CadencePassAttribute(opt_level=0)
pass_attr_O1_debug = CadencePassAttribute(opt_level=1, debug_pass=True)
pass_attr_O1_all_backends = CadencePassAttribute(
opt_level=1,
)

@register_cadence_pass(pass_attr_O0)
class DummyPass_O0(ExportPass):
pass

@register_cadence_pass(pass_attr_O1_debug)
class DummyPass_O1_Debug(ExportPass):
pass

@register_cadence_pass(pass_attr_O1_all_backends)
class DummyPass_O1_All_Backends(ExportPass):
pass

O1_filter = create_cadence_pass_filter(opt_level=1, debug=True)
O1_filter_passes = self.get_filtered_passes(O1_filter)

# Assert that no passes are filtered out.
expected_passes = {
DummyPass_O0: pass_attr_O0,
DummyPass_O1_Debug: pass_attr_O1_debug,
DummyPass_O1_All_Backends: pass_attr_O1_all_backends,
}
self.assertEqual(O1_filter_passes, expected_passes)

def test_filter_debug(self):
pass_attr_O1_debug = CadencePassAttribute(opt_level=1, debug_pass=True)
pass_attr_O2 = CadencePassAttribute(opt_level=2)

@register_cadence_pass(pass_attr_O1_debug)
class DummyPass_O1_Debug(ExportPass):
pass

@register_cadence_pass(pass_attr_O2)
class DummyPass_O2(ExportPass):
pass

debug_filter = create_cadence_pass_filter(opt_level=2, debug=False)
debug_filter_passes = self.get_filtered_passes(debug_filter)

# Assert that debug passees are filtered out, since the filter explicitly
# chooses debug=False.
self.assertEqual(debug_filter_passes, {DummyPass_O2: pass_attr_O2})

def test_filter_all(self):
@register_cadence_pass(CadencePassAttribute(opt_level=1))
class DummyPass_O1(ExportPass):
pass

@register_cadence_pass(CadencePassAttribute(opt_level=2))
class DummyPass_O2(ExportPass):
pass

debug_filter = create_cadence_pass_filter(opt_level=0)
debug_filter_passes = self.get_filtered_passes(debug_filter)

# Assert that all the passes are filtered out, since the filter only selects
# passes with opt_level <= 0
self.assertEqual(debug_filter_passes, {})

def test_filter_opt_level_None(self):
pass_attr_O1 = CadencePassAttribute(opt_level=1)
pass_attr_O2_debug = CadencePassAttribute(opt_level=2, debug_pass=True)

@register_cadence_pass(CadencePassAttribute(opt_level=None))
class DummyPass_None(ExportPass):
pass

@register_cadence_pass(pass_attr_O1)
class DummyPass_O1(ExportPass):
pass

@register_cadence_pass(pass_attr_O2_debug)
class DummyPass_O2_Debug(ExportPass):
pass

O2_filter = create_cadence_pass_filter(opt_level=2, debug=True)
filtered_passes = self.get_filtered_passes(O2_filter)
# Passes with opt_level=None should never be retained.
expected_passes = {
DummyPass_O1: pass_attr_O1,
DummyPass_O2_Debug: pass_attr_O2_debug,
}
self.assertEqual(filtered_passes, expected_passes)
Loading