Skip to content

Commit 762a21e

Browse files
committed
Update on "[ET-VK] Fix OSS build + separate test build into its own CMakeLists.txt"
## Context As title. In the next diff, the vulkan test binary will be added to CI. Differential Revision: [D57747739](https://our.internmc.facebook.com/intern/diff/D57747739) [ghstack-poisoned]
2 parents 10cb72f + 3c4e875 commit 762a21e

34 files changed

+266
-44
lines changed

.ci/scripts/setup-macos.sh

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -28,15 +28,18 @@ install_buck() {
2828
fi
2929

3030
pushd .ci/docker
31-
3231
# TODO(huydo): This is a one-off copy of buck2 2024-05-15 to unblock Jon and
3332
# re-enable ShipIt. It’s not ideal that upgrading buck2 will require a manual
3433
# update the cached binary on S3 bucket too. Let me figure out if there is a
3534
# way to correctly implement the previous setup of installing a new version of
3635
# buck2 only when it’s needed. AFAIK, the complicated part was that buck2
3736
# --version doesn't say anything w.r.t its release version, i.e. 2024-05-15.
3837
# See D53878006 for more details.
39-
BUCK2=buck2-aarch64-apple-darwin.zst
38+
#
39+
# If you need to upgrade buck2 version on S3, please reach out to Dev Infra
40+
# team for help.
41+
BUCK2_VERSION=$(cat ci_commit_pins/buck2.txt)
42+
BUCK2=buck2-aarch64-apple-darwin-${BUCK2_VERSION}.zst
4043
curl -s "https://ossci-macos.s3.amazonaws.com/${BUCK2}" -o "${BUCK2}"
4144

4245
zstd -d "${BUCK2}" -o buck2

backends/vulkan/main.cpp

Lines changed: 0 additions & 7 deletions
This file was deleted.

backends/vulkan/partitioner/supported_ops.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -52,12 +52,15 @@ def __contains__(self, op):
5252
UNARY_OPS = [
5353
exir_ops.edge.aten.abs.default,
5454
exir_ops.edge.aten.clamp.default,
55+
exir_ops.edge.aten.cos.default,
5556
exir_ops.edge.aten.exp.default,
5657
exir_ops.edge.aten.gelu.default,
5758
exir_ops.edge.aten.hardshrink.default,
5859
exir_ops.edge.aten.hardtanh.default,
60+
exir_ops.edge.aten.neg.default,
5961
exir_ops.edge.aten.relu.default,
6062
exir_ops.edge.aten.sigmoid.default,
63+
exir_ops.edge.aten.sin.default,
6164
exir_ops.edge.aten.sqrt.default,
6265
exir_ops.edge.aten.tanh.default,
6366
]
@@ -84,6 +87,7 @@ def __contains__(self, op):
8487
]
8588

8689
NORMALIZATION_OPS = [
90+
exir_ops.edge.aten._native_batch_norm_legit_no_training.default,
8791
exir_ops.edge.aten.native_layer_norm.default,
8892
]
8993

backends/vulkan/runtime/api/QueryPool.h

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -18,12 +18,16 @@
1818
#include <executorch/backends/vulkan/runtime/api/Command.h>
1919
#include <executorch/backends/vulkan/runtime/api/Pipeline.h>
2020

