Skip to content

Commit 4bbbde3

Browse files
committed
[ET-VK][13/n] split_with_sizes with more test codegen
Life is fun when the code-gen is more challenging than writing the operator itself. 1. Test codegen update to include vector of Tensor as output: `Tensor(a)[]`. 2. `aten.split_with_sizes.default` 3. `aten.split.Tensor` 4. Improve `DimUtils` for better dimension reasoning. Differential Revision: [D56660525](https://our.internmc.facebook.com/intern/diff/D56660525/) [ghstack-poisoned]
1 parent c90d054 commit 4bbbde3

File tree

6 files changed

+353
-35
lines changed

6 files changed

+353
-35
lines changed
Lines changed: 143 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,143 @@
1+
/*
2+
* Copyright (c) Meta Platforms, Inc. and affiliates.
3+
* All rights reserved.
4+
*
5+
* This source code is licensed under the BSD-style license found in the
6+
* LICENSE file in the root directory of this source tree.
7+
*/
8+
9+
#include <executorch/backends/vulkan/runtime/graph/ops/OperatorRegistry.h>
10+
11+
#include <executorch/backends/vulkan/runtime/api/api.h>
12+
13+
#include <executorch/backends/vulkan/runtime/graph/ops/impl/Copy.h>
14+
15+
#include <executorch/backends/vulkan/runtime/graph/ops/impl/utils/DimUtils.h>
16+
#include <executorch/backends/vulkan/runtime/graph/ops/impl/utils/KernelUtils.h>
17+
#include <executorch/backends/vulkan/runtime/graph/ops/impl/utils/TensorUtils.h>
18+
#include <executorch/backends/vulkan/runtime/graph/ops/utils/ShaderNameUtils.h>
19+
20+
namespace vkcompute {
21+
22+
void add_split_with_sizes_default_node(
23+
ComputeGraph& graph,
24+
ValueRef in,
25+
const std::vector<int64_t>& split_sizes,
26+
int64_t dim,
27+
ValueRef out_list_ref) {
28+
vTensorPtr t_in = graph.get_tensor(in);
29+
30+
VK_CHECK_COND(check_memory_layout_is(*t_in, api::kChannelsPacked));
31+
32+
ValueListPtr out_list = graph.get_value_list(out_list_ref);
33+
34+
NchwDim nchw_dim = normalize_to_nchw_dim(*t_in, dim);
35+
36+
VK_CHECK_COND(out_list->size() == split_sizes.size());
37+
38+
for (int split_idx = 0; split_idx < split_sizes.size(); split_idx++) {
39+
int64_t split_size = split_sizes[split_idx];
40+
ValueRef out_ref = (*out_list)[split_idx];
41+
42+
vTensorPtr t_out = graph.get_tensor(out_ref);
43+
VK_CHECK_COND(check_memory_layout_is(*t_out, api::kChannelsPacked));
44+
VK_CHECK_COND(dim_at(*t_out, nchw_dim) == split_size);
45+
}
46+
47+
if (nchw_dim == DimWidth) {
48+
api::utils::ivec3 src_offset = api::utils::make_ivec3({0, 0, 0}, false);
49+
api::utils::ivec3 dst_offset = api::utils::make_ivec3({0, 0, 0}, false);
50+
51+
for (ValueRef out_ref : *out_list) {
52+
// Doesn't need to use split_size since we have already verified that the
53+
// output tensor's size matches with the split_size.
54+
vTensorPtr t_out = graph.get_tensor(out_ref);
55+
api::utils::ivec3 range = t_out->texture_limits();
56+
add_copy_offset_node(graph, in, range, src_offset, dst_offset, out_ref);
57+
58+
src_offset.data[0] += range.data[0];
59+
}
60+
} else if (nchw_dim == DimHeight) {
61+
api::utils::ivec3 src_offset = api::utils::make_ivec3({0, 0, 0}, false);
62+
api::utils::ivec3 dst_offset = api::utils::make_ivec3({0, 0, 0}, false);
63+
64+
for (ValueRef out_ref : *out_list) {
65+
vTensorPtr t_out = graph.get_tensor(out_ref);
66+
api::utils::ivec3 range = t_out->texture_limits();
67+
add_copy_offset_node(graph, in, range, src_offset, dst_offset, out_ref);
68+
69+
src_offset.data[1] += range.data[1];
70+
}
71+
} else if (nchw_dim == DimBatch) {
72+
api::utils::ivec3 src_offset = api::utils::make_ivec3({0, 0, 0}, false);
73+
api::utils::ivec3 dst_offset = api::utils::make_ivec3({0, 0, 0}, false);
74+
75+
for (ValueRef out_ref : *out_list) {
76+
vTensorPtr t_out = graph.get_tensor(out_ref);
77+
api::utils::ivec3 range = t_out->texture_limits();
78+
add_copy_offset_node(graph, in, range, src_offset, dst_offset, out_ref);
79+
80+
src_offset.data[2] += range.data[2];
81+
}
82+
} else if (nchw_dim == DimChannel) {
83+
int32_t src_offset = 0;
84+
int32_t dst_offset = 0;
85+
86+
for (ValueRef out_ref : *out_list) {
87+
vTensorPtr t_out = graph.get_tensor(out_ref);
88+
int32_t range = dim_at<Dim4D::Channel>(t_out->sizes());
89+
add_copy_channel_offset_node(
90+
graph, in, range, src_offset, dst_offset, out_ref);
91+
src_offset += range;
92+
}
93+
94+
} else {
95+
VK_THROW("not ipmlemented");
96+
}
97+
}
98+
99+
void add_split_with_sizes_default_node(
100+
ComputeGraph& graph,
101+
ValueRef in,
102+
ValueRef split_sizes_ref,
103+
ValueRef dim_ref,
104+
ValueRef out) {
105+
int64_t dim = graph.extract_scalar<int64_t>(dim_ref);
106+
std::vector<int64_t> split_sizes = *(graph.get_int_list(split_sizes_ref));
107+
108+
add_split_with_sizes_default_node(graph, in, split_sizes, dim, out);
109+
}
110+
111+
void split_with_sizes_default(
112+
ComputeGraph& graph,
113+
const std::vector<ValueRef>& args) {
114+
add_split_with_sizes_default_node(graph, args[0], args[1], args[2], args[3]);
115+
}
116+
117+
void add_split_tensor_node(
118+
ComputeGraph& graph,
119+
ValueRef in,
120+
ValueRef split_size_ref,
121+
ValueRef dim_ref,
122+
ValueRef out) {
123+
int64_t split_size = graph.extract_scalar<int64_t>(split_size_ref);
124+
int64_t dim = graph.extract_scalar<int64_t>(dim_ref);
125+
126+
vTensorPtr t_in = graph.get_tensor(in);
127+
NchwDim nchw_dim = normalize_to_nchw_dim(*t_in, dim);
128+
int64_t size = dim_at(*t_in, nchw_dim);
129+
std::vector<int64_t> split_sizes(size / split_size, split_size);
130+
131+
add_split_with_sizes_default_node(graph, in, split_sizes, dim, out);
132+
}
133+
134+
void split_tensor(ComputeGraph& graph, const std::vector<ValueRef>& args) {
135+
add_split_tensor_node(graph, args[0], args[1], args[2], args[3]);
136+
}
137+
138+
REGISTER_OPERATORS {
139+
VK_REGISTER_OP(aten.split_with_sizes.default, split_with_sizes_default);
140+
VK_REGISTER_OP(aten.split.Tensor, split_tensor);
141+
}
142+
143+
} // namespace vkcompute

backends/vulkan/runtime/graph/ops/impl/utils/DimUtils.h

Lines changed: 34 additions & 27 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,27 @@
1212

1313
namespace vkcompute {
1414

15+
// A canonical way to represent dimensions as enum. Motivation behind a
16+
// canonical enum is that in the user tensor, it is using a "big-endian"-ish
17+
// mechanism to reference a dimension in a nchw-tensor, leading to tensor of
18+
// different dimension have different mapping from dim to the underlying texture
19+
// dimension. For instasnce, for a 2d (height x width) tensors, dim 0 refers to
20+
// height and dim 1 refers to width; for a 4d (batch x channel x height x width)
21+
// tensor, dim 0 refers to batch and dim 1 refers to channel. Using this
22+
// canonical enum allows us to bring clarity in code.
23+
24+
enum NchwDim : uint32_t {
25+
DimWidth = 1u,
26+
DimHeight = 2u,
27+
DimChannel = 3u,
28+
DimBatch = 4u,
29+
};
30+
31+
// Convert a dim provided by user into canonical enum.
32+
inline NchwDim normalize_to_nchw_dim(const vTensor& v_in, int32_t dim) {
33+
return static_cast<NchwDim>(v_in.dim() - dim);
34+
}
35+
1536
/*
1637
* Maps a semantic dimension name to an integer that
1738
* corresponds to its innermost ordering in a 4D tensor in
@@ -20,10 +41,10 @@ namespace vkcompute {
2041
* corresponds to 2, and so on.
2142
*/
2243
struct Dim4D {
23-
static constexpr uint32_t Width = 1u;
24-
static constexpr uint32_t Height = 2u;
25-
static constexpr uint32_t Channel = 3u;
26-
static constexpr uint32_t Batch = 4u;
44+
static constexpr uint32_t Width = DimWidth;
45+
static constexpr uint32_t Height = DimHeight;
46+
static constexpr uint32_t Channel = DimChannel;
47+
static constexpr uint32_t Batch = DimBatch;
2748
};
2849

2950
/*
@@ -65,34 +86,20 @@ uint32_t dim_at(const std::vector<int64_t>& sizes) {
6586
return dims < N ? 1 : api::utils::safe_downcast<uint32_t>(sizes[dims - N]);
6687
}
6788

89+
inline uint32_t dim_at(const std::vector<int64_t>& sizes, NchwDim nchw_dim) {
90+
const uint32_t dims = sizes.size();
91+
return dims < nchw_dim
92+
? 1
93+
: api::utils::safe_downcast<uint32_t>(sizes[dims - nchw_dim]);
94+
}
95+
6896
template <uint32_t N>
6997
uint32_t dim_at(const vTensor& v_in) {
7098
return dim_at<N>(v_in.sizes());
7199
}
72100

73-
// A canonical way to represent dimensions as enum. Intended to use the same
74-
// value as Dim4D for potential future refactoring.
75-
76-
enum NchwDim {
77-
DimWidth = 1,
78-
DimHeight = 2,
79-
DimChannel = 3,
80-
DimBatch = 4,
81-
};
82-
83-
/* This function return a NchwDim
84-
* given a Tensor and a user provided dim. The reason for this normalization is
85-
* that in the user tensor coordinate, it is using a "big-endian" mechanism when
86-
* referring to a nchw dimension, in that dim=0 refers to the batch dimension in
87-
* a 4d tensor but dim=0 reference to height in a 2d tensor. Despite in a common
88-
* texture representation of channel packing, a 2d tensor has exactly the same
89-
* layout as a 4d with the batch and channel size equals to 1. This function
90-
* returns a canonical dimension to simplify dimension reasoning in the code.
91-
*
92-
*/
93-
94-
inline NchwDim normalize_to_nchw_dim(const vTensor& v_in, int32_t dim) {
95-
return static_cast<NchwDim>(v_in.dim() - dim);
101+
inline uint32_t dim_at(const vTensor& v_in, NchwDim nchw_dim) {
102+
return dim_at(v_in.sizes(), nchw_dim);
96103
}
97104

98105
inline std::ostream& operator<<(std::ostream& os, NchwDim nchw_dim) {

backends/vulkan/runtime/graph/ops/impl/utils/TensorUtils.cpp

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -8,8 +8,6 @@
88

99
#include <executorch/backends/vulkan/runtime/graph/ops/impl/utils/TensorUtils.h>
1010

11-
#include <executorch/backends/vulkan/runtime/graph/ops/impl/utils/DimUtils.h>
12-
1311
namespace vkcompute {
1412

1513
//

backends/vulkan/test/op_tests/cases.py

Lines changed: 89 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -438,9 +438,7 @@ def get_cat_inputs():
438438
([(3, 5), (4, 5)], 0),
439439
([(3, 5), (4, 5), (1, 5)], 0),
440440
(
441-
[
442-
(3, 5),
443-
],
441+
[(3, 5)],
444442
0,
445443
),
446444
# Cat on Width
@@ -449,9 +447,7 @@ def get_cat_inputs():
449447
([(5, 3), (5, 4)], 1),
450448
([(5, 3), (5, 4), (5, 1)], 1),
451449
(
452-
[
453-
(5, 4),
454-
],
450+
[(5, 4)],
455451
1,
456452
),
457453
([(5,), (6,)], 0),
@@ -474,6 +470,91 @@ def get_cat_inputs():
474470
return test_suite
475471

476472

473+
def get_split_with_sizes_inputs():
474+
Test = namedtuple("VkSliceTest", ["self", "sizes", "dim"])
475+
test_cases = [
476+
# Split on Width
477+
Test(self=(S1, 7, 10, 10), sizes=[1, 2, 3, 4], dim=3),
478+
Test(self=(7, 10, 10), sizes=[1, 2, 3, 4], dim=2),
479+
Test(self=(7, 10, 10), sizes=[1, 9], dim=2),
480+
Test(self=(10, 10), sizes=[1, 9], dim=1),
481+
Test(self=(10,), sizes=[1, 9], dim=0),
482+
# Split on Height
483+
Test(self=(S1, 7, 10, 10), sizes=[1, 2, 3, 4], dim=2),
484+
Test(self=(7, 10, 10), sizes=[1, 2, 3, 4], dim=1),
485+
Test(self=(7, 10, 10), sizes=[10], dim=1),
486+
Test(self=(7, 6, 10), sizes=[1, 1, 1, 1, 1, 1], dim=1),
487+
Test(self=(10, 10), sizes=[1, 2, 3, 4], dim=0),
488+
# Split on Batch
489+
Test(self=(10, 7, 10, 10), sizes=[3, 6, 1], dim=0),
490+
Test(self=(10, 7, 10, 10), sizes=[10], dim=0),
491+
# Split on Channel
492+
Test(self=(7, 13, 4, 8), sizes=[3, 6, 1, 3], dim=1),
493+
Test(self=(7, 13, 4, 8), sizes=[3, 3, 3, 3, 1], dim=1),
494+
Test(self=(13, 4, 8), sizes=[3, 3, 3, 3, 1], dim=0),
495+
Test(self=(13, 4, 8), sizes=[2, 9, 2], dim=0),
496+
Test(self=(13, 4, 8), sizes=[13], dim=0),
497+
]
498+
test_suite = VkTestSuite([tuple(tc) for tc in test_cases])
499+
500+
test_suite.layouts = [
501+
"api::kChannelsPacked",
502+
]
503+
test_suite.data_gen = "make_seq_tensor"
504+
test_suite.dtypes = ["at::kFloat"]
505+
return test_suite
506+
507+
508+
def get_split_tensor_inputs():
509+
test_suite = VkTestSuite(
510+
[
511+
# Split on Width
512+
((S1, 7, 10, 12), 12, 3),
513+
((S1, 7, 10, 12), 3, 3),
514+
((S1, 7, 10, 12), 1, 3),
515+
((7, 10, 12), 12, 2),
516+
((7, 10, 12), 3, 2),
517+
((7, 10, 12), 1, 2),
518+
((10, 12), 12, 1),
519+
((10, 12), 3, 1),
520+
((10, 12), 1, 1),
521+
((12,), 12, 0),
522+
((12,), 3, 0),
523+
((12,), 1, 0),
524+
# Split on Height
525+
((S1, 7, 12, 8), 12, 2),
526+
((S1, 7, 12, 8), 3, 2),
527+
((S1, 7, 12, 8), 1, 2),
528+
((7, 12, 8), 12, 1),
529+
((7, 12, 8), 3, 1),
530+
((7, 12, 8), 1, 1),
531+
((12, 8), 12, 0),
532+
((12, 8), 3, 0),
533+
((12, 8), 1, 0),
534+
# Split on Batch
535+
((12, 7, 10, 10), 12, 0),
536+
((12, 7, 10, 10), 3, 0),
537+
((12, 7, 10, 10), 1, 0),
538+
# Split on Channel
539+
((7, 15, 10, 10), 15, 1),
540+
((7, 15, 10, 10), 5, 1),
541+
((7, 15, 10, 10), 3, 1),
542+
((7, 15, 10, 10), 1, 1),
543+
((15, 10, 10), 15, 0),
544+
((15, 10, 10), 5, 0),
545+
((15, 10, 10), 3, 0),
546+
((15, 10, 10), 1, 0),
547+
]
548+
)
549+
550+
test_suite.layouts = [
551+
"api::kChannelsPacked",
552+
]
553+
test_suite.data_gen = "make_seq_tensor"
554+
test_suite.dtypes = ["at::kFloat"]
555+
return test_suite
556+
557+
477558
test_suites = {
478559
"aten.add.Tensor": get_binary_elementwise_inputs(),
479560
"aten.sub.Tensor": get_binary_elementwise_inputs(),
@@ -494,4 +575,6 @@ def get_cat_inputs():
494575
"aten.clone.default": get_clone_inputs(),
495576
"aten.repeat.default": get_repeat_inputs(),
496577
"aten.cat.default": get_cat_inputs(),
578+
"aten.split_with_sizes.default": get_split_with_sizes_inputs(),
579+
"aten.split.Tensor": get_split_tensor_inputs(),
497580
}

0 commit comments

Comments
 (0)