Skip to content

Commit 0bbca93

Browse files
committed
[ET-VK][14/n] aten.t, aten._to_copy, aten.contiguous
3 more trivial operators. For `_to_copy` and `contiguous`, since the vulkan memory layout is different from that of CPU, we ignore these arguments. There is one exception of changing `dtype`, we will add this feature when needed. Differential Revision: [D56666219](https://our.internmc.facebook.com/intern/diff/D56666219/) ghstack-source-id: 224258877 Pull Request resolved: #3390
1 parent 23e04e2 commit 0bbca93

File tree

5 files changed

+94
-1
lines changed

5 files changed

+94
-1
lines changed

backends/vulkan/runtime/graph/ops/impl/Clone.cpp

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -42,6 +42,20 @@ void clone(ComputeGraph& graph, const std::vector<ValueRef>& args) {
4242
return add_clone_node(graph, args[0], args[2]);
4343
}
4444

45+
void contiguous(ComputeGraph& graph, const std::vector<ValueRef>& args) {
46+
// The vulkan delegate does not support changing memory format.
47+
return add_clone_node(graph, args[0], args[2]);
48+
}
49+
50+
void _to_copy(ComputeGraph& graph, const std::vector<ValueRef>& args) {
51+
// All arguments are ignored for the time being.
52+
// _to_copy(Tensor self, *, ScalarType? dtype=None, Layout? layout=None,
53+
// Device? device=None, bool? pin_memory=None, bool non_blocking=False,
54+
// MemoryFormat? memory_format=None) -> Tensor
55+
56+
return add_clone_node(graph, args[0], args[7]);
57+
}
58+
4559
// Clone node is not the most efficient implementation for the aten.clone
4660
// operation. A more efficient implementation can be achieved during vulkan
4761
// export with the use of shared object. This clone node is introduced to enable
@@ -50,6 +64,8 @@ void clone(ComputeGraph& graph, const std::vector<ValueRef>& args) {
5064

5165
REGISTER_OPERATORS {
5266
VK_REGISTER_OP(aten.clone.default, clone);
67+
VK_REGISTER_OP(aten.contiguous.default, contiguous);
68+
VK_REGISTER_OP(aten._to_copy.default, _to_copy);
5369
}
5470

5571
} // namespace vkcompute
Lines changed: 46 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,46 @@
1+
/*
2+
* Copyright (c) Meta Platforms, Inc. and affiliates.
3+
* All rights reserved.
4+
*
5+
* This source code is licensed under the BSD-style license found in the
6+
* LICENSE file in the root directory of this source tree.
7+
*/
8+
9+
#include <executorch/backends/vulkan/runtime/graph/ops/OperatorRegistry.h>
10+
11+
#include <executorch/backends/vulkan/runtime/graph/ops/impl/Permute.h>
12+
13+
#include <executorch/backends/vulkan/runtime/graph/ops/impl/utils/TensorUtils.h>
14+
15+
namespace vkcompute {
16+
17+
void add_t_default_node(ComputeGraph& graph, ValueRef in, ValueRef out) {
18+
vTensorPtr t_in = graph.get_tensor(in);
19+
20+
VK_CHECK_COND(check_memory_layout_is(*t_in, api::kChannelsPacked));
21+
22+
// TODO: Verify 0-dim tensor
23+
VK_CHECK_COND(
24+
(1 <= t_in->dim()) && (t_in->dim() <= 2),
25+
"aten.t tensor must be 1d or 2d");
26+
27+
std::vector<int64_t> permute_dims;
28+
if (t_in->dim() == 1) {
29+
permute_dims.emplace_back(0);
30+
} else {
31+
permute_dims.emplace_back(1);
32+
permute_dims.emplace_back(0);
33+
}
34+
35+
add_permute_node(graph, in, permute_dims, out);
36+
}
37+
38+
void t_default(ComputeGraph& graph, const std::vector<ValueRef>& args) {
39+
add_t_default_node(graph, args[0], args[1]);
40+
}
41+
42+
REGISTER_OPERATORS {
43+
VK_REGISTER_OP(aten.t.default, t_default);
44+
}
45+
46+
} // namespace vkcompute

backends/vulkan/test/op_tests/cases.py

Lines changed: 22 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -555,6 +555,25 @@ def get_split_tensor_inputs():
555555
return test_suite
556556

557557

558+
def get_t_inputs():
559+
test_suite = VkTestSuite(
560+
[
561+
((1, S1),),
562+
((S1, 1),),
563+
((S2, S2),),
564+
((S2, S1),),
565+
((S1, S2),),
566+
((S1,),),
567+
((1,),),
568+
]
569+
)
570+
test_suite.layouts = [
571+
"api::kChannelsPacked",
572+
]
573+
test_suite.data_gen = "make_seq_tensor"
574+
return test_suite
575+
576+
558577
test_suites = {
559578
"aten.add.Tensor": get_binary_elementwise_inputs(),
560579
"aten.sub.Tensor": get_binary_elementwise_inputs(),
@@ -573,8 +592,11 @@ def get_split_tensor_inputs():
573592
"aten.slice_copy.Tensor": get_slice_inputs(),
574593
"aten.unsqueeze_copy.default": get_unsqueeze_inputs(),
575594
"aten.clone.default": get_clone_inputs(),
595+
"aten.contiguous.default": get_clone_inputs(),
596+
"aten._to_copy.default": get_clone_inputs(),
576597
"aten.repeat.default": get_repeat_inputs(),
577598
"aten.cat.default": get_cat_inputs(),
578599
"aten.split_with_sizes.default": get_split_with_sizes_inputs(),
579600
"aten.split.Tensor": get_split_tensor_inputs(),
601+
"aten.t.default": get_t_inputs(),
580602
}

backends/vulkan/test/op_tests/utils/codegen.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,7 @@
1717
CppTestFileGen,
1818
DOUBLE,
1919
INT,
20+
MEMORY_FORMAT,
2021
OPT_AT_TENSOR,
2122
OPT_BOOL,
2223
OPT_DEVICE,
@@ -231,7 +232,7 @@ def create_aten_method_call(self) -> str:
231232
# at::_ops::{name}::call(*), and ATEN_FN is a handly macro.
232233
cpp_sig = gen_static_dispatch_backend_call_signature(self.f_sig, self.f)
233234
exprs = translate_args(self.f_sig, cpp_sig)
234-
func_call = f"ATEN_FN({self.f_sig.name()})({exprs});"
235+
func_call = f"ATEN_FN({self.f_sig.func.name})({exprs});"
235236
return func_call
236237

237238
def create_out_src(self) -> str:
@@ -342,6 +343,7 @@ def create_value_for(self, ref: ValueRefList) -> str: # noqa: C901
342343
or ref.src_cpp_type == OPT_DEVICE
343344
or ref.src_cpp_type == OPT_BOOL
344345
or ref.src_cpp_type == OPT_MEMORY_FORMAT
346+
or ref.src_cpp_type == MEMORY_FORMAT
345347
):
346348
ret_str += "add_none(); \n"
347349
elif ref.src_cpp_type == TWO_TENSOR_TUPLE:

