Skip to content

Commit 7a46ac7

Browse files
committed
[ET-VK][12/n] aten.cat with more codegen
1. The `aten.cat` operation are very straightforward using the `copy_*_node`. 2. Complexity comes from the code-gen. We need to introduce a `AT_TENSOR_LIST` type, which contains a list of `AT_TENSOR` with `is_in=True`. The tensor list itself as a container is not an `IOValueRef`, but the element inside are. It leads to some ugly if-then-else in the codegen. Differential Revision: [D56626865](https://our.internmc.facebook.com/intern/diff/D56626865/) [ghstack-poisoned]
1 parent d8191ee commit 7a46ac7

File tree

7 files changed

+309
-24
lines changed

7 files changed

+309
-24
lines changed
Lines changed: 103 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,103 @@
1+
/*
2+
* Copyright (c) Meta Platforms, Inc. and affiliates.
3+
* All rights reserved.
4+
*
5+
* This source code is licensed under the BSD-style license found in the
6+
* LICENSE file in the root directory of this source tree.
7+
*/
8+
9+
#include <executorch/backends/vulkan/runtime/graph/ops/OperatorRegistry.h>
10+
11+
#include <executorch/backends/vulkan/runtime/api/api.h>
12+
13+
#include <executorch/backends/vulkan/runtime/graph/ops/impl/Copy.h>
14+
#include <executorch/backends/vulkan/runtime/graph/ops/impl/utils/DimUtils.h>
15+
#include <executorch/backends/vulkan/runtime/graph/ops/impl/utils/KernelUtils.h>
16+
#include <executorch/backends/vulkan/runtime/graph/ops/impl/utils/TensorUtils.h>
17+
#include <executorch/backends/vulkan/runtime/graph/ops/utils/ShaderNameUtils.h>
18+
19+
namespace vkcompute {
20+
21+
void add_cat_default_node(
22+
ComputeGraph& graph,
23+
ValueRef in_list_ref,
24+
ValueRef dim_ref,
25+
ValueRef out) {
26+
ValueListPtr input_list = graph.get_value_list(in_list_ref);
27+
28+
for (ValueRef input_ref : *input_list) {
29+
vTensorPtr t_in = graph.get_tensor(input_ref);
30+
VK_CHECK_COND(check_memory_layout_is(*t_in, api::kChannelsPacked));
31+
}
32+
33+
int64_t dim = graph.extract_scalar<int64_t>(dim_ref);
34+
vTensorPtr t_out = graph.get_tensor(out);
35+
36+
/*
37+
for (ValueRef input_ref : *input_list) {
38+
vTensorPtr t_in = graph.get_tensor(input_ref);
39+
}
40+
*/
41+
42+
NchwDim nchw_dim = normalize_to_nchw_dim(*t_out, dim);
43+
44+
// TODO: Find ways to factor out the similar code for width, height, and batch
45+
if (nchw_dim == DimWidth) {
46+
api::utils::ivec3 src_offset = api::utils::make_ivec3({0, 0, 0}, false);
47+
api::utils::ivec3 dst_offset = api::utils::make_ivec3({0, 0, 0}, false);
48+
49+
for (ValueRef input_ref : *input_list) {
50+
vTensorPtr t_in = graph.get_tensor(input_ref);
51+
api::utils::ivec3 range = t_in->texture_limits();
52+
add_copy_offset_node(
53+
graph, input_ref, range, src_offset, dst_offset, out);
54+
dst_offset.data[0] += range.data[0];
55+
}
56+
57+
} else if (nchw_dim == DimHeight) {
58+
api::utils::ivec3 src_offset = api::utils::make_ivec3({0, 0, 0}, false);
59+
api::utils::ivec3 dst_offset = api::utils::make_ivec3({0, 0, 0}, false);
60+
61+
for (ValueRef input_ref : *input_list) {
62+
vTensorPtr t_in = graph.get_tensor(input_ref);
63+
api::utils::ivec3 range = t_in->texture_limits();
64+
add_copy_offset_node(
65+
graph, input_ref, range, src_offset, dst_offset, out);
66+
dst_offset.data[1] += range.data[1];
67+
}
68+
} else if (nchw_dim == DimBatch) {
69+
api::utils::ivec3 src_offset = api::utils::make_ivec3({0, 0, 0}, false);
70+
api::utils::ivec3 dst_offset = api::utils::make_ivec3({0, 0, 0}, false);
71+
72+
for (ValueRef input_ref : *input_list) {
73+
vTensorPtr t_in = graph.get_tensor(input_ref);
74+
api::utils::ivec3 range = t_in->texture_limits();
75+
add_copy_offset_node(
76+
graph, input_ref, range, src_offset, dst_offset, out);
77+
dst_offset.data[2] += range.data[2];
78+
}
79+
} else if (nchw_dim == DimChannel) {
80+
int32_t src_offset = 0;
81+
int32_t dst_offset = 0;
82+
83+
for (ValueRef input_ref : *input_list) {
84+
vTensorPtr t_in = graph.get_tensor(input_ref);
85+
int32_t range = dim_at<Dim4D::Channel>(t_in->sizes());
86+
add_copy_channel_offset_node(
87+
graph, input_ref, range, src_offset, dst_offset, out);
88+
dst_offset += range;
89+
}
90+
} else {
91+
VK_THROW("Unexpected value of nchw_dim=", nchw_dim);
92+
}
93+
}
94+
95+
void cat_default(ComputeGraph& graph, const std::vector<ValueRef>& args) {
96+
add_cat_default_node(graph, args[0], args[1], args[2]);
97+
}
98+
99+
REGISTER_OPERATORS {
100+
VK_REGISTER_OP(aten.cat.default, cat_default);
101+
}
102+
103+
} // namespace vkcompute

backends/vulkan/runtime/graph/ops/impl/Copy.cpp

Lines changed: 15 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -96,10 +96,23 @@ void add_copy_channel_offset_node(
9696

9797
VK_CHECK_COND(
9898
dim_at<Dim4D::Channel>(in_sizes) >= src_channel_offset + channel_range,
99-
"Source channel plus range should be less than or equal to input tensor's channel size");
99+
"Src channel (",
100+
src_channel_offset,
101+
") and range (",
102+
channel_range,
103+
") should be less than or equal to input tensor's channel size (",
104+
dim_at<Dim4D::Channel>(in_sizes),
105+
")");
106+
100107
VK_CHECK_COND(
101108
dim_at<Dim4D::Channel>(out_sizes) >= dst_channel_offset + channel_range,
102-
"Source channel and range should be less than or equal to input tensor's channel size");
109+
"Dst channel (",
110+
dst_channel_offset,
111+
") and range (",
112+
channel_range,
113+
") should be less than or equal to input tensor's channel size (",
114+
dim_at<Dim4D::Channel>(out_sizes),
115+
")");
103116