21+
#ifndef VULKAN_QUERY_POOL_SIZE
22+
#define VULKAN_QUERY_POOL_SIZE 4096u
23+
#endif
24+
2125
namespace vkcompute {
2226
namespace api {
2327

2428
struct QueryPoolConfig final {
25-
uint32_t max_query_count;
26-
uint32_t initial_reserve_size;
29+
uint32_t max_query_count = VULKAN_QUERY_POOL_SIZE;
30+
uint32_t initial_reserve_size = 256u;
2731
};
2832

2933
struct ShaderDuration final {

backends/vulkan/runtime/graph/ops/glsl/unary_op.yaml

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -18,12 +18,18 @@ unary_op:
1818
- NAME: clamp_int
1919
OPERATOR: clamp(X, A, B)
2020
DTYPE: int
21+
- NAME: cos
22+
OPERATOR: cos(X)
2123
- NAME: exp
2224
OPERATOR: exp(X)
2325
- NAME: gelu
2426
OPERATOR: 0.5 * X * (1 + tanh(sqrt(2 / 3.141593) * (X + 0.044715 * X * X * X)))
27+
- NAME: neg
28+
OPERATOR: -X
2529
- NAME: sigmoid
2630
OPERATOR: 1 / (1 + exp(-1 * X))
31+
- NAME: sin
32+
OPERATOR: sin(X)
2733
- NAME: sqrt
2834
OPERATOR: sqrt(X)
2935
- NAME: tanh
Lines changed: 43 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,43 @@
1+
/*
2+
* Copyright (c) Meta Platforms, Inc. and affiliates.
3+
* All rights reserved.
4+
*
5+
* This source code is licensed under the BSD-style license found in the
6+
* LICENSE file in the root directory of this source tree.
7+
*/
8+
9+
#version 450 core
10+
11+
#include "broadcasting_utils.h"
12+
#include "indexing_utils.h"
13+
14+
#define PRECISION ${PRECISION}
15+
16+
#define VEC4_T ${texel_type(DTYPE)}
17+
18+
layout(std430) buffer;
19+
20+
layout(set = 0, binding = 0, ${IMAGE_FORMAT[DTYPE]}) uniform PRECISION restrict writeonly ${IMAGE_T[NDIM][DTYPE]} image_out;
21+
22+
layout(set = 0, binding = 1) uniform PRECISION sampler3D image_in;
23+
24+
layout(set = 0, binding = 2) uniform PRECISION restrict OutLimits {
25+
ivec3 out_limits;
26+
};
27+
28+
layout(set = 0, binding = 3) uniform PRECISION restrict Sizes {
29+
ivec4 sizes;
30+
};
31+
32+
layout(local_size_x_id = 0, local_size_y_id = 1, local_size_z_id = 2) in;
33+
34+
void main() {
35+
const ivec3 pos = ivec3(gl_GlobalInvocationID);
36+
37+
if (any(greaterThanEqual(pos, out_limits))) {
38+
return;
39+
}
40+
41+
VEC4_T in_texel = texelFetch(image_in, pos, 0);
42+
imageStore(image_out, pos, in_texel);
43+
}
Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,17 @@
1+
# Copyright (c) Meta Platforms, Inc. and affiliates.
2+
# All rights reserved.
3+
#
4+
# This source code is licensed under the BSD-style license found in the
5+
# LICENSE file in the root directory of this source tree.
6+
7+
upsample:
8+
parameter_names_with_default_values:
9+
NDIM: 3
10+
DTYPE: float
11+
PACKING: C_packed
12+
generate_variant_forall:
13+
DTYPE:
14+
- VALUE: half
15+
- VALUE: float
16+
shader_variants:
17+
- NAME: upsample

backends/vulkan/runtime/graph/ops/impl/UnaryOp.cpp

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -122,8 +122,11 @@ void gelu(ComputeGraph& graph, const std::vector<ValueRef>& args) {
122122
}
123123

124124
DEFINE_ACTIVATION_FN(abs);
125+
DEFINE_ACTIVATION_FN(cos);
125126
DEFINE_ACTIVATION_FN(exp);
127+
DEFINE_ACTIVATION_FN(neg);
126128
DEFINE_ACTIVATION_FN(sigmoid);
129+
DEFINE_ACTIVATION_FN(sin);
127130
DEFINE_ACTIVATION_FN(sqrt);
128131
DEFINE_ACTIVATION_FN(tanh);
129132
DEFINE_CLAMP_FN(clamp);
@@ -134,11 +137,14 @@ DEFINE_HARDSHRINK_FN(hardshrink);
134137
REGISTER_OPERATORS {
135138
VK_REGISTER_OP(aten.abs.default, abs);
136139
VK_REGISTER_OP(aten.clamp.default, clamp);
140+
VK_REGISTER_OP(aten.cos.default, cos);
137141
VK_REGISTER_OP(aten.exp.default, exp);
138142
VK_REGISTER_OP(aten.gelu.default, gelu);
139143
VK_REGISTER_OP(aten.hardtanh.default, hardtanh);
144+
VK_REGISTER_OP(aten.neg.default, neg);
140145
VK_REGISTER_OP(aten.relu.default, relu);
141146
VK_REGISTER_OP(aten.sigmoid.default, sigmoid);
147+
VK_REGISTER_OP(aten.sin.default, sin);
142148
VK_REGISTER_OP(aten.sqrt.default, sqrt);
143149
VK_REGISTER_OP(aten.tanh.default, tanh);
144150
VK_REGISTER_OP(aten.hardshrink.default, hardshrink);
Lines changed: 67 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,67 @@
1+
/*
2+
* Copyright (c) Meta Platforms, Inc. and affiliates.
3+
* All rights reserved.
4+
*
5+
* This source code is licensed under the BSD-style license found in the
6+
* LICENSE file in the root directory of this source tree.
7+
*/
8+
9+
#include <executorch/backends/vulkan/runtime/graph/ops/OperatorRegistry.h>
10+
11+
#include <executorch/backends/vulkan/runtime/graph/ops/impl/Staging.h>
12+
13+
#include <executorch/backends/vulkan/runtime/graph/ops/impl/utils/ScalarUtils.h>
14+
#include <executorch/backends/vulkan/runtime/graph/ops/impl/utils/TensorUtils.h>
15+
16+
#include <executorch/backends/vulkan/runtime/graph/ops/utils/ShaderNameUtils.h>
17+
18+
namespace vkcompute {
19+
20+
void resize_upsample_node(
21+
ComputeGraph* graph,
22+
const std::vector<ArgGroup>& args,
23+
const std::vector<ValueRef>& extra_args) {
24+
(void)graph;
25+
(void)args;
26+
(void)extra_args;
27+
}
28+
29+
void add_upsample_node(
30+
ComputeGraph& graph,
31+
const ValueRef in,
32+
const ValueRef out) {
33+
ValueRef arg = prepack_if_tensor_ref(graph, in);
34+
35+
vTensorPtr t_out = graph.get_tensor(out);
36+
api::utils::uvec3 global_size = t_out->image_extents();
37+
api::utils::uvec3 local_size = adaptive_work_group_size(global_size);
38+
39+
std::string kernel_name("upsample");
40+
kernel_name.reserve(kShaderNameReserve);
41+
42+
add_dtype_suffix(kernel_name, *t_out);
43+
44+
graph.execute_nodes().emplace_back(new ExecuteNode(
45+
graph,
46+
VK_KERNEL_FROM_STR(kernel_name),
47+
global_size,
48+
local_size,
49+
// Inputs and Outputs
50+
{{out, api::MemoryAccessType::WRITE}, {arg, api::MemoryAccessType::READ}},
51+
// Shader params buffers
52+
{t_out->texture_limits_ubo(), graph.create_params_buffer(0.5)},
53+
// Specialization Constants
54+
{},
55+
// Resizing Logic
56+
resize_upsample_node));
57+
}
58+
59+
void upsample(ComputeGraph& graph, const std::vector<ValueRef>& args) {
60+
return add_upsample_node(graph, args[0], args[3]);
61+
}
62+
63+
REGISTER_OPERATORS {
64+
VK_REGISTER_OP(aten.upsample_nearest2d.vec, upsample);
65+
}
66+
67+
} // namespace vkcompute

backends/vulkan/test/op_tests/cases.py

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -241,6 +241,16 @@ def get_native_layer_norm_inputs():
241241
return test_suite
242242

243243

244+
def get_upsample_inputs():
245+
test_suite = VkTestSuite(
246+
[
247+
# TODO(dixu): implement the basic upsample logic to have a meaningful test
248+
((2, 2, 2, 2), None, [1, 1]),
249+
]
250+
)
251+
return test_suite
252+
253+
244254
def get_full_inputs():
245255
test_suite = VkTestSuite(
246256
[
@@ -672,6 +682,8 @@ def get_unary_ops_inputs():
672682
]
673683
)
674684
test_suite.storage_types = ["api::kTexture3D", "api::kBuffer"]
685+
test_suite.atol = "1e-4"
686+
test_suite.rtol = "1e-4"
675687
return test_suite
676688

677689

@@ -796,4 +808,8 @@ def get_gelu_inputs():
796808
"aten._native_batch_norm_legit_no_training.default": get_native_batch_norm_inputs(),
797809
"aten.gelu.default": get_gelu_inputs(),
798810
"aten.hardshrink.default": get_unary_ops_inputs(),
811+
"aten.upsample_nearest2d.vec": get_upsample_inputs(),
812+
"aten.sin.default": get_unary_ops_inputs(),
813+
"aten.neg.default": get_unary_ops_inputs(),
814+
"aten.cos.default": get_unary_ops_inputs(),
799815
}

backends/vulkan/test/op_tests/utils/codegen.py

Lines changed: 16 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,8 @@
1717
CppTestFileGen,
1818
DOUBLE,
1919
INT,
20+
OPT_AT_DOUBLE_ARRAY_REF,
21+
OPT_AT_INT_ARRAY_REF,
2022
OPT_AT_TENSOR,
2123
OPT_BOOL,
2224
OPT_DEVICE,
@@ -289,6 +291,16 @@ def create_value_for(self, ref: ValueRefList) -> str: # noqa: C901
289291
ret_str += f"{self.graph}{self.dot}add_scalar<int64_t>"
290292
ret_str += f"({ref.src_cpp_name}.value());\n"
291293
return ret_str
294+
elif (
295+
ref.src_cpp_type == OPT_AT_DOUBLE_ARRAY_REF
296+
or ref.src_cpp_type == OPT_AT_INT_ARRAY_REF
297+
):
298+
ret_str = f"{cpp_type} {ref.name} = "
299+
ret_str += f"!{ref.src_cpp_name}.has_value() ? "
300+
ret_str += f"{self.graph}{self.dot}add_none() : "
301+
ret_str += f"{self.graph}{self.dot}add_scalar_list"
302+
ret_str += f"({ref.src_cpp_name}->vec());\n"
303+
return ret_str
292304
elif ref.src_cpp_type == AT_TENSOR_LIST:
293305
assert ref.is_in, "AT_TENSOR_LIST must be an input"
294306
# This logic is a bit convoluted. We need to create a IOValueRef for
@@ -588,8 +600,8 @@ class GeneratedOpsTest_{op_name} : public ::testing::TestWithParam< ::std::tuple
588600
protected:
589601
ComputeGraph* graph;
590602
at::ScalarType test_dtype = at::kFloat;
591-
float rtol = 1e-5;
592-
float atol = 1e-5;
603+
float rtol = {rtol};
604+
float atol = {atol};
593605
594606
void SetUp() override {{
595607
GraphConfig config;
@@ -639,6 +651,8 @@ def generate_fixture_cpp(self) -> str:
639651
op_name=self.op_name,
640652
check_fn=check_fn,
641653
prepacked_check_fn=prepacked_check_fn,
654+
rtol=self.suite_def.rtol,
655+
atol=self.suite_def.atol,
642656
)
643657

644658
def gen_parameterization(self) -> str:

backends/vulkan/test/op_tests/utils/codegen_base.py

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,8 @@
2222
BOOL = "bool"
2323
DOUBLE = "double"
2424
INT = "int64_t"
25+
OPT_AT_DOUBLE_ARRAY_REF = "::std::optional<at::ArrayRef<double>>"
26+
OPT_AT_INT_ARRAY_REF = "at::OptionalIntArrayRef"
2527
OPT_AT_TENSOR = "::std::optional<at::Tensor>"
2628
OPT_BOOL = "::std::optional<bool>"
2729
OPT_INT64 = "::std::optional<int64_t>"
@@ -45,6 +47,8 @@ def __init__(self, input_cases: List[Any]):
4547
self.prepacked_args: List[str] = []
4648
self.requires_prepack: bool = False
4749
self.dtypes: List[str] = ["at::kFloat", "at::kHalf"]
50+
self.atol: str = "1e-5"
51+
self.rtol: str = "1e-5"
4852

4953
def supports_prepack(self):
5054
return len(self.prepacked_args) > 0
@@ -142,6 +146,10 @@ def create_input_data(self, arg: Argument, data: Any) -> str: # noqa: C901
142146

143147
if cpp_type == AT_INT_ARRAY_REF:
144148
ret_str = f"std::vector<int64_t> {arg.name} = "
149+
elif (
150+
cpp_type == OPT_AT_DOUBLE_ARRAY_REF or cpp_type == OPT_AT_INT_ARRAY_REF
151+
) and str(data) != "None":
152+
ret_str = f"std::vector<double> {arg.name} = "
145153
else:
146154
ret_str = f"{cpp_type} {arg.name} = "
147155

@@ -156,6 +164,11 @@ def create_input_data(self, arg: Argument, data: Any) -> str: # noqa: C901
156164
ret_str += f"{data};"
157165
elif cpp_type == AT_INT_ARRAY_REF:
158166
ret_str += f"{init_list_str(data)};"
167+
elif cpp_type == OPT_AT_DOUBLE_ARRAY_REF or cpp_type == OPT_AT_INT_ARRAY_REF:
168+
if str(data) == "None":
169+
ret_str += "std::nullopt;"
170+
else:
171+
ret_str += f"{init_list_str(data)};"
159172
elif cpp_type == BOOL:
160173
ret_str += f"{str(data).lower()};"
161174
elif cpp_type == INT:

0 commit comments

Comments
 (0)