Skip to content

Commit c28ca18

Browse files
Di Xu (AR)facebook-github-bot
authored andcommitted
Integrate a placeholder upsample_nearest2d.vec to Vulkan codegen operator tests (#3711)
Summary: Pull Request resolved: #3711 Integrate a placeholder upsample_nearest2d.vec to Vulkan codegen operator tests - Right now just implement the codegen to support opt_int_array and opt_double_array. - Using a scaling factor = 1 and just passing the input to pass the unit tests bypass-github-export-checks bypass-github-pytorch-ci-checks bypass-github-executorch-ci-checks Reviewed By: jorgep31415 Differential Revision: D57643709 fbshipit-source-id: 9d7f177890c035a5c7b44bf0a9a7e439d6a16da8
1 parent 1dd4935 commit c28ca18

File tree

6 files changed

+161
-0
lines changed

6 files changed

+161
-0
lines changed
Lines changed: 43 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,43 @@
1+
/*
2+
* Copyright (c) Meta Platforms, Inc. and affiliates.
3+
* All rights reserved.
4+
*
5+
* This source code is licensed under the BSD-style license found in the
6+
* LICENSE file in the root directory of this source tree.
7+
*/
8+
9+
#version 450 core
10+
11+
#include "broadcasting_utils.h"
12+
#include "indexing_utils.h"
13+
14+
#define PRECISION ${PRECISION}
15+
16+
#define VEC4_T ${texel_type(DTYPE)}
17+
18+
layout(std430) buffer;
19+
20+
layout(set = 0, binding = 0, ${IMAGE_FORMAT[DTYPE]}) uniform PRECISION restrict writeonly ${IMAGE_T[NDIM][DTYPE]} image_out;
21+
22+
layout(set = 0, binding = 1) uniform PRECISION sampler3D image_in;
23+
24+
layout(set = 0, binding = 2) uniform PRECISION restrict OutLimits {
25+
ivec3 out_limits;
26+
};
27+
28+
layout(set = 0, binding = 3) uniform PRECISION restrict Sizes {
29+
ivec4 sizes;
30+
};
31+
32+
layout(local_size_x_id = 0, local_size_y_id = 1, local_size_z_id = 2) in;
33+
34+
void main() {
35+
const ivec3 pos = ivec3(gl_GlobalInvocationID);
36+
37+
if (any(greaterThanEqual(pos, out_limits))) {
38+
return;
39+
}
40+
41+
VEC4_T in_texel = texelFetch(image_in, pos, 0);
42+
imageStore(image_out, pos, in_texel);
43+
}
Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,17 @@
1+
# Copyright (c) Meta Platforms, Inc. and affiliates.
2+
# All rights reserved.
3+
#
4+
# This source code is licensed under the BSD-style license found in the
5+
# LICENSE file in the root directory of this source tree.
6+
7+
upsample:
8+
parameter_names_with_default_values:
9+
NDIM: 3
10+
DTYPE: float
11+
PACKING: C_packed
12+
generate_variant_forall:
13+
DTYPE:
14+
- VALUE: half
15+
- VALUE: float
16+
shader_variants:
17+
- NAME: upsample
Lines changed: 67 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,67 @@
1+
/*
2+
* Copyright (c) Meta Platforms, Inc. and affiliates.
3+
* All rights reserved.
4+
*
5+
* This source code is licensed under the BSD-style license found in the
6+
* LICENSE file in the root directory of this source tree.
7+
*/
8+
9+
#include <executorch/backends/vulkan/runtime/graph/ops/OperatorRegistry.h>
10+
11+
#include <executorch/backends/vulkan/runtime/graph/ops/impl/Staging.h>
12+
13+
#include <executorch/backends/vulkan/runtime/graph/ops/impl/utils/ScalarUtils.h>
14+
#include <executorch/backends/vulkan/runtime/graph/ops/impl/utils/TensorUtils.h>
15+
16+
#include <executorch/backends/vulkan/runtime/graph/ops/utils/ShaderNameUtils.h>
17+
18+
namespace vkcompute {
19+
20+
void resize_upsample_node(
21+
ComputeGraph* graph,
22+
const std::vector<ArgGroup>& args,
23+
const std::vector<ValueRef>& extra_args) {
24+
(void)graph;
25+
(void)args;
26+
(void)extra_args;
27+
}
28+
29+
void add_upsample_node(
30+
ComputeGraph& graph,
31+
const ValueRef in,
32+
const ValueRef out) {
33+
ValueRef arg = prepack_if_tensor_ref(graph, in);
34+
35+
vTensorPtr t_out = graph.get_tensor(out);
36+
api::utils::uvec3 global_size = t_out->image_extents();
37+
api::utils::uvec3 local_size = adaptive_work_group_size(global_size);
38+
39+
std::string kernel_name("upsample");
40+
kernel_name.reserve(kShaderNameReserve);
41+
42+
add_dtype_suffix(kernel_name, *t_out);
43+
44+
graph.execute_nodes().emplace_back(new ExecuteNode(
45+
graph,
46+
VK_KERNEL_FROM_STR(kernel_name),
47+
global_size,
48+
local_size,
49+
// Inputs and Outputs
50+
{{out, api::MemoryAccessType::WRITE}, {arg, api::MemoryAccessType::READ}},
51+
// Shader params buffers
52+
{t_out->texture_limits_ubo(), graph.create_params_buffer(0.5)},
53+
// Specialization Constants
54+
{},
55+
// Resizing Logic
56+
resize_upsample_node));
57+
}
58+
59+
void upsample(ComputeGraph& graph, const std::vector<ValueRef>& args) {
60+
return add_upsample_node(graph, args[0], args[3]);
61+
}
62+
63+
REGISTER_OPERATORS {
64+
VK_REGISTER_OP(aten.upsample_nearest2d.vec, upsample);
65+
}
66+
67+
} // namespace vkcompute

backends/vulkan/test/op_tests/cases.py

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -241,6 +241,16 @@ def get_native_layer_norm_inputs():
241241
return test_suite
242242

243243

244+
def get_upsample_inputs():
245+
test_suite = VkTestSuite(
246+
[
247+
# TODO(dixu): implement the basic upsample logic to have a meaningful test
248+
((2, 2, 2, 2), None, [1, 1]),
249+
]
250+
)
251+
return test_suite
252+
253+
244254
def get_full_inputs():
245255
test_suite = VkTestSuite(
246256
[
@@ -796,4 +806,5 @@ def get_gelu_inputs():
796806
"aten._native_batch_norm_legit_no_training.default": get_native_batch_norm_inputs(),
797807
"aten.gelu.default": get_gelu_inputs(),
798808
"aten.hardshrink.default": get_unary_ops_inputs(),
809+
"aten.upsample_nearest2d.vec": get_upsample_inputs(),
799810
}

backends/vulkan/test/op_tests/utils/codegen.py

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,8 @@
1717
CppTestFileGen,
1818
DOUBLE,
1919
INT,
20+
OPT_AT_DOUBLE_ARRAY_REF,
21+
OPT_AT_INT_ARRAY_REF,
2022
OPT_AT_TENSOR,
2123
OPT_BOOL,
2224
OPT_DEVICE,
@@ -289,6 +291,16 @@ def create_value_for(self, ref: ValueRefList) -> str: # noqa: C901
289291
ret_str += f"{self.graph}{self.dot}add_scalar<int64_t>"
290292
ret_str += f"({ref.src_cpp_name}.value());\n"
291293
return ret_str
294+
elif (
295+
ref.src_cpp_type == OPT_AT_DOUBLE_ARRAY_REF
296+
or ref.src_cpp_type == OPT_AT_INT_ARRAY_REF
297+
):
298+
ret_str = f"{cpp_type} {ref.name} = "
299+
ret_str += f"!{ref.src_cpp_name}.has_value() ? "
300+
ret_str += f"{self.graph}{self.dot}add_none() : "
301+
ret_str += f"{self.graph}{self.dot}add_scalar_list"
302+
ret_str += f"({ref.src_cpp_name}->vec());\n"
303+
return ret_str
292304
elif ref.src_cpp_type == AT_TENSOR_LIST:
293305
assert ref.is_in, "AT_TENSOR_LIST must be an input"
294306
# This logic is a bit convoluted. We need to create a IOValueRef for

backends/vulkan/test/op_tests/utils/codegen_base.py

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,8 @@
2222
BOOL = "bool"
2323
DOUBLE = "double"
2424
INT = "int64_t"
25+
OPT_AT_DOUBLE_ARRAY_REF = "::std::optional<at::ArrayRef<double>>"
26+
OPT_AT_INT_ARRAY_REF = "at::OptionalIntArrayRef"
2527
OPT_AT_TENSOR = "::std::optional<at::Tensor>"
2628
OPT_BOOL = "::std::optional<bool>"
2729
OPT_INT64 = "::std::optional<int64_t>"
@@ -142,6 +144,10 @@ def create_input_data(self, arg: Argument, data: Any) -> str: # noqa: C901
142144

143145
if cpp_type == AT_INT_ARRAY_REF:
144146
ret_str = f"std::vector<int64_t> {arg.name} = "
147+
elif (
148+
cpp_type == OPT_AT_DOUBLE_ARRAY_REF or cpp_type == OPT_AT_INT_ARRAY_REF
149+
) and str(data) != "None":
150+
ret_str = f"std::vector<double> {arg.name} = "
145151
else:
146152
ret_str = f"{cpp_type} {arg.name} = "
147153

@@ -156,6 +162,11 @@ def create_input_data(self, arg: Argument, data: Any) -> str: # noqa: C901
156162
ret_str += f"{data};"
157163
elif cpp_type == AT_INT_ARRAY_REF:
158164
ret_str += f"{init_list_str(data)};"
165+
elif cpp_type == OPT_AT_DOUBLE_ARRAY_REF or cpp_type == OPT_AT_INT_ARRAY_REF:
166+
if str(data) == "None":
167+
ret_str += "std::nullopt;"
168+
else:
169+
ret_str += f"{init_list_str(data)};"
159170
elif cpp_type == BOOL:
160171
ret_str += f"{str(data).lower()};"
161172
elif cpp_type == INT:

0 commit comments

Comments
 (0)