104117
VK_CHECK_COND(channel_range >= 0, "Channel range must be non-negative");
105118
VK_CHECK_COND(

backends/vulkan/runtime/graph/ops/impl/utils/DimUtils.h

Lines changed: 46 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -70,4 +70,50 @@ uint32_t dim_at(const vTensor& v_in) {
7070
return dim_at<N>(v_in.sizes());
7171
}
7272

73+
// A canonical way to represent dimensions as enum. Intended to use the same
74+
// value as Dim4D for potential future refactoring.
75+
76+
enum NchwDim {
77+
DimWidth = 1,
78+
DimHeight = 2,
79+
DimChannel = 3,
80+
DimBatch = 4,
81+
};
82+
83+
/* This function return a NchwDim
84+
* given a Tensor and a user provided dim. The reason for this normalization is
85+
* that in the user tensor coordinate, it is using a "big-endian" mechanism when
86+
* referring to a nchw dimension, in that dim=0 refers to the batch dimension in
87+
* a 4d tensor but dim=0 reference to height in a 2d tensor. Despite in a common
88+
* texture representation of channel packing, a 2d tensor has exactly the same
89+
* layout as a 4d with the batch and channel size equals to 1. This function
90+
* returns a canonical dimension to simplify dimension reasoning in the code.
91+
*
92+
*/
93+
94+
inline NchwDim normalize_to_nchw_dim(const vTensor& v_in, int32_t dim) {
95+
return static_cast<NchwDim>(v_in.dim() - dim);
96+
}
97+
98+
inline std::ostream& operator<<(std::ostream& os, NchwDim nchw_dim) {
99+
switch (nchw_dim) {
100+
case DimWidth:
101+
os << "DimWidth";
102+
break;
103+
case DimHeight:
104+
os << "DimHeight";
105+
break;
106+
case DimChannel:
107+
os << "DimChannel";
108+
break;
109+
case DimBatch:
110+
os << "DimBatch";
111+
break;
112+
default:
113+
os << "DimUnknown";
114+
break;
115+
}
116+
return os;
117+
}
118+
73119
} // namespace vkcompute

