Skip to content

Commit ed58cac

Browse files
SS-JIAfacebook-github-bot
authored andcommitted
Automatically generate operator tests
Summary: ## Context One of the most time consuming parts of adding new operators is writing tests to verify that the implementation is correct. This changeset introduces a codegen solution for automatically generating tests. The goal is to introduce a simple interface to specify what inputs an operator should be checked with, and have a 1 button solution for generating the code and executing operator tests. ## Usage Overview From the developer's perspective, they only need to interact with `op_tests/cases.py`. The file is very simple: ``` # Prime numbers dim sizes for testing XL = 113 L = 89 M2 = 41 M1 = 37 M = 29 S2 = 11 S1 = 7 S = 5 XS = 3 ... def get_mm_inputs(): return [ ((M1, L), (L, M2)), ((S1, S2), (S2, M)), ] test_cases = { "aten.add.Tensor": get_binary_elementwise_inputs(), "aten.sub.Tensor": get_binary_elementwise_inputs(), "aten.div.Tensor": get_binary_elementwise_inputs(), "aten.mul.Tensor": get_binary_elementwise_inputs(), "aten.mm.default": get_mm_inputs(), } ``` It just contains a mapping from the name an operator is registered under in the operator registry to a list of inputs for which tests should be generated. To generate and run tests: ``` buck run //xplat/executorch/backends/vulkan/test/op_tests:compute_graph_op_tests_bin ``` ## Design Overview The code generation is mostly built on top of [torchgen](https://github.com/pytorch/pytorch/tree/main/torchgen), which is PyTorch's codegen system for parsing [native_function.yaml](https://github.com/pytorch/pytorch/blob/main/aten/src/ATen/native/native_functions.yaml) and generating C++ ATen functions from it. The basic idea is: 1. Using the operator registry name, find the corresponding native function for native_function.yaml 2. Use the function schema from the parsed native function to generate test fixtures that can build a Vulkan compute graph for the operator 3. Individual test cases can be generated by creating ATen tensors and calling the ATen operator to get a reference output, then using the test fixture to get a Vulkan output and compare it to the reference output. 4. GTest [test parameterization](https://github.com/google/googletest/blob/main/googletest/samples/sample8_unittest.cc) is used to test each test case under a combination of dtypes, storage types, and memory layout [Example generated cpp](https://www.internalfb.com/phabricator/paste/view/1201406551) Differential Revision: D55446638
1 parent 45c2557 commit ed58cac

File tree

6 files changed

+830
-0
lines changed

6 files changed

+830
-0
lines changed

