Skip to content

Commit 1341c8e

Browse files
committed
fix-up
1 parent 8a5f52b commit 1341c8e

File tree

5 files changed

+144
-69
lines changed

5 files changed

+144
-69
lines changed

backends/vulkan/runtime/graph/ops/glsl/conv2d_dw.glsl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,7 @@
1414

1515
#define op(X, A, B) ${OPERATOR}
1616

17-
#include "indexing_utils_u16.h"
17+
#include "indexing_utils.h"
1818

1919
layout(std430) buffer;
2020

backends/vulkan/runtime/graph/ops/glsl/conv2d_dw_output_tile.glsl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -22,7 +22,7 @@
2222

2323
#define op(X, A, B) ${OPERATOR}
2424

25-
#include "indexing_utils_u16.h"
25+
#include "indexing_utils.h"
2626

2727
layout(std430) buffer;
2828

backends/vulkan/runtime/graph/ops/glsl/conv2d_pw.glsl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,7 @@
1616

1717
#define op(X, A, B) ${OPERATOR}
1818

19-
#include "indexing_utils_u16.h"
19+
#include "indexing_utils.h"
2020

2121
layout(std430) buffer;
2222

backends/vulkan/runtime/graph/ops/glsl/indexing_utils_u16.h

Lines changed: 0 additions & 19 deletions
This file was deleted.

backends/vulkan/test/op_tests/targets.bzl

Lines changed: 141 additions & 47 deletions
Original file line numberDiff line numberDiff line change
@@ -3,44 +3,6 @@ load("@fbsource//xplat/caffe2:pt_defs.bzl", "get_pt_ops_deps")
33
load("@fbsource//xplat/caffe2:pt_ops.bzl", "pt_operator_library")
44
load("@fbsource//xplat/executorch/build:runtime_wrapper.bzl", "runtime")
55

6-
def define_test_targets(test_name, extra_deps = [], src_file = None, is_fbcode = False):
7-
deps_list = [
8-
"//third-party/googletest:gtest_main",
9-
"//executorch/backends/vulkan:vulkan_graph_runtime",
10-
runtime.external_dep_location("libtorch"),
11-
] + extra_deps
12-
13-
src_file_str = src_file if src_file else "{}.cpp".format(test_name)
14-
15-
runtime.cxx_binary(
16-
name = "{}_bin".format(test_name),
17-
srcs = [
18-
src_file_str,
19-
],
20-
compiler_flags = [
21-
"-Wno-unused-variable",
22-
],
23-
define_static_target = False,
24-
deps = deps_list,
25-
)
26-
27-
runtime.cxx_test(
28-
name = test_name,
29-
srcs = [
30-
src_file_str,
31-
],
32-
contacts = ["[email protected]"],
33-
fbandroid_additional_loaded_sonames = [
34-
"torch-code-gen",
35-
"vulkan_graph_runtime",
36-
"vulkan_graph_runtime_shaderlib",
37-
],
38-
platforms = [ANDROID],
39-
use_instrumentation_test = True,
40-
deps = deps_list,
41-
)
42-
43-
446
def define_common_targets(is_fbcode = False):
457
if is_fbcode:
468
return
@@ -120,6 +82,19 @@ def define_common_targets(is_fbcode = False):
12082
default_outs = ["."],
12183
)
12284

85+
runtime.cxx_binary(
86+
name = "compute_graph_op_tests_bin",
87+
srcs = [
88+
":generated_op_correctness_tests_cpp[op_tests.cpp]",
89+
],
90+
define_static_target = False,
91+
deps = [
92+
"//third-party/googletest:gtest_main",
93+
"//executorch/backends/vulkan:vulkan_graph_runtime",
94+
runtime.external_dep_location("libtorch"),
95+
],
96+
)
97+
12398
runtime.cxx_binary(
12499
name = "compute_graph_op_benchmarks_bin",
125100
srcs = [
@@ -136,17 +111,136 @@ def define_common_targets(is_fbcode = False):
136111
],
137112
)
138113

