|
6 | 6 |
|
7 | 7 | import re
|
8 | 8 | from dataclasses import dataclass
|
9 |
| -from typing import Any, List, Optional, Union |
| 9 | +from typing import List, Optional, Union |
10 | 10 |
|
11 |
| -from executorch.backends.vulkan.test.op_tests.utils.codegen_base import ( |
| 11 | +from executorch.backends.vulkan.test.op_tests.utils.aten_types import ( |
12 | 12 | AT_INT_ARRAY_REF,
|
13 | 13 | AT_SCALAR,
|
14 | 14 | AT_TENSOR,
|
15 | 15 | AT_TENSOR_LIST,
|
16 | 16 | BOOL,
|
17 |
| - CppTestFileGen, |
18 | 17 | DOUBLE,
|
19 | 18 | INT,
|
20 | 19 | OPT_AT_DOUBLE_ARRAY_REF,
|
|
28 | 27 | OPT_SCALAR_TYPE,
|
29 | 28 | STRING,
|
30 | 29 | TENSOR_VECTOR,
|
31 |
| - TestSuite, |
32 |
| - TestSuiteGen, |
33 | 30 | THREE_TENSOR_TUPLE,
|
34 | 31 | TWO_TENSOR_TUPLE,
|
35 | 32 | )
|
| 33 | +from executorch.backends.vulkan.test.op_tests.utils.test_suite import TestSuite |
36 | 34 |
|
37 | 35 | from torchgen.api import cpp
|
38 | 36 | from torchgen.api.types import CppSignatureGroup
|
39 |
| - |
40 | 37 | from torchgen.gen import generate_static_dispatch_backend_call, translate_args
|
41 |
| - |
42 | 38 | from torchgen.gen_aoti_c_shim import gen_static_dispatch_backend_call_signature
|
43 | 39 | from torchgen.model import NativeFunction, Variant
|
44 | 40 |
|
45 |
| -################################## |
46 |
| -## Custom Test Suite Definition ## |
47 |
| -################################## |
48 |
| - |
49 |
| - |
50 |
| -@dataclass |
51 |
| -class VkTestSuite(TestSuite): |
52 |
| - def __init__(self, input_cases: List[Any]): |
53 |
| - super().__init__(input_cases) |
54 |
| - self.storage_types: List[str] = ["utils::kTexture3D"] |
55 |
| - self.layouts: List[str] = ["utils::kChannelsPacked"] |
56 |
| - self.data_gen: str = "make_rand_tensor" |
57 |
| - |
58 |
| - |
59 |
| -########################## |
60 |
| -## Code Generator Class ## |
61 |
| -########################## |
| 41 | +################################### |
| 42 | +## Compute Graph Code Generation ## |
| 43 | +################################### |
62 | 44 |
|
63 | 45 |
|
64 | 46 | @dataclass
|
@@ -105,6 +87,8 @@ def vk_out(self):
|
105 | 87 |
|
106 | 88 |
|
107 | 89 | class ComputeGraphGen:
|
| 90 | + backend_key = None |
| 91 | + |
108 | 92 | def __init__(self, op_reg_name: str, f: NativeFunction, suite_def: TestSuite):
|
109 | 93 | self.op_reg_name = op_reg_name
|
110 | 94 | self.f = f
|
@@ -230,7 +214,7 @@ def gen_decl(self, fn_name: str, ret_type: str = "void") -> str:
|
230 | 214 |
|
231 | 215 | def create_aten_fn_call(self) -> str:
|
232 | 216 | func_call = generate_static_dispatch_backend_call(
|
233 |
| - self.f_sig, self.f, TestSuiteGen.backend_key |
| 217 | + self.f_sig, self.f, ComputeGraphGen.backend_key |
234 | 218 | )[7:].replace("::cpu", "")
|
235 | 219 |
|
236 | 220 | return func_call
|
@@ -611,147 +595,3 @@ def gen_op_check_fn(self) -> str:
|
611 | 595 | op_check_fn += "\n }"
|
612 | 596 |
|
613 | 597 | return op_check_fn
|
614 |
| - |
615 |
| - |
616 |
| -################################## |
617 |
| -## Test Fixture Code Generation ## |
618 |
| -################################## |
619 |
| - |
620 |
| -test_fixture_template = """ |
621 |
| -class GeneratedOpsTest_{op_name} : public ::testing::TestWithParam< ::std::tuple<at::ScalarType, utils::StorageType, utils::GPUMemoryLayout>> {{ |
622 |
| - protected: |
623 |
| - ComputeGraph* graph; |
624 |
| - at::ScalarType test_dtype = at::kFloat; |
625 |
| - float rtol = {rtol}; |
626 |
| - float atol = {atol}; |
627 |
| -
|
628 |
| - void SetUp() override {{ |
629 |
| - GraphConfig config; |
630 |
| - utils::StorageType default_storage_type; |
631 |
| - utils::GPUMemoryLayout default_memory_layout; |
632 |
| - std::tie(test_dtype, default_storage_type, default_memory_layout) = GetParam(); |
633 |
| - config.set_storage_type_override(default_storage_type); |
634 |
| - config.set_memory_layout_override(default_memory_layout); |
635 |
| - graph = new ComputeGraph(config); |
636 |
| -
|
637 |
| - if (test_dtype == at::kHalf) {{ |
638 |
| - rtol = 1e-2; |
639 |
| - atol = 1e-2; |
640 |
| - }} |
641 |
| - }} |
642 |
| -
|
643 |
| - void TearDown() override {{ |
644 |
| - delete graph; |
645 |
| - graph = nullptr; |
646 |
| - }} |
647 |
| -
|
648 |
| - {check_fn} |
649 |
| -}}; |
650 |
| -""" |
651 |
| - |
652 |
| - |
653 |
| -class VkTestSuiteGen(TestSuiteGen): |
654 |
| - def __init__(self, op_reg_name: str, f: NativeFunction, inputs: VkTestSuite): |
655 |
| - super().__init__(f, inputs) |
656 |
| - self.op_reg_name = op_reg_name |
657 |
| - self.generator = ComputeGraphGen(self.op_reg_name, self.f, self.suite_def) |
658 |
| - |
659 |
| - def generate_fixture_cpp(self) -> str: |
660 |
| - check_fn = "" |
661 |
| - if not self.suite_def.requires_prepack: |
662 |
| - check_fn = self.generator.gen_op_check_fn() |
663 |
| - |
664 |
| - prepacked_check_fn = "" |
665 |
| - if self.suite_def.supports_prepack(): |
666 |
| - self.generator.should_prepack = True |
667 |
| - prepacked_check_fn = self.generator.gen_op_check_fn() |
668 |
| - check_fn += "\n\n " |
669 |
| - check_fn += prepacked_check_fn |
670 |
| - |
671 |
| - return test_fixture_template.format( |
672 |
| - op_name=self.op_name, |
673 |
| - check_fn=check_fn, |
674 |
| - rtol=self.suite_def.rtol, |
675 |
| - atol=self.suite_def.atol, |
676 |
| - ) |
677 |
| - |
678 |
| - def gen_parameterization(self) -> str: |
679 |
| - dtypes = self.suite_def.dtypes |
680 |
| - storage_types = self.suite_def.storage_types |
681 |
| - layouts = self.suite_def.layouts |
682 |
| - |
683 |
| - return f""" |
684 |
| -INSTANTIATE_TEST_SUITE_P( |
685 |
| - Combos_{self.op_name}, |
686 |
| - GeneratedOpsTest_{self.op_name}, |
687 |
| - ::testing::Combine( |
688 |
| - ::testing::Values({', '.join(dtypes)}), |
689 |
| - ::testing::Values({', '.join(storage_types)}), |
690 |
| - ::testing::Values({', '.join(layouts)}))); |
691 |
| - """ |
692 |
| - |
693 |
| - |
694 |
| -############################## |
695 |
| -## Test File Code Generation ## |
696 |
| -############################### |
697 |
| - |
698 |
| -preamble_str = """ |
699 |
| -#include <executorch/backends/vulkan/runtime/api/api.h> |
700 |
| -#include <executorch/backends/vulkan/runtime/graph/ops/OperatorRegistry.h> |
701 |
| -#include <executorch/backends/vulkan/runtime/graph/ComputeGraph.h> |
702 |
| -
|
703 |
| -#include <tuple> |
704 |
| -
|
705 |
| -using namespace vkcompute; |
706 |
| -using TensorOptions = at::TensorOptions; |
707 |
| -
|
708 |
| -vkapi::ScalarType from_at_scalartype(c10::ScalarType at_scalartype) { |
709 |
| - switch (at_scalartype) { |
710 |
| - case c10::kFloat: |
711 |
| - return vkapi::kFloat; |
712 |
| - case c10::kHalf: |
713 |
| - return vkapi::kHalf; |
714 |
| - case c10::kInt: |
715 |
| - return vkapi::kInt; |
716 |
| - case c10::kLong: |
717 |
| - return vkapi::kInt; |
718 |
| - case c10::kChar: |
719 |
| - return vkapi::kChar; |
720 |
| - default: |
721 |
| - VK_THROW("Unsupported at::ScalarType!"); |
722 |
| - } |
723 |
| -} |
724 |
| -
|
725 |
| -#ifdef USE_VULKAN_FP16_INFERENCE |
726 |
| -bool check_close(at::Tensor& t1, at::Tensor& t2, float rtol=1e-2, float atol=1e-2) { |
727 |
| -#else |
728 |
| -bool check_close(at::Tensor& t1, at::Tensor& t2, float rtol=1e-5, float atol=1e-5) { |
729 |
| -#endif |
730 |
| - // Skip checking index tensors |
731 |
| - if (t1.scalar_type() == at::kLong || t2.scalar_type() == at::kLong) { |
732 |
| - return true; |
733 |
| - } |
734 |
| - bool is_close = at::allclose(t1, t2, rtol, atol); |
735 |
| - if (!is_close && t1.numel() < 500) { |
736 |
| - std::cout << "reference: " << std::endl; |
737 |
| - print(t1, 150); |
738 |
| - std::cout << std::endl; |
739 |
| - std::cout << "vulkan: " << std::endl; |
740 |
| - print(t2, 150); |
741 |
| - std::cout << std::endl; |
742 |
| - } |
743 |
| - return is_close; |
744 |
| -} |
745 |
| -""" |
746 |
| - |
747 |
| - |
748 |
| -class VkCppTestFileGen(CppTestFileGen): |
749 |
| - def __init__(self, out_path: str): |
750 |
| - super().__init__(out_path) |
751 |
| - |
752 |
| - def generate_preamble(self) -> str: |
753 |
| - return preamble_str |
754 |
| - |
755 |
| - def add_suite(self, op_reg_name: str, f: NativeFunction, all_input_cases) -> None: |
756 |
| - suites_gen = VkTestSuiteGen(op_reg_name, f, all_input_cases) |
757 |
| - self.suites_gens.append(suites_gen) |
0 commit comments