Skip to content

Commit ca0e48c

Browse files
SS-JIAfacebook-github-bot
authored andcommitted
Refactor codegen components to prepare for benchmark generation (#5560)
Summary: Pull Request resolved: #5560 ## Context Refactor operator test code generation scripts, such that components can be re-used to generate operator benchmarks. In broad strokes, the refactors implemented by this diff are as follows: * Improve granularity of Python modules * Replace `test` with `correctness_test`, to make it clear that we are generating correctness tests. **Note that I haven't changed the top level target name `compute_graph_op_tests_bin` since I believe it would be too verbose. ghstack-source-id: 244283559 exported-using-ghexport Reviewed By: nathanaelsee Differential Revision: D63286131 fbshipit-source-id: 1177ea381e6381045f1c97491dd7ec006690f574
1 parent 61cb5b0 commit ca0e48c

File tree

8 files changed

+289
-245
lines changed

8 files changed

+289
-245
lines changed

backends/vulkan/test/op_tests/cases.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,7 @@
88
from collections import namedtuple
99
from typing import Callable
1010

11-
from executorch.backends.vulkan.test.op_tests.utils.codegen import VkTestSuite
11+
from executorch.backends.vulkan.test.op_tests.utils.test_suite import VkTestSuite
1212

1313

1414
# Prime numbers dim sizes for testing

backends/vulkan/test/op_tests/generate_op_tests.py renamed to backends/vulkan/test/op_tests/generate_op_correctness_tests.py

Lines changed: 11 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -10,12 +10,14 @@
1010
from typing import Dict
1111

1212
from executorch.backends.vulkan.test.op_tests.cases import test_suites
13+
from executorch.backends.vulkan.test.op_tests.utils.gen_computegraph import (
14+
ComputeGraphGen,
15+
)
1316

14-
from executorch.backends.vulkan.test.op_tests.utils.codegen import VkCppTestFileGen
15-
from executorch.backends.vulkan.test.op_tests.utils.codegen_base import (
16-
TestSuite,
17-
TestSuiteGen,
17+
from executorch.backends.vulkan.test.op_tests.utils.gen_correctness_vk import (
18+
VkCorrectnessTestFileGen,
1819
)
20+
from executorch.backends.vulkan.test.op_tests.utils.test_suite import TestSuite
1921
from torchgen import local
2022

2123
from torchgen.gen import parse_native_yaml, ParsedYaml
@@ -37,7 +39,7 @@ def construct_f_map(parsed_yaml: ParsedYaml) -> Dict[str, NativeFunction]:
3739

3840

3941
def process_test_suites(
40-
cpp_generator: VkCppTestFileGen,
42+
cpp_generator: VkCorrectnessTestFileGen,
4143
f_map: Dict[str, NativeFunction],
4244
test_suites: Dict[str, TestSuite],
4345
) -> None:
@@ -53,12 +55,12 @@ def generate_cpp(
5355
native_functions_yaml_path: str, tags_path: str, output_dir: str
5456
) -> None:
5557
output_file = os.path.join(output_dir, "op_tests.cpp")
56-
cpp_generator = VkCppTestFileGen(output_file)
58+
cpp_generator = VkCorrectnessTestFileGen(output_file)
5759

5860
parsed_yaml = parse_native_yaml(native_functions_yaml_path, tags_path)
5961
f_map = construct_f_map(parsed_yaml)
6062

61-
TestSuiteGen.backend_key = parsed_yaml.backend_indices[DispatchKey.CPU]
63+
ComputeGraphGen.backend_key = parsed_yaml.backend_indices[DispatchKey.CPU]
6264

6365
process_test_suites(cpp_generator, f_map, test_suites)
6466

@@ -67,16 +69,14 @@ def generate_cpp(
6769

6870

6971
if __name__ == "__main__":
70-
parser = argparse.ArgumentParser(
71-
description="Generate a simple Hello World C++ program."
72-
)
72+
parser = argparse.ArgumentParser()
7373
parser.add_argument(
7474
"--aten-yaml-path",
7575
help="path to native_functions.yaml file.",
7676
)
7777
parser.add_argument(
7878
"--tags-path",
79-
help="Path to tags.yaml. Required by yaml parsing in codegen system.",
79+
help="Path to tags.yaml. Required by yaml parsing in gen_correctness_vk system.",
8080
)
8181
parser.add_argument("-o", "--output", help="Output directory", required=True)
8282
args = parser.parse_args()

backends/vulkan/test/op_tests/targets.bzl

Lines changed: 9 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -8,9 +8,9 @@ def define_common_targets(is_fbcode = False):
88
return
99

1010
runtime.python_library(
11-
name = "generate_op_tests_lib",
11+
name = "generate_op_correctness_tests_lib",
1212
srcs = native.glob(["utils/*.py"]) + [
13-
"generate_op_tests.py",
13+
"generate_op_correctness_tests.py",
1414
"cases.py",
1515
],
1616
base_module = "executorch.backends.vulkan.test.op_tests",
@@ -21,23 +21,23 @@ def define_common_targets(is_fbcode = False):
2121
)
2222

2323
runtime.python_binary(
24-
name = "generate_op_tests",
25-
main_module = "executorch.backends.vulkan.test.op_tests.generate_op_tests",
24+
name = "generate_op_correctness_tests",
25+
main_module = "executorch.backends.vulkan.test.op_tests.generate_op_correctness_tests",
2626
deps = [
27-
":generate_op_tests_lib",
27+
":generate_op_correctness_tests_lib",
2828
],
2929
)
3030

3131
aten_src_path = runtime.external_dep_location("aten-src-path")
3232
genrule_cmd = [
33-
"$(exe :generate_op_tests)",
33+
"$(exe :generate_op_correctness_tests)",
3434
"--tags-path $(location {})/aten/src/ATen/native/tags.yaml".format(aten_src_path),
3535
"--aten-yaml-path $(location {})/aten/src/ATen/native/native_functions.yaml".format(aten_src_path),
3636
"-o $OUT",
3737
]
3838

3939
runtime.genrule(
40-
name = "generated_op_tests_cpp",
40+
name = "generated_op_correctness_tests_cpp",
4141
outs = {
4242
"op_tests.cpp": ["op_tests.cpp"],
4343
},
@@ -66,7 +66,7 @@ def define_common_targets(is_fbcode = False):
6666
runtime.cxx_binary(
6767
name = "compute_graph_op_tests_bin",
6868
srcs = [
69-
":generated_op_tests_cpp[op_tests.cpp]",
69+
":generated_op_correctness_tests_cpp[op_tests.cpp]",
7070
],
7171
define_static_target = False,
7272
deps = [
@@ -79,7 +79,7 @@ def define_common_targets(is_fbcode = False):
7979
runtime.cxx_test(
8080
name = "compute_graph_op_tests",
8181
srcs = [
82-
":generated_op_tests_cpp[op_tests.cpp]",
82+
":generated_op_correctness_tests_cpp[op_tests.cpp]",
8383
],
8484
contacts = ["[email protected]"],
8585
fbandroid_additional_loaded_sonames = [
Lines changed: 30 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,30 @@
1+
# Copyright (c) Meta Platforms, Inc. and affiliates.
2+
# All rights reserved.
3+
#
4+
# This source code is licensed under the BSD-style license found in the
5+
# LICENSE file in the root directory of this source tree.
6+
7+
####################
8+
## ATen C++ Types ##
9+
####################
10+
11+
AT_INT_ARRAY_REF = "at::IntArrayRef"
12+
AT_SCALAR = "at::Scalar"
13+
AT_TENSOR = "at::Tensor"
14+
AT_TENSOR_LIST = "at::TensorList"
15+
BOOL = "bool"
16+
DOUBLE = "double"
17+
INT = "int64_t"
18+
OPT_AT_DOUBLE_ARRAY_REF = "::std::optional<at::ArrayRef<double>>"
19+
OPT_AT_INT_ARRAY_REF = "at::OptionalIntArrayRef"
20+
OPT_AT_TENSOR = "::std::optional<at::Tensor>"
21+
OPT_BOOL = "::std::optional<bool>"
22+
OPT_INT64 = "::std::optional<int64_t>"
23+
OPT_DEVICE = "::std::optional<at::Device>"
24+
OPT_LAYOUT = "::std::optional<at::Layout>"
25+
OPT_MEMORY_FORMAT = "::std::optional<at::MemoryFormat>"
26+
OPT_SCALAR_TYPE = "::std::optional<at::ScalarType>"
27+
STRING = "c10::string_view"
28+
TWO_TENSOR_TUPLE = "::std::tuple<at::Tensor,at::Tensor>"
29+
THREE_TENSOR_TUPLE = "::std::tuple<at::Tensor,at::Tensor,at::Tensor>"
30+
TENSOR_VECTOR = "::std::vector<at::Tensor>"

backends/vulkan/test/op_tests/utils/codegen.py renamed to backends/vulkan/test/op_tests/utils/gen_computegraph.py

Lines changed: 9 additions & 169 deletions
Original file line numberDiff line numberDiff line change
@@ -6,15 +6,14 @@
66

77
import re
88
from dataclasses import dataclass
9-
from typing import Any, List, Optional, Union
9+
from typing import List, Optional, Union
1010

11-
from executorch.backends.vulkan.test.op_tests.utils.codegen_base import (
11+
from executorch.backends.vulkan.test.op_tests.utils.aten_types import (
1212
AT_INT_ARRAY_REF,
1313
AT_SCALAR,
1414
AT_TENSOR,
1515
AT_TENSOR_LIST,
1616
BOOL,
17-
CppTestFileGen,
1817
DOUBLE,
1918
INT,
2019
OPT_AT_DOUBLE_ARRAY_REF,
@@ -28,37 +27,20 @@
2827
OPT_SCALAR_TYPE,
2928
STRING,
3029
TENSOR_VECTOR,
31-
TestSuite,
32-
TestSuiteGen,
3330
THREE_TENSOR_TUPLE,
3431
TWO_TENSOR_TUPLE,
3532
)
33+
from executorch.backends.vulkan.test.op_tests.utils.test_suite import TestSuite
3634

3735
from torchgen.api import cpp
3836
from torchgen.api.types import CppSignatureGroup
39-
4037
from torchgen.gen import generate_static_dispatch_backend_call, translate_args
41-
4238
from torchgen.gen_aoti_c_shim import gen_static_dispatch_backend_call_signature
4339
from torchgen.model import NativeFunction, Variant
4440

45-
##################################
46-
## Custom Test Suite Definition ##
47-
##################################
48-
49-
50-
@dataclass
51-
class VkTestSuite(TestSuite):
52-
def __init__(self, input_cases: List[Any]):
53-
super().__init__(input_cases)
54-
self.storage_types: List[str] = ["utils::kTexture3D"]
55-
self.layouts: List[str] = ["utils::kChannelsPacked"]
56-
self.data_gen: str = "make_rand_tensor"
57-
58-
59-
##########################
60-
## Code Generator Class ##
61-
##########################
41+
###################################
42+
## Compute Graph Code Generation ##
43+
###################################
6244

6345

6446
@dataclass
@@ -105,6 +87,8 @@ def vk_out(self):
10587

10688

10789
class ComputeGraphGen:
90+
backend_key = None
91+
10892
def __init__(self, op_reg_name: str, f: NativeFunction, suite_def: TestSuite):
10993
self.op_reg_name = op_reg_name
11094
self.f = f
@@ -230,7 +214,7 @@ def gen_decl(self, fn_name: str, ret_type: str = "void") -> str:
230214

231215
def create_aten_fn_call(self) -> str:
232216
func_call = generate_static_dispatch_backend_call(
233-
self.f_sig, self.f, TestSuiteGen.backend_key
217+
self.f_sig, self.f, ComputeGraphGen.backend_key
234218
)[7:].replace("::cpu", "")
235219

236220
return func_call
@@ -611,147 +595,3 @@ def gen_op_check_fn(self) -> str:
611595
op_check_fn += "\n }"
612596

613597
return op_check_fn
614-
615-
616-
##################################
617-
## Test Fixture Code Generation ##
618-
##################################
619-
620-
test_fixture_template = """
621-
class GeneratedOpsTest_{op_name} : public ::testing::TestWithParam< ::std::tuple<at::ScalarType, utils::StorageType, utils::GPUMemoryLayout>> {{
622-
protected:
623-
ComputeGraph* graph;
624-
at::ScalarType test_dtype = at::kFloat;
625-
float rtol = {rtol};
626-
float atol = {atol};
627-
628-
void SetUp() override {{
629-
GraphConfig config;
630-
utils::StorageType default_storage_type;
631-
utils::GPUMemoryLayout default_memory_layout;
632-
std::tie(test_dtype, default_storage_type, default_memory_layout) = GetParam();
633-
config.set_storage_type_override(default_storage_type);
634-
config.set_memory_layout_override(default_memory_layout);
635-
graph = new ComputeGraph(config);
636-
637-
if (test_dtype == at::kHalf) {{
638-
rtol = 1e-2;
639-
atol = 1e-2;
640-
}}
641-
}}
642-
643-
void TearDown() override {{
644-
delete graph;
645-
graph = nullptr;
646-
}}
647-
648-
{check_fn}
649-
}};
650-
"""
651-
652-
653-
class VkTestSuiteGen(TestSuiteGen):
654-
def __init__(self, op_reg_name: str, f: NativeFunction, inputs: VkTestSuite):
655-
super().__init__(f, inputs)
656-
self.op_reg_name = op_reg_name
657-
self.generator = ComputeGraphGen(self.op_reg_name, self.f, self.suite_def)
658-
659-
def generate_fixture_cpp(self) -> str:
660-
check_fn = ""
661-
if not self.suite_def.requires_prepack:
662-
check_fn = self.generator.gen_op_check_fn()
663-
664-
prepacked_check_fn = ""
665-
if self.suite_def.supports_prepack():
666-
self.generator.should_prepack = True
667-
prepacked_check_fn = self.generator.gen_op_check_fn()
668-
check_fn += "\n\n "
669-
check_fn += prepacked_check_fn
670-
671-
return test_fixture_template.format(
672-
op_name=self.op_name,
673-
check_fn=check_fn,
674-
rtol=self.suite_def.rtol,
675-
atol=self.suite_def.atol,
676-
)
677-
678-
def gen_parameterization(self) -> str:
679-
dtypes = self.suite_def.dtypes
680-
storage_types = self.suite_def.storage_types
681-
layouts = self.suite_def.layouts
682-
683-
return f"""
684-
INSTANTIATE_TEST_SUITE_P(
685-
Combos_{self.op_name},
686-
GeneratedOpsTest_{self.op_name},
687-
::testing::Combine(
688-
::testing::Values({', '.join(dtypes)}),
689-
::testing::Values({', '.join(storage_types)}),
690-
::testing::Values({', '.join(layouts)})));
691-
"""
692-
693-
694-
##############################
695-
## Test File Code Generation ##
696-
###############################
697-
698-
preamble_str = """
699-
#include <executorch/backends/vulkan/runtime/api/api.h>
700-
#include <executorch/backends/vulkan/runtime/graph/ops/OperatorRegistry.h>
701-
#include <executorch/backends/vulkan/runtime/graph/ComputeGraph.h>
702-
703-
#include <tuple>
704-
705-
using namespace vkcompute;
706-
using TensorOptions = at::TensorOptions;
707-
708-
vkapi::ScalarType from_at_scalartype(c10::ScalarType at_scalartype) {
709-
switch (at_scalartype) {
710-
case c10::kFloat:
711-
return vkapi::kFloat;
712-
case c10::kHalf:
713-
return vkapi::kHalf;
714-
case c10::kInt:
715-
return vkapi::kInt;
716-
case c10::kLong:
717-
return vkapi::kInt;
718-
case c10::kChar:
719-
return vkapi::kChar;
720-
default:
721-
VK_THROW("Unsupported at::ScalarType!");
722-
}
723-
}
724-
725-
#ifdef USE_VULKAN_FP16_INFERENCE
726-
bool check_close(at::Tensor& t1, at::Tensor& t2, float rtol=1e-2, float atol=1e-2) {
727-
#else
728-
bool check_close(at::Tensor& t1, at::Tensor& t2, float rtol=1e-5, float atol=1e-5) {
729-
#endif
730-
// Skip checking index tensors
731-
if (t1.scalar_type() == at::kLong || t2.scalar_type() == at::kLong) {
732-
return true;
733-
}
734-
bool is_close = at::allclose(t1, t2, rtol, atol);
735-
if (!is_close && t1.numel() < 500) {
736-
std::cout << "reference: " << std::endl;
737-
print(t1, 150);
738-
std::cout << std::endl;
739-
std::cout << "vulkan: " << std::endl;
740-
print(t2, 150);
741-
std::cout << std::endl;
742-
}
743-
return is_close;
744-
}
745-
"""
746-
747-
748-
class VkCppTestFileGen(CppTestFileGen):
749-
def __init__(self, out_path: str):
750-
super().__init__(out_path)
751-
752-
def generate_preamble(self) -> str:
753-
return preamble_str
754-
755-
def add_suite(self, op_reg_name: str, f: NativeFunction, all_input_cases) -> None:
756-
suites_gen = VkTestSuiteGen(op_reg_name, f, all_input_cases)
757-
self.suites_gens.append(suites_gen)

0 commit comments

Comments
 (0)