Skip to content

Commit 2570eec

Browse files
lucylqfacebook-github-bot
authored andcommitted
Generate selective_mobile_ops.h file for executorch (#1021)
Summary: Pull Request resolved: #1021 Generate `selective_mobile_ops_executorch.h` Test using an expecttest. Check cases where we have multiple dtypes and default case (return true). Reviewed By: JacobSzwejbka Differential Revision: D49236461 fbshipit-source-id: 8e229313cc0d3e7175d99b42b9f036203a7808ab
1 parent 0969ca8 commit 2570eec

File tree

3 files changed

+202
-0
lines changed

3 files changed

+202
-0
lines changed
Lines changed: 98 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,98 @@
1+
#!/usr/bin/env fbpython
2+
# Copyright (c) Meta Platforms, Inc. and affiliates.
3+
# All rights reserved.
4+
#
5+
# This source code is licensed under the BSD-style license found in the
6+
# LICENSE file in the root directory of this source tree.
7+
8+
import os
9+
10+
import yaml
11+
12+
from torchgen.code_template import CodeTemplate
13+
14+
15+
ops_and_dtypes_template_str = """((string_view(operator_name).compare("$operator_name") == 0)\n && ($dtype_checks))"""
16+
ops_and_dtypes_template = CodeTemplate(ops_and_dtypes_template_str)
17+
18+
selected_kernel_dtypes_h_template_str = """#pragma once
19+
/**
20+
* Generated by executorch/codegen/tools/gen_selected_mobile_ops_header.py
21+
*/
22+
23+
inline constexpr bool should_include_kernel_dtype(
24+
const char *operator_name,
25+
exec_aten::ScalarType scalar_type
26+
) {
27+
return $body;
28+
}
29+
"""
30+
selected_kernel_dtypes_h_template = CodeTemplate(selected_kernel_dtypes_h_template_str)
31+
32+
# enum from: https://github.com/pytorch/executorch/blob/main/runtime/core/portable_type/scalar_type.h
33+
dtype_enum_to_type = {
34+
"0": "Byte",
35+
"1": "Char",
36+
"2": "Short",
37+
"3": "Int",
38+
"4": "Long",
39+
"5": "Half",
40+
"6": "Float",
41+
"7": "Double",
42+
"8": "ComplexHalf",
43+
"9": "ComplexFloat",
44+
"10": "ComplexDouble",
45+
"11": "Bool",
46+
"12": "QInt8",
47+
"13": "QUInt8",
48+
"14": "QInt32",
49+
"15": "BFloat16",
50+
"16": "QUInt4x2",
51+
"17": "QUInt2x4",
52+
"18": "Bits1x8",
53+
"19": "Bits2x4",
54+
"20": "Bits4x2",
55+
"21": "Bits8",
56+
"22": "Bits16",
57+
}
58+
59+
60+
def write_selected_mobile_ops(output_file_path: str) -> None:
61+
with open(
62+
os.path.join(output_file_path, "selected_operators.yaml"), "r"
63+
) as selected_operators_file:
64+
# Collect et_kernel_metadata from selected_operators.yaml and extract dtypes
65+
# Example format: v1/6;0,1|6;0,1|6;0,1|6;0,1 # Float, 0, 1
66+
selected_operators_dict = yaml.safe_load(selected_operators_file)
67+
et_kernel_metadata = selected_operators_dict.get("et_kernel_metadata", {})
68+
assert isinstance(et_kernel_metadata, dict)
69+
body = "return true;"
70+
body_parts = []
71+
for operator_name, kernel_metadata_str in et_kernel_metadata.items():
72+
tensor_meta = []
73+
for kernel_metadata in kernel_metadata_str:
74+
if kernel_metadata == "default":
75+
break
76+
else:
77+
x = kernel_metadata.split("/")[1]
78+
tensor_meta.extend(x.split("|"))
79+
conditions = ["true"]
80+
if len(tensor_meta) > 0:
81+
dtype_set = set([x.split(";")[0] for x in tensor_meta])
82+
dtype_list = sorted([dtype_enum_to_type[x] for x in dtype_set])
83+
conditions = [
84+
"scalar_type == exec_aten::ScalarType::" + x for x in dtype_list
85+
]
86+
body_parts.append(
87+
ops_and_dtypes_template.substitute(
88+
operator_name=operator_name.removeprefix("aten::"),
89+
dtype_checks=" || ".join(conditions),
90+
),
91+
)
92+
body = "\n || ".join(body_parts)
93+
header_contents = selected_kernel_dtypes_h_template.substitute(body=body)
94+
selected_mobile_ops_path = os.path.join(
95+
output_file_path, "selected_mobile_ops.h"
96+
)
97+
with open(selected_mobile_ops_path, "wb") as out_file:
98+
out_file.write(header_contents.encode("utf-8"))

codegen/tools/targets.bzl

Lines changed: 24 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -118,6 +118,30 @@ def define_common_targets(is_fbcode = False):
118118
_is_external_target = True,
119119
)
120120

