Skip to content

Commit cf92ed3

Browse files
committed
[ET-VK][9/n] clone node
Pull Request resolved: #3219 Introduce a clone node for copy operation. Also register `aten.clone` to this node. Important to note that during model export, possible to point the lvalue of `aten.clone` to the underlying shared object of the rvalue to achieve no-copy. ghstack-source-id: 223694720 Differential Revision: [D56441547](https://our.internmc.facebook.com/intern/diff/D56441547/)
1 parent 2166892 commit cf92ed3

File tree

6 files changed

+123
-4
lines changed

6 files changed

+123
-4
lines changed
Lines changed: 30 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,30 @@
1+
/*
2+
* Copyright (c) Meta Platforms, Inc. and affiliates.
3+
* All rights reserved.
4+
*
5+
* This source code is licensed under the BSD-style license found in the
6+
* LICENSE file in the root directory of this source tree.
7+
*/
8+
9+
#version 450 core
10+
11+
#define PRECISION ${PRECISION}
12+
13+
layout(std430) buffer;
14+
15+
layout(set = 0, binding = 0, ${IMAGE_FORMAT[DTYPE]}) uniform PRECISION restrict writeonly ${IMAGE_T[NDIM][DTYPE]} image_out;
16+
layout(set = 0, binding = 1) uniform PRECISION sampler3D image_in;
17+
18+
layout(set = 0, binding = 2) uniform PRECISION restrict OutLimits {
19+
ivec3 out_limits;
20+
};
21+
22+
layout(local_size_x_id = 0, local_size_y_id = 1, local_size_z_id = 2) in;
23+
24+
void main() {
25+
ivec3 pos = ivec3(gl_GlobalInvocationID);
26+
if (any(greaterThanEqual(pos, out_limits))) {
27+
return;
28+
}
29+
imageStore(image_out, pos, texelFetch(image_in, pos, 0));
30+
}
Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,10 @@
1+
clone:
2+
parameter_names_with_default_values:
3+
DTYPE: float
4+
NDIM: 3
5+
generate_variant_forall:
6+
DTYPE:
7+
- VALUE: half
8+
- VALUE: float
9+
shader_variants:
10+
- NAME: clone
Lines changed: 55 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,55 @@
1+
/*
2+
* Copyright (c) Meta Platforms, Inc. and affiliates.
3+
* All rights reserved.
4+
*
5+
* This source code is licensed under the BSD-style license found in the
6+
* LICENSE file in the root directory of this source tree.
7+
*/
8+
9+
#include <executorch/backends/vulkan/runtime/graph/ops/OperatorRegistry.h>
10+
11+
#include <executorch/backends/vulkan/runtime/graph/Logging.h>
12+
13+
#include <executorch/backends/vulkan/runtime/graph/ops/impl/utils/KernelUtils.h>
14+
#include <executorch/backends/vulkan/runtime/graph/ops/impl/utils/TensorUtils.h>
15+
#include <executorch/backends/vulkan/runtime/graph/ops/utils/ShaderNameUtils.h>
16+
17+
namespace vkcompute {
18+
19+
void add_clone_node(
20+
ComputeGraph& graph,
21+
const ValueRef in,
22+
const ValueRef out) {
23+
vTensorPtr t_out = graph.get_tensor(out);
24+
25+
std::string kernel_name = "clone";
26+
add_dtype_suffix(kernel_name, *t_out);
27+
28+
api::utils::uvec3 global_size = t_out->extents();
29+
api::utils::uvec3 local_size = adaptive_work_group_size(global_size);
30+
31+
graph.execute_nodes().emplace_back(new ExecuteNode(
32+
graph,
33+
VK_KERNEL_FROM_STR(kernel_name),
34+
global_size,
35+
local_size,
36+
{{out, api::MemoryAccessType::WRITE}, {in, api::MemoryAccessType::READ}},
37+
{t_out->texture_limits_ubo()}));
38+
}
39+
40+
void clone(ComputeGraph& graph, const std::vector<ValueRef>& args) {
41+
// The vulkan delegate does not support changing memory format.
42+
return add_clone_node(graph, args[0], args[2]);
43+
}
44+
45+
// Clone node is not the most efficient implementation for the aten.clone
46+
// operation. A more efficient implementation can be achieved during vulkan
47+
// export with the use of shared object. This clone node is introduced to enable
48+
// a "copy" mechanism if there is no alternative (e.g. during direct
49+
// ComputeGraph manipulation, we need to make a copy of a Tensor).
50+
51+
REGISTER_OPERATORS {
52+
VK_REGISTER_OP(aten.clone.default, clone);
53+
}
54+
55+
} // namespace vkcompute