139-
define_test_targets(
140-
"compute_graph_op_tests",
141-
src_file=":generated_op_correctness_tests_cpp[op_tests.cpp]"
114+
runtime.cxx_test(
115+
name = "compute_graph_op_tests",
116+
srcs = [
117+
":generated_op_correctness_tests_cpp[op_tests.cpp]",
118+
],
119+
contacts = ["[email protected]"],
120+
fbandroid_additional_loaded_sonames = [
121+
"torch-code-gen",
122+
"vulkan_graph_runtime",
123+
"vulkan_graph_runtime_shaderlib",
124+
],
125+
platforms = [ANDROID],
126+
use_instrumentation_test = True,
127+
deps = [
128+
"//third-party/googletest:gtest_main",
129+
"//executorch/backends/vulkan:vulkan_graph_runtime",
130+
runtime.external_dep_location("libtorch"),
131+
],
142132
)
143133

144-
define_test_targets(
145-
"sdpa_test",
146-
extra_deps = [
134+
135+
runtime.cxx_binary(
136+
name = "sdpa_test_bin",
137+
srcs = [
138+
"sdpa_test.cpp",
139+
],
140+
compiler_flags = [
141+
"-Wno-unused-variable",
142+
],
143+
define_static_target = False,
144+
deps = [
145+
"//third-party/googletest:gtest_main",
146+
"//executorch/backends/vulkan:vulkan_graph_runtime",
147+
"//executorch/extension/llm/custom_ops:custom_ops_aot_lib",
148+
],
149+
)
150+
151+
runtime.cxx_test(
152+
name = "sdpa_test",
153+
srcs = [
154+
"sdpa_test.cpp",
155+
],
156+
contacts = ["[email protected]"],
157+
fbandroid_additional_loaded_sonames = [
158+
"torch-code-gen",
159+
"vulkan_graph_runtime",
160+
"vulkan_graph_runtime_shaderlib",
161+
],
162+
platforms = [ANDROID],
163+
use_instrumentation_test = True,
164+
deps = [
165+
"//third-party/googletest:gtest_main",
166+
"//executorch/backends/vulkan:vulkan_graph_runtime",
167+
"//executorch/extension/llm/custom_ops:custom_ops_aot_lib",
168+
"//executorch/extension/tensor:tensor",
169+
runtime.external_dep_location("libtorch"),
170+
],
171+
)
172+
173+
runtime.cxx_binary(
174+
name = "linear_weight_int4_test_bin",
175+
srcs = [
176+
"linear_weight_int4_test.cpp",
177+
],
178+
compiler_flags = [
179+
"-Wno-unused-variable",
180+
],
181+
define_static_target = False,
182+
deps = [
183+
"//third-party/googletest:gtest_main",
184+
"//executorch/backends/vulkan:vulkan_graph_runtime",
185+
runtime.external_dep_location("libtorch"),
186+
],
187+
)
188+
189+
runtime.cxx_test(
190+
name = "linear_weight_int4_test",
191+
srcs = [
192+
"linear_weight_int4_test.cpp",
193+
],
194+
contacts = ["[email protected]"],
195+
fbandroid_additional_loaded_sonames = [
196+
"torch-code-gen",
197+
"vulkan_graph_runtime",
198+
"vulkan_graph_runtime_shaderlib",
199+
],
200+
platforms = [ANDROID],
201+
use_instrumentation_test = True,
202+
deps = [
203+
"//third-party/googletest:gtest_main",
204+
"//executorch/backends/vulkan:vulkan_graph_runtime",
147205
"//executorch/extension/llm/custom_ops:custom_ops_aot_lib",
148206
"//executorch/extension/tensor:tensor",
149-
]
207+
runtime.external_dep_location("libtorch"),
208+
],
209+
)
210+
211+
runtime.cxx_binary(
212+
name = "rotary_embedding_test_bin",
213+
srcs = [
214+
"rotary_embedding_test.cpp",
215+
],
216+
compiler_flags = [
217+
"-Wno-unused-variable",
218+
],
219+
define_static_target = False,
220+
deps = [
221+
"//third-party/googletest:gtest_main",
222+
"//executorch/backends/vulkan:vulkan_graph_runtime",
223+
runtime.external_dep_location("libtorch"),
224+
],
225+
)
226+
227+
runtime.cxx_test(
228+
name = "rotary_embedding_test",
229+
srcs = [
230+
"rotary_embedding_test.cpp",
231+
],
232+
contacts = ["[email protected]"],
233+
fbandroid_additional_loaded_sonames = [
234+
"torch-code-gen",
235+
"vulkan_graph_runtime",
236+
"vulkan_graph_runtime_shaderlib",
237+
],
238+
platforms = [ANDROID],
239+
use_instrumentation_test = True,
240+
deps = [
241+
"//third-party/googletest:gtest_main",
242+
"//executorch/backends/vulkan:vulkan_graph_runtime",
243+
"//executorch/extension/tensor:tensor",
244+
runtime.external_dep_location("libtorch"),
245+
],
150246
)
151-
define_test_targets("linear_weight_int4_test")
152-
define_test_targets("rotary_embedding_test")

0 commit comments

Comments
 (0)