Skip to content

Commit 5a984cc

Browse files
SS-JIAfacebook-github-bot
authored andcommitted
Generate benchmarks automatically (#5561)
Summary: Pull Request resolved: #5561 ## Context Use the automatic test generation infrastructure to generate operator benchmarks. The overall concept is the same as the test generation; we just structure the generated code in the style of the google benchmark library instead of GTEST. ghstack-source-id: 244287193 Reviewed By: derekxu, nathanaelsee Differential Revision: D63286132 fbshipit-source-id: 25c379accf6664dfca8232db81772b638b41c758
1 parent ca0e48c commit 5a984cc

File tree

4 files changed

+595
-24
lines changed

4 files changed

+595
-24
lines changed
Lines changed: 84 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,84 @@
1+
# Copyright (c) Meta Platforms, Inc. and affiliates.
2+
# All rights reserved.
3+
#
4+
# This source code is licensed under the BSD-style license found in the
5+
# LICENSE file in the root directory of this source tree.
6+
7+
import argparse
8+
import os
9+
10+
from typing import Dict
11+
12+
from executorch.backends.vulkan.test.op_tests.cases import test_suites
13+
14+
from executorch.backends.vulkan.test.op_tests.utils.gen_benchmark_vk import (
15+
VkBenchmarkFileGen,
16+
)
17+
from executorch.backends.vulkan.test.op_tests.utils.gen_computegraph import (
18+
ComputeGraphGen,
19+
)
20+
from executorch.backends.vulkan.test.op_tests.utils.test_suite import TestSuite
21+
from torchgen import local
22+
23+
from torchgen.gen import parse_native_yaml, ParsedYaml
24+
from torchgen.model import DispatchKey, NativeFunction
25+
26+
27+
def registry_name(f: NativeFunction) -> str:
28+
name = str(f.namespace) + "." + str(f.func.name)
29+
if len(f.func.name.overload_name) == 0:
30+
name += ".default"
31+
return name
32+
33+
34+
def construct_f_map(parsed_yaml: ParsedYaml) -> Dict[str, NativeFunction]:
35+
f_map: Dict[str, NativeFunction] = {}
36+
for f in parsed_yaml.native_functions:
37+
f_map[registry_name(f)] = f
38+
return f_map
39+
40+
41+
def process_test_suites(
42+
cpp_generator: VkBenchmarkFileGen,
43+
f_map: Dict[str, NativeFunction],
44+
test_suites: Dict[str, TestSuite],
45+
) -> None:
46+
for registry_name, op_test_suite in test_suites.items():
47+
f = f_map[registry_name]
48+
cpp_generator.add_suite(registry_name, f, op_test_suite)
49+
50+
51+
@local.parametrize(
52+
use_const_ref_for_mutable_tensors=False, use_ilistref_for_tensor_lists=False
53+
)
54+
def generate_cpp(
55+
native_functions_yaml_path: str, tags_path: str, output_dir: str
56+
) -> None:
57+
output_file = os.path.join(output_dir, "op_benchmarks.cpp")
58+
cpp_generator = VkBenchmarkFileGen(output_file)
59+
60+
parsed_yaml = parse_native_yaml(native_functions_yaml_path, tags_path)
61+
f_map = construct_f_map(parsed_yaml)
62+
63+
ComputeGraphGen.backend_key = parsed_yaml.backend_indices[DispatchKey.CPU]
64+
65+
process_test_suites(cpp_generator, f_map, test_suites)
66+
67+
with open(output_file, "w") as file:
68+
file.write(cpp_generator.generate_cpp())
69+
70+
71+
if __name__ == "__main__":
72+
parser = argparse.ArgumentParser()
73+
parser.add_argument(
74+
"--aten-yaml-path",
75+
help="path to native_functions.yaml file.",
76+
)
77+
parser.add_argument(
78+
"--tags-path",
79+
help="Path to tags.yaml. Required by yaml parsing in gen_correctness_vk system.",
80+
)
81+
82+
parser.add_argument("-o", "--output", help="Output directory", required=True)
83+
args = parser.parse_args()
84+
generate_cpp(args.aten_yaml_path, args.tags_path, args.output)

backends/vulkan/test/op_tests/targets.bzl

Lines changed: 53 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,19 @@ def define_common_targets(is_fbcode = False):
2020
external_deps = ["torchgen"],
2121
)
2222

23+
runtime.python_library(
24+
name = "generate_op_benchmarks_lib",
25+
srcs = native.glob(["utils/*.py"]) + [
26+
"generate_op_benchmarks.py",
27+
"cases.py",
28+
],
29+
base_module = "executorch.backends.vulkan.test.op_tests",
30+
deps = [
31+
"fbsource//third-party/pypi/expecttest:expecttest",
32+
],
33+
external_deps = ["torchgen"],
34+
)
35+
2336
runtime.python_binary(
2437
name = "generate_op_correctness_tests",
2538
main_module = "executorch.backends.vulkan.test.op_tests.generate_op_correctness_tests",
@@ -28,6 +41,14 @@ def define_common_targets(is_fbcode = False):
2841
],
2942
)
3043

44+
runtime.python_binary(
45+
name = "generate_op_benchmarks",
46+
main_module = "executorch.backends.vulkan.test.op_tests.generate_op_benchmarks",
47+
deps = [
48+
":generate_op_benchmarks_lib",
49+
],
50+
)
51+
3152
aten_src_path = runtime.external_dep_location("aten-src-path")
3253
genrule_cmd = [
3354
"$(exe :generate_op_correctness_tests)",
@@ -45,6 +66,22 @@ def define_common_targets(is_fbcode = False):
4566
default_outs = ["."],
4667
)
4768

69+
benchmarks_genrule_cmd = [
70+
"$(exe :generate_op_benchmarks)",
71+
"--tags-path $(location {})/aten/src/ATen/native/tags.yaml".format(aten_src_path),
72+
"--aten-yaml-path $(location {})/aten/src/ATen/native/native_functions.yaml".format(aten_src_path),
73+
"-o $OUT",
74+
]
75+
76+
runtime.genrule(
77+
name = "generated_op_benchmarks_cpp",
78+
outs = {
79+
"op_benchmarks.cpp": ["op_benchmarks.cpp"],
80+
},
81+
cmd = " ".join(benchmarks_genrule_cmd),
82+
default_outs = ["."],
83+
)
84+
4885
pt_operator_library(
4986
name = "all_aten_ops",
5087
check_decl = False,
@@ -76,6 +113,22 @@ def define_common_targets(is_fbcode = False):
76113
],
77114
)
78115

116+
runtime.cxx_binary(
117+
name = "compute_graph_op_benchmarks_bin",
118+
srcs = [
119+
":generated_op_benchmarks_cpp[op_benchmarks.cpp]",
120+
],
121+
compiler_flags = [
122+
"-Wno-unused-variable",
123+
],
124+
define_static_target = False,
125+
deps = [
126+
"//third-party/benchmark:benchmark",
127+
"//executorch/backends/vulkan:vulkan_graph_runtime",
128+
":all_aten_ops_lib",
129+
],
130+
)
131+
79132
runtime.cxx_test(
80133
name = "compute_graph_op_tests",
81134
srcs = [

0 commit comments

Comments
 (0)