Skip to content

Commit 99b4a3d

Browse files
author
Github Executorch
committed
Update on "Integrate torchgen exception boundary with ExecuTorch"
As of #7746, we build with exceptions by default, so we just need to use them. Differential Revision: [D67904052](https://our.internmc.facebook.com/intern/diff/D67904052/) [ghstack-poisoned]
2 parents 59b5088 + 2bb7777 commit 99b4a3d

File tree

4 files changed

+88
-26
lines changed

4 files changed

+88
-26
lines changed

backends/arm/test/misc/test_tosa_spec.py

Lines changed: 37 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -20,16 +20,22 @@
2020
"TOSA-0.80+MI+8k",
2121
"TOSA-0.80+BI+u55",
2222
]
23-
test_valid_1_00_strings = [
24-
"TOSA-1.00.0+INT+FP+fft",
25-
"TOSA-1.00.0+FP+bf16+fft",
26-
"TOSA-1.00.0+INT+int4+cf",
27-
"TOSA-1.00.0+FP+cf+bf16+8k",
28-
"TOSA-1.00.0+FP+INT+bf16+fft+int4+cf",
29-
"TOSA-1.00.0+FP+INT+fft+int4+cf+8k",
23+
test_valid_1_0_strings = [
24+
"TOSA-1.0.0+INT+FP+fft",
25+
"TOSA-1.0.0+FP+bf16+fft",
26+
"TOSA-1.0.0+INT+int4+cf",
27+
"TOSA-1.0.0+FP+cf+bf16+8k",
28+
"TOSA-1.0.0+FP+INT+bf16+fft+int4+cf",
29+
"TOSA-1.0.0+FP+INT+fft+int4+cf+8k",
30+
"TOSA-1.0+INT+FP+fft",
31+
"TOSA-1.0+FP+bf16+fft",
32+
"TOSA-1.0+INT+int4+cf",
33+
"TOSA-1.0+FP+cf+bf16+8k",
34+
"TOSA-1.0+FP+INT+bf16+fft+int4+cf",
35+
"TOSA-1.0+FP+INT+fft+int4+cf+8k",
3036
]
3137

32-
test_valid_1_00_extensions = {
38+
test_valid_1_0_extensions = {
3339
"INT": ["int16", "int4", "var", "cf"],
3440
"FP": ["bf16", "fp8e4m3", "fp8e5m2", "fft", "var", "cf"],
3541
}
@@ -40,19 +46,19 @@
4046
"TOSA-0.80+8k",
4147
"TOSA-0.80+BI+MI",
4248
"TOSA-0.80+BI+U55",
43-
"TOSA-1.00.0+fft",
44-
"TOSA-1.00.0+fp+bf16+fft",
45-
"TOSA-1.00.0+INT+INT4+cf",
46-
"TOSA-1.00.0+BI",
47-
"TOSA-1.00.0+FP+FP+INT",
48-
"TOSA-1.00.0+FP+CF+bf16",
49-
"TOSA-1.00.0+BF16+fft+int4+cf+INT",
49+
"TOSA-1.0.0+fft",
50+
"TOSA-1.0.0+fp+bf16+fft",
51+
"TOSA-1.0.0+INT+INT4+cf",
52+
"TOSA-1.0.0+BI",
53+
"TOSA-1.0.0+FP+FP+INT",
54+
"TOSA-1.0.0+FP+CF+bf16",
55+
"TOSA-1.0.0+BF16+fft+int4+cf+INT",
5056
]
5157

5258
test_compile_specs = [
5359
([CompileSpec("tosa_version", "TOSA-0.80+BI".encode())],),
5460
([CompileSpec("tosa_version", "TOSA-0.80+BI+u55".encode())],),
55-
([CompileSpec("tosa_version", "TOSA-1.00.0+INT".encode())],),
61+
([CompileSpec("tosa_version", "TOSA-1.0.0+INT".encode())],),
5662
]
5763

5864
test_compile_specs_no_version = [
@@ -70,8 +76,8 @@ def test_version_string_0_80(self, version_string: str):
7076
assert isinstance(tosa_spec, Tosa_0_80)
7177
assert tosa_spec.profile in ["BI", "MI"]
7278

73-
@parameterized.expand(test_valid_1_00_strings) # type: ignore[misc]
74-
def test_version_string_1_00(self, version_string: str):
79+
@parameterized.expand(test_valid_1_0_strings) # type: ignore[misc]
80+
def test_version_string_1_0(self, version_string: str):
7581
tosa_spec = TosaSpecification.create_from_string(version_string)
7682
assert isinstance(tosa_spec, Tosa_1_00)
7783
assert [profile in ["INT", "FP"] for profile in tosa_spec.profiles].count(
@@ -80,7 +86,7 @@ def test_version_string_1_00(self, version_string: str):
8086

8187
for profile in tosa_spec.profiles:
8288
assert [
83-
e in test_valid_1_00_extensions[profile] for e in tosa_spec.extensions
89+
e in test_valid_1_0_extensions[profile] for e in tosa_spec.extensions
8490
]
8591

8692
@parameterized.expand(test_invalid_strings) # type: ignore[misc]
@@ -103,3 +109,15 @@ def test_create_from_invalid_compilespec(self, compile_specs: list[CompileSpec])
103109
tosa_spec = TosaSpecification.create_from_compilespecs(compile_specs)
104110

105111
assert tosa_spec is None
112+
113+
@parameterized.expand(test_valid_0_80_strings)
114+
def test_correct_string_representation_0_80(self, version_string: str):
115+
tosa_spec = TosaSpecification.create_from_string(version_string)
116+
assert isinstance(tosa_spec, Tosa_0_80)
117+
assert f"{tosa_spec}" == version_string
118+
119+
@parameterized.expand(test_valid_1_0_strings)
120+
def test_correct_string_representation_1_0(self, version_string: str):
121+
tosa_spec = TosaSpecification.create_from_string(version_string)
122+
assert isinstance(tosa_spec, Tosa_1_00)
123+
assert f"{tosa_spec}" == version_string

backends/arm/tosa_specification.py

Lines changed: 9 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
# Copyright 2024 Arm Limited and/or its affiliates.
1+
# Copyright 2024-2025 Arm Limited and/or its affiliates.
22
#
33
# This source code is licensed under the BSD-style license found in the
44
# LICENSE file in the root directory of this source tree.
@@ -14,7 +14,9 @@
1414
import re
1515
from typing import List
1616

17-
from executorch.exir.backend.compile_spec_schema import CompileSpec
17+
from executorch.exir.backend.compile_spec_schema import ( # type: ignore[import-untyped]
18+
CompileSpec,
19+
)
1820
from packaging.version import Version
1921

2022

@@ -131,7 +133,7 @@ def __init__(self, version: Version, extras: List[str]):
131133
def __repr__(self):
132134
extensions = ""
133135
if self.level_8k:
134-
extensions += "+8K"
136+
extensions += "+8k"
135137
if self.is_U55_subset:
136138
extensions += "+u55"
137139
return f"TOSA-{str(self.version)}+{self.profile}{extensions}"
@@ -207,7 +209,10 @@ def _get_extensions_string(self) -> str:
207209
return "".join(["+" + e for e in self.extensions])
208210

209211
def __repr__(self):
210-
return f"TOSA-{self.version}{self._get_profiles_string()}{self._get_profiles_string()}"
212+
extensions = self._get_extensions_string()
213+
if self.level_8k:
214+
extensions += "+8k"
215+
return f"TOSA-{self.version}{self._get_profiles_string()}{extensions}"
211216

212217
def __hash__(self) -> int:
213218
return hash(str(self.version) + self._get_profiles_string())

codegen/tools/gen_all_oplist.py

Lines changed: 41 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -47,6 +47,31 @@ def resolve_model_file_path_to_buck_target(model_file_path: str) -> str:
4747
return real_path
4848

4949

50+
def _raise_if_check_prim_ops_fail(options):
51+
52+
# Error out if we have more than one targets registering prim ops.
53+
if options.DEBUG_ONLY_check_prim_ops and len(options.DEBUG_ONLY_check_prim_ops) > 1:
54+
assert (
55+
options.DEBUG_ONLY_check_prim_ops[0] == "@"
56+
), "DEBUG_ONLY_check_prim_ops is not a valid file path, or it doesn't start with '@'. This is likely a BUCK issue."
57+
58+
prim_ops_targets_file = options.DEBUG_ONLY_check_prim_ops[1:]
59+
with open(prim_ops_targets_file, "r") as file:
60+
prim_ops_targets = file.read().split()
61+
if len(prim_ops_targets) > 1:
62+
# Yellow bold: \033[33;1m
63+
# Red bold: \033[31;1m
64+
# Green bold: \033[32;1m
65+
error = (
66+
"It seems this target is depending on more than 1 `prim_ops_registry` targets: "
67+
+ f'\033[33;1m\n{", ".join(prim_ops_targets)}\033[0m. \nThis will likely cause errors such as: '
68+
+ "\n \033[31;1mRe-registering aten::sym_size.int...\033[0m"
69+
+ "\nTo find out the dependency chain, run the following command: "
70+
+ f'\n \033[32;1mbuck2 cquery <mode> "allpaths(<target>, {prim_ops_targets[0]})"\033[0m'
71+
)
72+
raise Exception(error)
73+
74+
5075
def main(argv: List[Any]) -> None:
5176
"""This binary generates 3 files:
5277
@@ -95,8 +120,18 @@ def main(argv: List[Any]) -> None:
95120
default=False,
96121
required=False,
97122
)
123+
parser.add_argument(
124+
"--DEBUG-ONLY-check-prim-ops",
125+
"--DEBUG_ONLY_check_prim_ops",
126+
help=(
127+
"Useful argument to take BUCK targets that registers prim ops and error out if we have more than 1."
128+
),
129+
required=False,
130+
)
98131
options = parser.parse_args(argv)
99132

133+
_raise_if_check_prim_ops_fail(options)
134+
100135
# Check if the build has any dependency on any selective build target. If we have a target, BUCK shold give us either:
101136
# 1. a yaml file containing selected ops (could be empty), or
102137
# 2. a non-empty list of yaml files in the `model_file_list_path` or
@@ -153,14 +188,17 @@ def main(argv: List[Any]) -> None:
153188
debug_info_2 = ",".join(
154189
model_dict["operators"][op_name]["debug_info"]
155190
)
156-
error = f"Operator {op_name} is used in 2 models: {debug_info_1} and {debug_info_2}"
191+
# Yellow bold: \033[33;1m
192+
# Red bold: \033[31;1m
193+
# Green bold: \033[32;1m
194+
error = f"\033[31;1mOperator {op_name} is used in 2 models: \033[33;1m{debug_info_1} and {debug_info_2}\033[0m"
157195
if "//" not in debug_info_1 and "//" not in debug_info_2:
158196
error += "\nWe can't determine what BUCK targets these model files belong to."
159197
tail = "."
160198
else:
161199
error += "\nPlease run the following commands to find out where is the BUCK target being added as a dependency to your target:\n"
162-
error += f'\n buck2 cquery <mode> "allpaths(<target>, {debug_info_1})"'
163-
error += f'\n buck2 cquery <mode> "allpaths(<target>, {debug_info_2})"'
200+
error += f'\n \033[32;1mbuck2 cquery <mode> "allpaths(<target>, {debug_info_1})"\033[0m'
201+
error += f'\n \033[32;1mbuck2 cquery <mode> "allpaths(<target>, {debug_info_2})"\033[0m'
164202
tail = "as well as results from BUCK commands listed above."
165203

166204
error += (

shim/xplat/executorch/codegen/codegen.bzl

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -706,6 +706,7 @@ def executorch_ops_check(
706706
"--model_file_list_path $(@query_outputs \"filter('.*_et_oplist', deps(set({deps})))\") " +
707707
"--allow_include_all_overloads " +
708708
"--check_ops_not_overlapping " +
709+
"--DEBUG_ONLY_check_prim_ops $(@query_targets \"filter('prim_ops_registry(?:_static|_aten)?$', deps(set({deps})))\") " +
709710
"--output_dir $OUT ").format(deps = " ".join(["\'{}\'".format(d) for d in deps])),
710711
define_static_target = False,
711712
platforms = kwargs.pop("platforms", get_default_executorch_platforms()),

0 commit comments

Comments
 (0)