Skip to content

Commit 098c58e

Browse files
authored
Copy unit tests from torchgen to ET codegen
Differential Revision: D75236020 Pull Request resolved: #11074
1 parent 8620702 commit 098c58e

9 files changed

+1280
-0
lines changed

codegen/test/TARGETS

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,8 @@
1+
# Any targets that should be shared between fbcode and xplat must be defined in
2+
# targets.bzl. This file can contain xplat-only targets.
3+
4+
load(":targets.bzl", "define_common_targets")
5+
6+
oncall("executorch")
7+
8+
define_common_targets()

codegen/test/targets.bzl

Lines changed: 21 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,21 @@
1+
load("@fbsource//xplat/executorch/build:runtime_wrapper.bzl", "get_oss_build_kwargs", "runtime")
2+
3+
def define_common_targets():
4+
"""Defines targets that should be shared between fbcode and xplat.
5+
6+
The directory containing this targets.bzl file should also contain both
7+
TARGETS and BUCK files that call this function.
8+
"""
9+
10+
runtime.python_test(
11+
name = "test_gen",
12+
srcs = glob(["test_*.py"]),
13+
package_style = "inplace",
14+
deps = [
15+
"//executorch/codegen:gen_lib",
16+
"fbsource//third-party/pypi/expecttest:expecttest",
17+
],
18+
external_deps = [
19+
"torchgen",
20+
],
21+
)
Lines changed: 153 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,153 @@
1+
# Copyright (c) Meta Platforms, Inc. and affiliates.
2+
# All rights reserved.
3+
#
4+
# This source code is licensed under the BSD-style license found in the
5+
# LICENSE file in the root directory of this source tree.
6+
7+
from __future__ import annotations
8+
9+
import tempfile
10+
import unittest
11+
from typing import Any
12+
from unittest.mock import ANY, Mock, patch
13+
14+
import expecttest
15+
16+
import torchgen
17+
from executorch.codegen.api.custom_ops import ComputeNativeFunctionStub
18+
from executorch.codegen.model import ETKernelIndex
19+
from torchgen.gen_executorch import gen_headers
20+
from torchgen.model import Location, NativeFunction
21+
from torchgen.selective_build.selector import SelectiveBuilder
22+
from torchgen.utils import FileManager
23+
24+
25+
SPACES = " "
26+
27+
28+
def _get_native_function_from_yaml(yaml_obj: dict[str, object]) -> NativeFunction:
29+
native_function, _ = NativeFunction.from_yaml(
30+
yaml_obj,
31+
loc=Location(__file__, 1),
32+
valid_tags=set(),
33+
)
34+
return native_function
35+
36+
37+
class TestComputeNativeFunctionStub(expecttest.TestCase):
38+
"""
39+
Could use torch.testing._internal.common_utils to reduce boilerplate.
40+
GH CI job doesn't build torch before running tools unit tests, hence
41+
manually adding these parametrized tests.
42+
"""
43+
44+
def _test_function_schema_generates_correct_kernel(
45+
self, obj: dict[str, Any], expected: str
46+
) -> None:
47+
func = _get_native_function_from_yaml(obj)
48+
49+
gen = ComputeNativeFunctionStub()
50+
res = gen(func)
51+
self.assertIsNotNone(res)
52+
self.assertExpectedInline(
53+
str(res),
54+
expected,
55+
)
56+
57+
def test_function_schema_generates_correct_kernel_tensor_out(self) -> None:
58+
obj = {"func": "custom::foo.out(Tensor self, *, Tensor(a!) out) -> Tensor(a!)"}
59+
expected = """
60+
at::Tensor & wrapper_CPU_out_foo_out(const at::Tensor & self, at::Tensor & out) {
61+
return out;
62+
}
63+
"""
64+
self._test_function_schema_generates_correct_kernel(obj, expected)
65+
66+
def test_function_schema_generates_correct_kernel_no_out(self) -> None:
67+
obj = {"func": "custom::foo.Tensor(Tensor self) -> Tensor"}
68+
expected = """
69+
at::Tensor wrapper_CPU_Tensor_foo(const at::Tensor & self) {
70+
return self;
71+
}
72+
"""
73+
self._test_function_schema_generates_correct_kernel(obj, expected)
74+
75+
def test_function_schema_generates_correct_kernel_no_return(self) -> None:
76+
obj = {"func": "custom::foo.out(Tensor self, *, Tensor(a!)[] out) -> ()"}
77+
expected = f"""
78+
void wrapper_CPU_out_foo_out(const at::Tensor & self, at::TensorList out) {{
79+
{SPACES}
80+
}}
81+
"""
82+
self._test_function_schema_generates_correct_kernel(obj, expected)
83+
84+
def test_function_schema_generates_correct_kernel_3_returns(self) -> None:
85+
obj = {
86+
"func": "custom::foo(Tensor self, Tensor[] other) -> (Tensor, Tensor, Tensor)"
87+
}
88+
expected = """
89+
::std::tuple<at::Tensor,at::Tensor,at::Tensor> wrapper_CPU__foo(const at::Tensor & self, at::TensorList other) {
90+
return ::std::tuple<at::Tensor, at::Tensor, at::Tensor>(
91+
at::Tensor(), at::Tensor(), at::Tensor()
92+
);
93+
}
94+
"""
95+
self._test_function_schema_generates_correct_kernel(obj, expected)
96+
97+
def test_function_schema_generates_correct_kernel_1_return_no_out(self) -> None:
98+
obj = {"func": "custom::foo(Tensor[] a) -> Tensor"}
99+
expected = """
100+
at::Tensor wrapper_CPU__foo(at::TensorList a) {
101+
return at::Tensor();
102+
}
103+
"""
104+
self._test_function_schema_generates_correct_kernel(obj, expected)
105+
106+
def test_schema_has_no_return_type_argument_throws(self) -> None:
107+
func = _get_native_function_from_yaml(
108+
{"func": "custom::foo.bool(Tensor self) -> bool"}
109+
)
110+
111+
gen = ComputeNativeFunctionStub()
112+
with self.assertRaisesRegex(Exception, "Can't handle this return type"):
113+
gen(func)
114+
115+
116+
class TestGenCustomOpsHeader(unittest.TestCase):
117+
@patch.object(torchgen.utils.FileManager, "write_with_template")
118+
@patch.object(torchgen.utils.FileManager, "write")
119+
def test_fm_writes_custom_ops_header_when_boolean_is_true(
120+
self, unused: Mock, mock_method: Mock
121+
) -> None:
122+
with tempfile.TemporaryDirectory() as tempdir:
123+
fm = FileManager(tempdir, tempdir, False)
124+
gen_headers(
125+
native_functions=[],
126+
gen_custom_ops_header=True,
127+
custom_ops_native_functions=[],
128+
selector=SelectiveBuilder.get_nop_selector(),
129+
kernel_index=ETKernelIndex(index={}), # type: ignore[arg-type]
130+
cpu_fm=fm,
131+
use_aten_lib=False,
132+
)
133+
mock_method.assert_called_once_with(
134+
"CustomOpsNativeFunctions.h", "NativeFunctions.h", ANY
135+
)
136+
137+
@patch.object(torchgen.utils.FileManager, "write_with_template")
138+
@patch.object(torchgen.utils.FileManager, "write")
139+
def test_fm_doesnot_writes_custom_ops_header_when_boolean_is_false(
140+
self, unused: Mock, mock_method: Mock
141+
) -> None:
142+
with tempfile.TemporaryDirectory() as tempdir:
143+
fm = FileManager(tempdir, tempdir, False)
144+
gen_headers(
145+
native_functions=[],
146+
gen_custom_ops_header=False,
147+
custom_ops_native_functions=[],
148+
selector=SelectiveBuilder.get_nop_selector(),
149+
kernel_index=ETKernelIndex(index={}), # type: ignore[arg-type]
150+
cpu_fm=fm,
151+
use_aten_lib=False,
152+
)
153+
mock_method.assert_not_called()

0 commit comments

Comments
 (0)