backends/vulkan/test/op_tests/cases.py

Lines changed: 47 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -428,6 +428,52 @@ def get_repeat_inputs():
428428
return test_suite
429429

430430

431+
def get_cat_inputs():
432+
# TensorList must be specified as list of tuples
433+
test_suite = VkTestSuite(
434+
[
435+
# Cat on Height
436+
([(S1, S1, 3, 5), (S1, S1, 4, 5)], 2),
437+
([(S1, 3, 5), (S1, 4, 5)], 1),
438+
([(3, 5), (4, 5)], 0),
439+
([(3, 5), (4, 5), (1, 5)], 0),
440+
(
441+
[
442+
(3, 5),
443+
],
444+
0,
445+
),
446+
# Cat on Width
447+
([(S1, S1, 5, 3), (S1, S1, 5, 4)], 3),
448+
([(S1, 5, 3), (S1, 5, 4)], 2),
449+
([(5, 3), (5, 4)], 1),
450+
([(5, 3), (5, 4), (5, 1)], 1),
451+
(
452+
[
453+
(5, 4),
454+
],
455+
1,
456+
),
457+
([(5,), (6,)], 0),
458+
# Cat on Batch
459+
([(S, S1, 5, 4), (S1, S1, 5, 4)], 0),
460+
([(S, XS, 5, 4), (S1, XS, 5, 4)], 0),
461+
([(S, S2, 5, 4), (S1, S2, 5, 4)], 0),
462+
# Cat on Channel
463+
([(S, 5, 4), (S1, 5, 4), (S2, 5, 4)], 0),
464+
([(XS, 5, 4), (XS, 5, 4), (S2, 5, 4)], 0),
465+
([(XS, S, 5, 4), (XS, S1, 5, 4), (XS, S2, 5, 4)], 1),
466+
([(XS, XS, 5, 4), (XS, XS, 5, 4), (XS, S2, 5, 4)], 1),
467+
]
468+
)
469+
test_suite.layouts = [
470+
"api::kChannelsPacked",
471+
]
472+
test_suite.data_gen = "make_seq_tensor"
473+
test_suite.dtypes = ["at::kFloat"]
474+
return test_suite
475+
476+
431477
test_suites = {
432478
"aten.add.Tensor": get_binary_elementwise_inputs(),
433479
"aten.sub.Tensor": get_binary_elementwise_inputs(),
@@ -447,4 +493,5 @@ def get_repeat_inputs():
447493
"aten.unsqueeze_copy.default": get_unsqueeze_inputs(),
448494
"aten.clone.default": get_clone_inputs(),
449495
"aten.repeat.default": get_repeat_inputs(),
496+
"aten.cat.default": get_cat_inputs(),
450497
}

backends/vulkan/test/op_tests/generate_op_tests.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,7 @@
1616
TestSuite,
1717
TestSuiteGen,
1818
)
19+
from torchgen import local
1920

2021
from torchgen.gen import parse_native_yaml, ParsedYaml
2122
from torchgen.model import DispatchKey, NativeFunction
@@ -45,6 +46,9 @@ def process_test_suites(
4546
cpp_generator.add_suite(registry_name, f, op_test_suite)
4647

4748

49+
@local.parametrize(
50+
use_const_ref_for_mutable_tensors=False, use_ilistref_for_tensor_lists=False
51+
)
4852
def generate_cpp(
4953
native_functions_yaml_path: str, tags_path: str, output_dir: str
5054
) -> None:

0 commit comments

Comments
 (0)