backends/vulkan/test/op_tests/utils/codegen_base.py

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,7 @@
2222
BOOL = "bool"
2323
DOUBLE = "double"
2424
INT = "int64_t"
25+
MEMORY_FORMAT = "at::MemoryFormat"
2526
OPT_AT_TENSOR = "::std::optional<at::Tensor>"
2627
OPT_BOOL = "::std::optional<bool>"
2728
OPT_INT64 = "::std::optional<int64_t>"
@@ -174,6 +175,8 @@ def create_input_data(self, arg: Argument, data: Any) -> str: # noqa: C901
174175
or cpp_type == OPT_MEMORY_FORMAT
175176
):
176177
ret_str += "std::nullopt;"
178+
elif cpp_type == MEMORY_FORMAT:
179+
ret_str += "at::MemoryFormat::Contiguous;"
177180
else:
178181
raise RuntimeError(f"Unsupported cpp type {cpp_type}")
179182
return ret_str + "\n"
@@ -267,6 +270,10 @@ def generate_suite_cpp(self) -> str:
267270
return at::from_blob(values.data(), sizes, at::kFloat).toType(dtype).detach().clone();
268271
}}
269272
273+
274+
// torchgen assumes the "at" namespace is used for function default arguments.
275+
using at::MemoryFormat;
276+
270277
{test_suites_cpp}
271278
"""
272279

0 commit comments

Comments
 (0)