backends/vulkan/test/op_tests/cases.py

Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -339,6 +339,25 @@ def get_unsqueeze_inputs():
339339
return test_suite
340340

341341

342+
def get_clone_inputs():
343+
test_suite = VkTestSuite(
344+
[
345+
((S2, S1, S2, S1),),
346+
((S2, S1, S2),),
347+
((S2, S1),),
348+
((S2,),),
349+
((XS, S1, XS, S1),),
350+
((XS, S1, XS),),
351+
((S1, XS, S1),),
352+
((XS, S1),),
353+
((S1, XS),),
354+
((S1,),),
355+
((XS,),),
356+
]
357+
)
358+
return test_suite
359+
360+
342361
test_suites = {
343362
"aten.add.Tensor": get_binary_elementwise_inputs(),
344363
"aten.sub.Tensor": get_binary_elementwise_inputs(),
@@ -356,4 +375,5 @@ def get_unsqueeze_inputs():
356375
"aten.view_copy.default": get_view_inputs(),
357376
"aten.slice_copy.Tensor": get_slice_inputs(),
358377
"aten.unsqueeze_copy.default": get_unsqueeze_inputs(),
378+
"aten.clone.default": get_clone_inputs(),
359379
}

backends/vulkan/test/op_tests/utils/codegen.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -21,7 +21,8 @@
2121
OPT_DEVICE,
2222
OPT_INT64,
2323
OPT_LAYOUT,
24-
OPT_SCALARTYPE,
24+
OPT_MEMORY_FORMAT,
25+
OPT_SCALAR_TYPE,
2526
TestSuite,
2627
TestSuiteGen,
2728
THREE_TENSOR_TUPLE,
@@ -250,10 +251,11 @@ def create_value_for(self, ref: ValueRefList) -> str: # noqa: C901
250251
elif ref.src_cpp_type == DOUBLE:
251252
ret_str += f"add_scalar<double>({ref.src_cpp_name}); \n"
252253
elif (
253-
ref.src_cpp_type == OPT_SCALARTYPE
254+
ref.src_cpp_type == OPT_SCALAR_TYPE
254255
or ref.src_cpp_type == OPT_LAYOUT
255256
or ref.src_cpp_type == OPT_DEVICE
256257
or ref.src_cpp_type == OPT_BOOL
258+
or ref.src_cpp_type == OPT_MEMORY_FORMAT
257259
):
258260
ret_str += "add_none(); \n"
259261
elif ref.src_cpp_type == TWO_TENSOR_TUPLE:

backends/vulkan/test/op_tests/utils/codegen_base.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -25,7 +25,8 @@
2525
OPT_INT64 = "::std::optional<int64_t>"
2626
OPT_DEVICE = "::std::optional<at::Device>"
2727
OPT_LAYOUT = "::std::optional<at::Layout>"
28-
OPT_SCALARTYPE = "::std::optional<at::ScalarType>"
28+
OPT_MEMORY_FORMAT = "::std::optional<at::MemoryFormat>"
29+
OPT_SCALAR_TYPE = "::std::optional<at::ScalarType>"
2930
TWO_TENSOR_TUPLE = "::std::tuple<at::Tensor,at::Tensor>"
3031
THREE_TENSOR_TUPLE = "::std::tuple<at::Tensor,at::Tensor,at::Tensor>"
3132

@@ -149,10 +150,11 @@ def create_input_data(self, arg: Argument, data: Any) -> str: # noqa: C901
149150
else:
150151
ret_str += f"{str(data)};"
151152
elif (
152-
cpp_type == OPT_SCALARTYPE
153+
cpp_type == OPT_SCALAR_TYPE
153154
or cpp_type == OPT_LAYOUT
154155
or cpp_type == OPT_DEVICE
155156
or cpp_type == OPT_BOOL
157+
or cpp_type == OPT_MEMORY_FORMAT
156158
):
157159
ret_str += "std::nullopt;"
158160
else:

0 commit comments

Comments
 (0)