@@ -3,44 +3,6 @@ load("@fbsource//xplat/caffe2:pt_defs.bzl", "get_pt_ops_deps")
3
3
load ("@fbsource//xplat/caffe2:pt_ops.bzl" , "pt_operator_library" )
4
4
load ("@fbsource//xplat/executorch/build:runtime_wrapper.bzl" , "runtime" )
5
5
6
- def define_test_targets (test_name , extra_deps = [], src_file = None , is_fbcode = False ):
7
- deps_list = [
8
- "//third-party/googletest:gtest_main" ,
9
- "//executorch/backends/vulkan:vulkan_graph_runtime" ,
10
- runtime .external_dep_location ("libtorch" ),
11
- ] + extra_deps
12
-
13
- src_file_str = src_file if src_file else "{}.cpp" .format (test_name )
14
-
15
- runtime .cxx_binary (
16
- name = "{}_bin" .format (test_name ),
17
- srcs = [
18
- src_file_str ,
19
- ],
20
- compiler_flags = [
21
- "-Wno-unused-variable" ,
22
- ],
23
- define_static_target = False ,
24
- deps = deps_list ,
25
- )
26
-
27
- runtime .cxx_test (
28
- name = test_name ,
29
- srcs = [
30
- src_file_str ,
31
- ],
32
-
33
- fbandroid_additional_loaded_sonames = [
34
- "torch-code-gen" ,
35
- "vulkan_graph_runtime" ,
36
- "vulkan_graph_runtime_shaderlib" ,
37
- ],
38
- platforms = [ANDROID ],
39
- use_instrumentation_test = True ,
40
- deps = deps_list ,
41
- )
42
-
43
-
44
6
def define_common_targets (is_fbcode = False ):
45
7
if is_fbcode :
46
8
return
@@ -120,6 +82,19 @@ def define_common_targets(is_fbcode = False):
120
82
default_outs = ["." ],
121
83
)
122
84
85
+ runtime .cxx_binary (
86
+ name = "compute_graph_op_tests_bin" ,
87
+ srcs = [
88
+ ":generated_op_correctness_tests_cpp[op_tests.cpp]" ,
89
+ ],
90
+ define_static_target = False ,
91
+ deps = [
92
+ "//third-party/googletest:gtest_main" ,
93
+ "//executorch/backends/vulkan:vulkan_graph_runtime" ,
94
+ runtime .external_dep_location ("libtorch" ),
95
+ ],
96
+ )
97
+
123
98
runtime .cxx_binary (
124
99
name = "compute_graph_op_benchmarks_bin" ,
125
100
srcs = [
@@ -136,17 +111,136 @@ def define_common_targets(is_fbcode = False):
136
111
],
137
112
)
138
113
139
- define_test_targets (
140
- "compute_graph_op_tests" ,
141
- src_file = ":generated_op_correctness_tests_cpp[op_tests.cpp]"
114
+ runtime .cxx_test (
115
+ name = "compute_graph_op_tests" ,
116
+ srcs = [
117
+ ":generated_op_correctness_tests_cpp[op_tests.cpp]" ,
118
+ ],
119
+
120
+ fbandroid_additional_loaded_sonames = [
121
+ "torch-code-gen" ,
122
+ "vulkan_graph_runtime" ,
123
+ "vulkan_graph_runtime_shaderlib" ,
124
+ ],
125
+ platforms = [ANDROID ],
126
+ use_instrumentation_test = True ,
127
+ deps = [
128
+ "//third-party/googletest:gtest_main" ,
129
+ "//executorch/backends/vulkan:vulkan_graph_runtime" ,
130
+ runtime .external_dep_location ("libtorch" ),
131
+ ],
142
132
)
143
133
144
- define_test_targets (
145
- "sdpa_test" ,
146
- extra_deps = [
134
+
135
+ runtime .cxx_binary (
136
+ name = "sdpa_test_bin" ,
137
+ srcs = [
138
+ "sdpa_test.cpp" ,
139
+ ],
140
+ compiler_flags = [
141
+ "-Wno-unused-variable" ,
142
+ ],
143
+ define_static_target = False ,
144
+ deps = [
145
+ "//third-party/googletest:gtest_main" ,
146
+ "//executorch/backends/vulkan:vulkan_graph_runtime" ,
147
+ "//executorch/extension/llm/custom_ops:custom_ops_aot_lib" ,
148
+ ],
149
+ )
150
+
151
+ runtime .cxx_test (
152
+ name = "sdpa_test" ,
153
+ srcs = [
154
+ "sdpa_test.cpp" ,
155
+ ],
156
+
157
+ fbandroid_additional_loaded_sonames = [
158
+ "torch-code-gen" ,
159
+ "vulkan_graph_runtime" ,
160
+ "vulkan_graph_runtime_shaderlib" ,
161
+ ],
162
+ platforms = [ANDROID ],
163
+ use_instrumentation_test = True ,
164
+ deps = [
165
+ "//third-party/googletest:gtest_main" ,
166
+ "//executorch/backends/vulkan:vulkan_graph_runtime" ,
167
+ "//executorch/extension/llm/custom_ops:custom_ops_aot_lib" ,
168
+ "//executorch/extension/tensor:tensor" ,
169
+ runtime .external_dep_location ("libtorch" ),
170
+ ],
171
+ )
172
+
173
+ runtime .cxx_binary (
174
+ name = "linear_weight_int4_test_bin" ,
175
+ srcs = [
176
+ "linear_weight_int4_test.cpp" ,
177
+ ],
178
+ compiler_flags = [
179
+ "-Wno-unused-variable" ,
180
+ ],
181
+ define_static_target = False ,
182
+ deps = [
183
+ "//third-party/googletest:gtest_main" ,
184
+ "//executorch/backends/vulkan:vulkan_graph_runtime" ,
185
+ runtime .external_dep_location ("libtorch" ),
186
+ ],
187
+ )
188
+
189
+ runtime .cxx_test (
190
+ name = "linear_weight_int4_test" ,
191
+ srcs = [
192
+ "linear_weight_int4_test.cpp" ,
193
+ ],
194
+
195
+ fbandroid_additional_loaded_sonames = [
196
+ "torch-code-gen" ,
197
+ "vulkan_graph_runtime" ,
198
+ "vulkan_graph_runtime_shaderlib" ,
199
+ ],
200
+ platforms = [ANDROID ],
201
+ use_instrumentation_test = True ,
202
+ deps = [
203
+ "//third-party/googletest:gtest_main" ,
204
+ "//executorch/backends/vulkan:vulkan_graph_runtime" ,
147
205
"//executorch/extension/llm/custom_ops:custom_ops_aot_lib" ,
148
206
"//executorch/extension/tensor:tensor" ,
149
- ]
207
+ runtime .external_dep_location ("libtorch" ),
208
+ ],
209
+ )
210
+
211
+ runtime .cxx_binary (
212
+ name = "rotary_embedding_test_bin" ,
213
+ srcs = [
214
+ "rotary_embedding_test.cpp" ,
215
+ ],
216
+ compiler_flags = [
217
+ "-Wno-unused-variable" ,
218
+ ],
219
+ define_static_target = False ,
220
+ deps = [
221
+ "//third-party/googletest:gtest_main" ,
222
+ "//executorch/backends/vulkan:vulkan_graph_runtime" ,
223
+ runtime .external_dep_location ("libtorch" ),
224
+ ],
225
+ )
226
+
227
+ runtime .cxx_test (
228
+ name = "rotary_embedding_test" ,
229
+ srcs = [
230
+ "rotary_embedding_test.cpp" ,
231
+ ],
232
+
233
+ fbandroid_additional_loaded_sonames = [
234
+ "torch-code-gen" ,
235
+ "vulkan_graph_runtime" ,
236
+ "vulkan_graph_runtime_shaderlib" ,
237
+ ],
238
+ platforms = [ANDROID ],
239
+ use_instrumentation_test = True ,
240
+ deps = [
241
+ "//third-party/googletest:gtest_main" ,
242
+ "//executorch/backends/vulkan:vulkan_graph_runtime" ,
243
+ "//executorch/extension/tensor:tensor" ,
244
+ runtime .external_dep_location ("libtorch" ),
245
+ ],
150
246
)
151
- define_test_targets ("linear_weight_int4_test" )
152
- define_test_targets ("rotary_embedding_test" )
0 commit comments