121+
runtime.python_library(
122+
name = "gen_selected_mobile_ops_lib",
123+
srcs = ["gen_selected_mobile_ops.py"],
124+
base_module = "executorch.codegen.tools",
125+
visibility = ["//executorch/..."],
126+
)
127+
128+
runtime.python_test(
129+
name = "test_gen_selected_mobile_ops",
130+
srcs = [
131+
"test/test_gen_selected_mobile_ops.py",
132+
],
133+
package_style = "inplace",
134+
visibility = [
135+
"PUBLIC",
136+
],
137+
deps = [
138+
":gen_selected_mobile_ops_lib",
139+
":gen_oplist_lib",
140+
"fbsource//third-party/pypi/expecttest:expecttest",
141+
],
142+
_is_external_target = True,
143+
)
144+
121145
# TODO(larryliu0820): This is a hack to only run these two on fbcode. These targets depends on exir which is only available in fbcode.
122146
if not runtime.is_oss and is_fbcode:
123147
runtime.python_binary(
Lines changed: 80 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,80 @@
1+
#!/usr/bin/env fbpython
2+
# Copyright (c) Meta Platforms, Inc. and affiliates.
3+
# All rights reserved.
4+
#
5+
# This source code is licensed under the BSD-style license found in the
6+
# LICENSE file in the root directory of this source tree.
7+
8+
import os
9+
import tempfile
10+
11+
import executorch.codegen.tools.gen_selected_mobile_ops as gen_selected_mobile_ops
12+
import expecttest
13+
14+
15+
class TestGenSelectedMobileOpsHeader(expecttest.TestCase):
16+
def setUp(self):
17+
self.temp_dir = tempfile.TemporaryDirectory()
18+
self.selected_ops_yaml = os.path.join(
19+
self.temp_dir.name, "selected_operators.yaml"
20+
)
21+
with open(self.selected_ops_yaml, "w") as f:
22+
f.write(
23+
"""
24+
include_all_non_op_selectives: False
25+
include_all_operators: False
26+
debug_info:
27+
- model1@v100
28+
- model2@v50
29+
operators:
30+
aten::add:
31+
is_root_operator: Yes
32+
is_used_for_training: Yes
33+
include_all_overloads: No
34+
aten::add.int:
35+
is_root_operator: No
36+
is_used_for_training: No
37+
include_all_overloads: Yes
38+
kernel_metadata: {}
39+
et_kernel_metadata:
40+
aten::add.out:
41+
# A list of different kernel keys (tensors with dtype-enum/dim-order) combinations used in model
42+
- v1/6;0,1|6;0,1|6;0,1|6;0,1 # Float, 0, 1
43+
- v1/3;0,1|3;0,1|3;0,1|3;0,1 # Int, 0, 1
44+
aten::mul.out:
45+
- v1/6;0,1|6;0,1|6;0,1|6;0,1 # Float, 0, 1
46+
aten::sub.out:
47+
- default
48+
build_features: []
49+
custom_classes: []
50+
"""
51+
)
52+
53+
def tearDown(self):
54+
self.temp_dir.cleanup()
55+
56+
def test_generates_correct_header(self) -> None:
57+
gen_selected_mobile_ops.write_selected_mobile_ops(self.temp_dir.name)
58+
with open(
59+
os.path.join(self.temp_dir.name, "selected_mobile_ops.h"), "r"
60+
) as result:
61+
self.assertExpectedInline(
62+
result.read(),
63+
"""#pragma once
64+
/**
65+
* Generated by executorch/codegen/tools/gen_selected_mobile_ops_header.py
66+
*/
67+
68+
inline constexpr bool should_include_kernel_dtype(
69+
const char *operator_name,
70+
exec_aten::ScalarType scalar_type
71+
) {
72+
return ((string_view(operator_name).compare("add.out") == 0)
73+
&& (scalar_type == exec_aten::ScalarType::Float || scalar_type == exec_aten::ScalarType::Int))
74+
|| ((string_view(operator_name).compare("mul.out") == 0)
75+
&& (scalar_type == exec_aten::ScalarType::Float))
76+
|| ((string_view(operator_name).compare("sub.out") == 0)
77+
&& (true));
78+
}
79+
""",
80+
)

0 commit comments

Comments
 (0)