Skip to content

Commit b27a82e

Browse files
committed
[ET-VK][9/n] clone node
Introduce a clone node for copy operation. Also register `aten.clone` to this node. Important to note that during model export, possible to point the lvalue of `aten.clone` to the underlying shared object of the rvalue to achieve no-copy. Differential Revision: [D56441547](https://our.internmc.facebook.com/intern/diff/D56441547/) ghstack-source-id: 223471608 Pull Request resolved: #3219
1 parent e1de9eb commit b27a82e

File tree

6 files changed

+120
-0
lines changed

6 files changed

+120
-0
lines changed
Lines changed: 30 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,30 @@
1+
/*
2+
* Copyright (c) Meta Platforms, Inc. and affiliates.
3+
* All rights reserved.
4+
*
5+
* This source code is licensed under the BSD-style license found in the
6+
* LICENSE file in the root directory of this source tree.
7+
*/
8+
9+
#version 450 core
10+
11+
#define PRECISION ${PRECISION}
12+
13+
layout(std430) buffer;
14+
15+
layout(set = 0, binding = 0, ${IMAGE_FORMAT[DTYPE]}) uniform PRECISION restrict writeonly ${IMAGE_T[NDIM][DTYPE]} image_out;
16+
layout(set = 0, binding = 1) uniform PRECISION sampler3D image_in;
17+
18+
layout(set = 0, binding = 2) uniform PRECISION restrict OutLimits {
19+
ivec3 out_limits;
20+
};
21+
22+
layout(local_size_x_id = 0, local_size_y_id = 1, local_size_z_id = 2) in;
23+
24+
void main() {
25+
ivec3 pos = ivec3(gl_GlobalInvocationID);
26+
if (any(greaterThanEqual(pos, out_limits))) {
27+
return;
28+
}
29+
imageStore(image_out, pos, texelFetch(image_in, pos, 0));
30+
}
Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,10 @@
1+
clone:
2+
parameter_names_with_default_values:
3+
DTYPE: float
4+
NDIM: 3
5+
generate_variant_forall:
6+
DTYPE:
7+
- VALUE: half
8+
- VALUE: float
9+
shader_variants:
10+
- NAME: clone
Lines changed: 56 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,56 @@
1+
/*
2+
* Copyright (c) Meta Platforms, Inc. and affiliates.
3+
* All rights reserved.
4+
*
5+
* This source code is licensed under the BSD-style license found in the
6+
* LICENSE file in the root directory of this source tree.
7+
*/
8+
9+
#include <executorch/backends/vulkan/runtime/graph/ops/OperatorRegistry.h>
10+
11+
#include <executorch/backends/vulkan/runtime/api/api.h>
12+
#include <executorch/backends/vulkan/runtime/graph/Logging.h>
13+
14+
#include <executorch/backends/vulkan/runtime/graph/ops/impl/utils/KernelUtils.h>
15+
#include <executorch/backends/vulkan/runtime/graph/ops/impl/utils/TensorUtils.h>
16+
#include <executorch/backends/vulkan/runtime/graph/ops/utils/ShaderNameUtils.h>
17+
18+
namespace vkcompute {
19+
20+
void add_clone_node(
21+
ComputeGraph& graph,
22+
const ValueRef in,
23+
const ValueRef out) {
24+
vTensorPtr t_out = graph.get_tensor(out);
25+
26+
std::string kernel_name = "clone";
27+
add_dtype_suffix(kernel_name, *t_out);
28+
29+
api::utils::uvec3 global_size = t_out->extents();
30+
api::utils::uvec3 local_size = adaptive_work_group_size(global_size);
31+
32+
graph.execute_nodes().emplace_back(new ExecuteNode(
33+
graph,
34+
VK_KERNEL_FROM_STR(kernel_name),
35+
global_size,
36+
local_size,
37+
{{out, api::MemoryAccessType::WRITE}, {in, api::MemoryAccessType::READ}},
38+
{t_out->texture_limits_ubo()}));
39+
}
40+
41+
void clone(ComputeGraph& graph, const std::vector<ValueRef>& args) {
42+
// The vulkan delegate does not support changing memory format.
43+
return add_clone_node(graph, args[0], args[2]);
44+
}
45+
46+
// Clone node is not the most efficient implementation for the aten.clone
47+
// operation. A more efficient implementation can be achieved during vulkan
48+
// export with the use of shared object. This clone node is introduced to enable
49+
// a "copy" mechanism if there is no alternative (e.g. during direct
50+
// ComputeGraph manipulation, we need to make a copy of a Tensor).
51+
52+
REGISTER_OPERATORS {
53+
VK_REGISTER_OP(aten.clone.default, clone);
54+
}
55+
56+
} // namespace vkcompute

backends/vulkan/test/op_tests/cases.py

Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -312,6 +312,25 @@ def get_slice_inputs():
312312
return test_suite
313313

314314

315+
def get_clone_inputs():
316+
test_suite = VkTestSuite(
317+
[
318+
((S2, S1, S2, S1),),
319+
((S2, S1, S2),),
320+
((S2, S1),),
321+
((S2,),),
322+
((XS, S1, XS, S1),),
323+
((XS, S1, XS),),
324+
((S1, XS, S1),),
325+
((XS, S1),),
326+
((S1, XS),),
327+
((S1,),),
328+
((XS,),),
329+
]
330+
)
331+
return test_suite
332+
333+
315334
test_suites = {
316335
"aten.add.Tensor": get_binary_elementwise_inputs(),
317336
"aten.sub.Tensor": get_binary_elementwise_inputs(),
@@ -328,4 +347,5 @@ def get_slice_inputs():
328347
"aten.permute_copy.default": get_permute_inputs(),
329348
"aten.view_copy.default": get_view_inputs(),
330349
"aten.slice_copy.Tensor": get_slice_inputs(),
350+
"aten.clone.default": get_clone_inputs(),
331351
}

backends/vulkan/test/op_tests/utils/codegen.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,7 @@
2121
OPT_DEVICE,
2222
OPT_INT64,
2323
OPT_LAYOUT,
24+
OPT_MEMORYFORMAT,
2425
OPT_SCALARTYPE,
2526
TestSuite,
2627
TestSuiteGen,
@@ -254,6 +255,7 @@ def create_value_for(self, ref: ValueRefList) -> str: # noqa: C901
254255
or ref.src_cpp_type == OPT_LAYOUT
255256
or ref.src_cpp_type == OPT_DEVICE
256257
or ref.src_cpp_type == OPT_BOOL
258+
or ref.src_cpp_type == OPT_MEMORYFORMAT
257259
):
258260
ret_str += "add_none(); \n"
259261
elif ref.src_cpp_type == TWO_TENSOR_TUPLE:

backends/vulkan/test/op_tests/utils/codegen_base.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -25,6 +25,7 @@
2525
OPT_INT64 = "::std::optional<int64_t>"
2626
OPT_DEVICE = "::std::optional<at::Device>"
2727
OPT_LAYOUT = "::std::optional<at::Layout>"
28+
OPT_MEMORYFORMAT = "::std::optional<at::MemoryFormat>"
2829
OPT_SCALARTYPE = "::std::optional<at::ScalarType>"
2930
TWO_TENSOR_TUPLE = "::std::tuple<at::Tensor,at::Tensor>"
3031
THREE_TENSOR_TUPLE = "::std::tuple<at::Tensor,at::Tensor,at::Tensor>"
@@ -153,6 +154,7 @@ def create_input_data(self, arg: Argument, data: Any) -> str: # noqa: C901
153154
or cpp_type == OPT_LAYOUT
154155
or cpp_type == OPT_DEVICE
155156
or cpp_type == OPT_BOOL
157+
or cpp_type == OPT_MEMORYFORMAT
156158
):
157159
ret_str += "std::nullopt;"
158160
else:

0 commit comments

Comments
 (0)