backends/vulkan/test/op_tests/TARGETS

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,5 @@
1+
load(":targets.bzl", "define_common_targets")
2+
3+
oncall("executorch")
4+
5+
define_common_targets()
Lines changed: 58 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,58 @@
1+
# Copyright (c) Meta Platforms, Inc. and affiliates.
2+
# All rights reserved.
3+
#
4+
# This source code is licensed under the BSD-style license found in the
5+
# LICENSE file in the root directory of this source tree.
6+
7+
# Prime numbers dim sizes for testing
8+
XL = 113
9+
L = 89
10+
M2 = 41
11+
M1 = 37
12+
M = 29
13+
S2 = 11
14+
S1 = 7
15+
S = 5
16+
XS = 3
17+
18+
19+
def get_binary_elementwise_inputs():
20+
return [
21+
((M1, M2), (M1, M2)),
22+
((M1, M2), (M1, 1), 2.0),
23+
((M1, M2), (1, M2)),
24+
((S, S1, S2), (S, S1, S2)),
25+
((S, S1, S2), (S, S1, 1), 2.0),
26+
((S, S1, S2), (S, 1, S2), 2.0),
27+
]
28+
29+
30+
def get_mm_inputs():
31+
return [
32+
((M1, L), (L, M2)),
33+
((S1, S2), (S2, M)),
34+
]
35+
36+
37+
def get_pool2d_inputs():
38+
return [
39+
((S, M1, M2), [2, 2], [1, 1], [0, 0], [1, 1]),
40+
]
41+
42+
43+
test_cases = {
44+
"aten.add.Tensor": get_binary_elementwise_inputs(),
45+
"aten.sub.Tensor": get_binary_elementwise_inputs(),
46+
"aten.div.Tensor": get_binary_elementwise_inputs(),
47+
"aten.mul.Tensor": get_binary_elementwise_inputs(),
48+
"aten.mm.default": get_mm_inputs(),
49+
"aten.max_pool2d_with_indices.default": get_pool2d_inputs(),
50+
}
51+
52+
prepacked_args = {"aten.mm.default": {"mat2"}}
53+
54+
support_exceptions = {
55+
"aten.max_pool2d_with_indices.default": {
56+
"layouts": ["api::GPUMemoryLayout::TENSOR_CHANNELS_PACKED"]
57+
},
58+
}
Lines changed: 70 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,70 @@
1+
# Copyright (c) Meta Platforms, Inc. and affiliates.
2+
# All rights reserved.
3+
#
4+
# This source code is licensed under the BSD-style license found in the
5+
# LICENSE file in the root directory of this source tree.
6+
7+
import argparse
8+
import os
9+
10+
from executorch.backends.vulkan.test.op_tests.cases import test_cases
11+
12+
from executorch.backends.vulkan.test.op_tests.utils.codegen import (
13+
TestSuiteGen,
14+
VkCppTestFileGen,
15+
)
16+
17+
from torchgen.gen import parse_native_yaml
18+
from torchgen.model import DispatchKey, NativeFunction
19+
20+
21+
def registry_name(f: NativeFunction) -> str:
22+
name = str(f.namespace) + "." + str(f.func.name)
23+
if len(f.func.name.overload_name) == 0:
24+
name += ".default"
25+
return name
26+
27+
28+
def construct_f_map(parsed_yaml):
29+
f_map = {}
30+
for f in parsed_yaml.native_functions:
31+
f_map[registry_name(f)] = f
32+
return f_map
33+
34+
35+
def process_test_suites(cpp_generator, f_map, test_cases):
36+
for registry_name, all_cases in test_cases.items():
37+
f = f_map[registry_name]
38+
cpp_generator.add_suite(registry_name, f, all_cases)
39+
40+
41+
def generate_cpp(native_functions_yaml_path: str, tags_path: str, output_dir: str):
42+
output_file = os.path.join(output_dir, "op_tests.cpp")
43+
cpp_generator = VkCppTestFileGen(output_file)
44+
45+
parsed_yaml = parse_native_yaml(native_functions_yaml_path, tags_path)
46+
f_map = construct_f_map(parsed_yaml)
47+
48+
TestSuiteGen.backend_key = parsed_yaml.backend_indices[DispatchKey.CPU]
49+
50+
process_test_suites(cpp_generator, f_map, test_cases)
51+
52+
with open(output_file, "w") as file:
53+
file.write(cpp_generator.generate_cpp())
54+
55+
56+
if __name__ == "__main__":
57+
parser = argparse.ArgumentParser(
58+
description="Generate a simple Hello World C++ program."
59+
)
60+
parser.add_argument(
61+
"--aten_yaml_path",
62+
help="path to native_functions.yaml file.",
63+
)
64+
parser.add_argument(
65+
"--tags-path",
66+
help="Path to tags.yaml. Required by yaml parsing in codegen system.",
67+
)
68+
parser.add_argument("-o", "--output", help="Output directory", required=True)
69+
args = parser.parse_args()
70+
generate_cpp(args.aten_yaml_path, args.tags_path, args.output)
Lines changed: 54 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,54 @@
1+
load("@fbsource//xplat/executorch/build:runtime_wrapper.bzl", "runtime")
2+
3+
def define_common_targets():
4+
runtime.python_library(
5+
name = "generate_op_tests_lib",
6+
srcs = native.glob(["utils/*.py"]) + [
7+
"generate_op_tests.py",
8+
"cases.py",
9+
],
10+
base_module = "executorch.backends.vulkan.test.op_tests",
11+
deps = [
12+
"//caffe2/torchgen:torchgen",
13+
"fbsource//third-party/pypi/expecttest:expecttest",
14+
],
15+
)
16+
17+
runtime.python_binary(
18+
name = "generate_op_tests",
19+
main_module = "executorch.backends.vulkan.test.op_tests.generate_op_tests",
20+
deps = [
21+
":generate_op_tests_lib",
22+
],
23+
)
24+
25+
aten_src_path = runtime.external_dep_location("aten-src-path")
26+
genrule_cmd = [
27+
"$(exe :generate_op_tests)",
28+
"--tags-path $(location {})/aten/src/ATen/native/tags.yaml".format(aten_src_path),
29+
"--aten_yaml_path $(location {})/aten/src/ATen/native/native_functions.yaml".format(aten_src_path),
30+
"-o $OUT",
31+
]
32+
33+
runtime.genrule(
34+
name = "generated_op_tests_cpp",
35+
outs = {
36+
"op_tests.cpp": ["op_tests.cpp"],
37+
},
38+
cmd = " ".join(genrule_cmd),
39+
default_outs = ["."],
40+
)
41+
42+
runtime.cxx_binary(
43+
name = "compute_graph_op_tests_bin",
44+
srcs = [
45+
":generated_op_tests_cpp[op_tests.cpp]",
46+
],
47+
define_static_target = False,
48+
deps = [
49+
"//third-party/googletest:gtest_main",
50+
"//executorch/backends/vulkan:vulkan_graph_runtime",
51+
"//executorch/backends/vulkan/test:test_shader_lib",
52+
runtime.external_dep_location("libtorch"),
53+
],
54+
)

0 commit comments

Comments
 (0)