Skip to content

Commit d4b3e5c

Browse files
SS-JIAfacebook-github-bot
authored andcommitted
Automatically generate operator tests (#2754)
Summary: Pull Request resolved: #2754 ## Context One of the most time consuming parts of adding new operators is writing tests to verify that the implementation is correct. This changeset introduces a codegen solution for automatically generating tests. The goal is to introduce a simple interface to specify what inputs an operator should be checked with, and have a 1 button solution for generating the code and executing operator tests. ## Usage Overview From the developer's perspective, they only need to interact with `op_tests/cases.py`. The file is very simple: ``` # Prime numbers dim sizes for testing XL = 113 L = 89 M2 = 41 M1 = 37 M = 29 S2 = 11 S1 = 7 S = 5 XS = 3 ... def get_mm_inputs(): return [ ((M1, L), (L, M2)), ((S1, S2), (S2, M)), ] test_cases = { "aten.add.Tensor": get_binary_elementwise_inputs(), "aten.sub.Tensor": get_binary_elementwise_inputs(), "aten.div.Tensor": get_binary_elementwise_inputs(), "aten.mul.Tensor": get_binary_elementwise_inputs(), "aten.mm.default": get_mm_inputs(), } ``` It just contains a mapping from the name an operator is registered under in the operator registry to a list of inputs for which tests should be generated. To generate and run tests: ``` buck run //xplat/executorch/backends/vulkan/test/op_tests:compute_graph_op_tests_bin ``` ## Design Overview The code generation is mostly built on top of [torchgen](https://github.com/pytorch/pytorch/tree/main/torchgen), which is PyTorch's codegen system for parsing [native_function.yaml](https://github.com/pytorch/pytorch/blob/main/aten/src/ATen/native/native_functions.yaml) and generating C++ ATen functions from it. The basic idea is: 1. Using the operator registry name, find the corresponding native function for native_function.yaml 2. Use the function schema from the parsed native function to generate test fixtures that can build a Vulkan compute graph for the operator 3. Individual test cases can be generated by creating ATen tensors and calling the ATen operator to get a reference output, then using the test fixture to get a Vulkan output and compare it to the reference output. 4. GTest [test parameterization](https://github.com/google/googletest/blob/main/googletest/samples/sample8_unittest.cc) is used to test each test case under a combination of dtypes, storage types, and memory layout [Example generated cpp](https://www.internalfb.com/phabricator/paste/view/P1202279441) Reviewed By: copyrightly Differential Revision: D55446638 fbshipit-source-id: 93ca8e7cd43cee1e2678c489d6f2227507ef256f
1 parent 1c98d78 commit d4b3e5c

File tree

6 files changed

+903
-0
lines changed

6 files changed

+903
-0
lines changed

backends/vulkan/test/op_tests/TARGETS

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,5 @@
1+
load(":targets.bzl", "define_common_targets")
2+
3+
oncall("executorch")
4+
5+
define_common_targets(is_fbcode = True)
Lines changed: 72 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,72 @@
1+
# Copyright (c) Meta Platforms, Inc. and affiliates.
2+
# All rights reserved.
3+
#
4+
# This source code is licensed under the BSD-style license found in the
5+
# LICENSE file in the root directory of this source tree.
6+
7+
8+
from executorch.backends.vulkan.test.op_tests.utils.codegen import VkTestSuite
9+
10+
11+
# Prime numbers dim sizes for testing
12+
XL = 113
13+
L = 89
14+
M2 = 41
15+
M1 = 37
16+
M = 29
17+
S2 = 11
18+
S1 = 7
19+
S = 5
20+
XS = 3
21+
22+
23+
def get_binary_elementwise_inputs():
24+
return VkTestSuite(
25+
[
26+
((M1, M2), (M1, M2)),
27+
((M1, M2), (M1, 1), 2.0),
28+
((M1, M2), (1, M2)),
29+
((S, S1, S2), (S, S1, S2)),
30+
((S, S1, S2), (S, S1, 1), 2.0),
31+
((S, S1, S2), (S, 1, S2), 2.0),
32+
]
33+
)
34+
35+
36+
def get_mm_inputs():
37+
test_suite = VkTestSuite(
38+
[
39+
((M1, L), (L, M2)),
40+
((S1, S2), (S2, M)),
41+
],
42+
)
43+
test_suite.prepacked_args = ["mat2"]
44+
return test_suite
45+
46+
47+
def get_pool2d_inputs():
48+
test_suite = VkTestSuite(
49+
[
50+
((S, M1, M2), [2, 2], [1, 1], [0, 0], [1, 1]),
51+
]
52+
)
53+
test_suite.supports["layouts"] = ["api::GPUMemoryLayout::TENSOR_CHANNELS_PACKED"]
54+
return test_suite
55+
56+
57+
test_suites = {
58+
"aten.add.Tensor": get_binary_elementwise_inputs(),
59+
"aten.sub.Tensor": get_binary_elementwise_inputs(),
60+
"aten.div.Tensor": get_binary_elementwise_inputs(),
61+
"aten.mul.Tensor": get_binary_elementwise_inputs(),
62+
"aten.mm.default": get_mm_inputs(),
63+
"aten.max_pool2d_with_indices.default": get_pool2d_inputs(),
64+
}
65+
66+
prepacked_args = {"aten.mm.default": {"mat2"}}
67+
68+
support_exceptions = {
69+
"aten.max_pool2d_with_indices.default": {
70+
"layouts": ["api::GPUMemoryLayout::TENSOR_CHANNELS_PACKED"]
71+
},
72+
}
Lines changed: 79 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,79 @@
1+
# Copyright (c) Meta Platforms, Inc. and affiliates.
2+
# All rights reserved.
3+
#
4+
# This source code is licensed under the BSD-style license found in the
5+
# LICENSE file in the root directory of this source tree.
6+
7+
import argparse
8+
import os
9+
10+
from typing import Dict
11+
12+
from executorch.backends.vulkan.test.op_tests.cases import test_suites
13+
14+
from executorch.backends.vulkan.test.op_tests.utils.codegen import VkCppTestFileGen
15+
from executorch.backends.vulkan.test.op_tests.utils.codegen_base import (
16+
TestSuite,
17+
TestSuiteGen,
18+
)
19+
20+
from torchgen.gen import parse_native_yaml, ParsedYaml
21+
from torchgen.model import DispatchKey, NativeFunction
22+
23+
24+
def registry_name(f: NativeFunction) -> str:
25+
name = str(f.namespace) + "." + str(f.func.name)
26+
if len(f.func.name.overload_name) == 0:
27+
name += ".default"
28+
return name
29+
30+
31+
def construct_f_map(parsed_yaml: ParsedYaml) -> Dict[str, NativeFunction]:
32+
f_map: Dict[str, NativeFunction] = {}
33+
for f in parsed_yaml.native_functions:
34+
f_map[registry_name(f)] = f
35+
return f_map
36+
37+
38+
def process_test_suites(
39+
cpp_generator: VkCppTestFileGen,
40+
f_map: Dict[str, NativeFunction],
41+
test_suites: Dict[str, TestSuite],
42+
) -> None:
43+
for registry_name, op_test_suite in test_suites.items():
44+
f = f_map[registry_name]
45+
cpp_generator.add_suite(registry_name, f, op_test_suite)
46+
47+
48+
def generate_cpp(
49+
native_functions_yaml_path: str, tags_path: str, output_dir: str
50+
) -> None:
51+
output_file = os.path.join(output_dir, "op_tests.cpp")
52+
cpp_generator = VkCppTestFileGen(output_file)
53+
54+
parsed_yaml = parse_native_yaml(native_functions_yaml_path, tags_path)
55+
f_map = construct_f_map(parsed_yaml)
56+
57+
TestSuiteGen.backend_key = parsed_yaml.backend_indices[DispatchKey.CPU]
58+
59+
process_test_suites(cpp_generator, f_map, test_suites)
60+
61+
with open(output_file, "w") as file:
62+
file.write(cpp_generator.generate_cpp())
63+
64+
65+
if __name__ == "__main__":
66+
parser = argparse.ArgumentParser(
67+
description="Generate a simple Hello World C++ program."
68+
)
69+
parser.add_argument(
70+
"--aten-yaml-path",
71+
help="path to native_functions.yaml file.",
72+
)
73+
parser.add_argument(
74+
"--tags-path",
75+
help="Path to tags.yaml. Required by yaml parsing in codegen system.",
76+
)
77+
parser.add_argument("-o", "--output", help="Output directory", required=True)
78+
args = parser.parse_args()
79+
generate_cpp(args.aten_yaml_path, args.tags_path, args.output)
Lines changed: 77 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,77 @@
1+
load("@fbsource//tools/build_defs:platform_defs.bzl", "ANDROID")
2+
load("@fbsource//xplat/executorch/build:runtime_wrapper.bzl", "runtime")
3+
4+
def define_common_targets(is_fbcode = False):
5+
if is_fbcode:
6+
return
7+
8+
runtime.python_library(
9+
name = "generate_op_tests_lib",
10+
srcs = native.glob(["utils/*.py"]) + [
11+
"generate_op_tests.py",
12+
"cases.py",
13+
],
14+
base_module = "executorch.backends.vulkan.test.op_tests",
15+
deps = [
16+
"//caffe2/torchgen:torchgen",
17+
"fbsource//third-party/pypi/expecttest:expecttest",
18+
],
19+
)
20+
21+
runtime.python_binary(
22+
name = "generate_op_tests",
23+
main_module = "executorch.backends.vulkan.test.op_tests.generate_op_tests",
24+
deps = [
25+
":generate_op_tests_lib",
26+
],
27+
)
28+
29+
aten_src_path = runtime.external_dep_location("aten-src-path")
30+
genrule_cmd = [
31+
"$(exe :generate_op_tests)",
32+
"--tags-path $(location {})/aten/src/ATen/native/tags.yaml".format(aten_src_path),
33+
"--aten-yaml-path $(location {})/aten/src/ATen/native/native_functions.yaml".format(aten_src_path),
34+
"-o $OUT",
35+
]
36+
37+
runtime.genrule(
38+
name = "generated_op_tests_cpp",
39+
outs = {
40+
"op_tests.cpp": ["op_tests.cpp"],
41+
},
42+
cmd = " ".join(genrule_cmd),
43+
default_outs = ["."],
44+
)
45+
46+
runtime.cxx_binary(
47+
name = "compute_graph_op_tests_bin",
48+
srcs = [
49+
":generated_op_tests_cpp[op_tests.cpp]",
50+
],
51+
define_static_target = False,
52+
deps = [
53+
"//third-party/googletest:gtest_main",
54+
"//executorch/backends/vulkan:vulkan_graph_runtime",
55+
runtime.external_dep_location("libtorch"),
56+
],
57+
)
58+
59+
runtime.cxx_test(
60+
name = "compute_graph_op_tests",
61+
srcs = [
62+
":generated_op_tests_cpp[op_tests.cpp]",
63+
],
64+
contacts = ["[email protected]"],
65+
fbandroid_additional_loaded_sonames = [
66+
"torch-code-gen",
67+
"vulkan_graph_runtime",
68+
"vulkan_graph_runtime_shaderlib",
69+
],
70+
platforms = [ANDROID],
71+
use_instrumentation_test = True,
72+
deps = [
73+
"//third-party/googletest:gtest_main",
74+
"//executorch/backends/vulkan:vulkan_graph_runtime",
75+
runtime.external_dep_location("libtorch"),
76+
],
77+
)

0 commit comments

Comments
 (0)