Skip to content

Commit 23e04e2

Browse files
committed
[ET-VK][13/n] split_with_sizes with more test codegen
Pull Request resolved: #3389 Life is fun when the code-gen is more challenging than writing the operator itself. 1. Test codegen update to include vector of Tensor as output: `Tensor(a)[]`. 2. `aten.split_with_sizes.default` 3. `aten.split.Tensor` 4. Improve `DimUtils` for better dimension reasoning. ghstack-source-id: 224258792 @exported-using-ghexport Differential Revision: [D56660525](https://our.internmc.facebook.com/intern/diff/D56660525/)
1 parent ee02c32 commit 23e04e2

File tree

6 files changed

+351
-35
lines changed

6 files changed

+351
-35
lines changed
Lines changed: 141 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,141 @@
1+
/*
2+
* Copyright (c) Meta Platforms, Inc. and affiliates.
3+
* All rights reserved.
4+
*
5+
* This source code is licensed under the BSD-style license found in the
6+
* LICENSE file in the root directory of this source tree.
7+
*/
8+
9+
#include <executorch/backends/vulkan/runtime/graph/ops/OperatorRegistry.h>
10+
11+
#include <executorch/backends/vulkan/runtime/graph/ops/impl/Copy.h>
12+
13+
#include <executorch/backends/vulkan/runtime/graph/ops/impl/utils/DimUtils.h>
14+
#include <executorch/backends/vulkan/runtime/graph/ops/impl/utils/KernelUtils.h>
15+
#include <executorch/backends/vulkan/runtime/graph/ops/impl/utils/TensorUtils.h>
16+
#include <executorch/backends/vulkan/runtime/graph/ops/utils/ShaderNameUtils.h>
17+
18+
namespace vkcompute {
19+
20+
void add_split_with_sizes_default_node(
21+
ComputeGraph& graph,
22+
ValueRef in,
23+
const std::vector<int64_t>& split_sizes,
24+
int64_t dim,
25+
ValueRef out_list_ref) {
26+
vTensorPtr t_in = graph.get_tensor(in);
27+
28+
VK_CHECK_COND(check_memory_layout_is(*t_in, api::kChannelsPacked));
29+
30+
ValueListPtr out_list = graph.get_value_list(out_list_ref);
31+
32+
NchwDim nchw_dim = normalize_to_nchw_dim(*t_in, dim);
33+
34+
VK_CHECK_COND(out_list->size() == split_sizes.size());
35+
36+
for (int split_idx = 0; split_idx < split_sizes.size(); split_idx++) {
37+
int64_t split_size = split_sizes[split_idx];
38+
ValueRef out_ref = (*out_list)[split_idx];
39+
40+
vTensorPtr t_out = graph.get_tensor(out_ref);
41+
VK_CHECK_COND(check_memory_layout_is(*t_out, api::kChannelsPacked));
42+
VK_CHECK_COND(dim_at(*t_out, nchw_dim) == split_size);
43+
}
44+
45+
if (nchw_dim == DimWidth) {
46+
api::utils::ivec3 src_offset = api::utils::make_ivec3({0, 0, 0}, false);
47+
api::utils::ivec3 dst_offset = api::utils::make_ivec3({0, 0, 0}, false);
48+
49+
for (ValueRef out_ref : *out_list) {
50+
// Doesn't need to use split_size since we have already verified that the
51+
// output tensor's size matches with the split_size.
52+
vTensorPtr t_out = graph.get_tensor(out_ref);
53+
api::utils::ivec3 range = t_out->texture_limits();
54+
add_copy_offset_node(graph, in, range, src_offset, dst_offset, out_ref);
55+
56+
src_offset.data[0] += range.data[0];
57+
}
58+
} else if (nchw_dim == DimHeight) {
59+
api::utils::ivec3 src_offset = api::utils::make_ivec3({0, 0, 0}, false);
60+
api::utils::ivec3 dst_offset = api::utils::make_ivec3({0, 0, 0}, false);
61+
62+
for (ValueRef out_ref : *out_list) {
63+
vTensorPtr t_out = graph.get_tensor(out_ref);
64+
api::utils::ivec3 range = t_out->texture_limits();
65+
add_copy_offset_node(graph, in, range, src_offset, dst_offset, out_ref);
66+
67+
src_offset.data[1] += range.data[1];
68+
}
69+
} else if (nchw_dim == DimBatch) {
70+
api::utils::ivec3 src_offset = api::utils::make_ivec3({0, 0, 0}, false);
71+
api::utils::ivec3 dst_offset = api::utils::make_ivec3({0, 0, 0}, false);
72+
73+
for (ValueRef out_ref : *out_list) {
74+
vTensorPtr t_out = graph.get_tensor(out_ref);
75+
api::utils::ivec3 range = t_out->texture_limits();
76+
add_copy_offset_node(graph, in, range, src_offset, dst_offset, out_ref);
77+
78+
src_offset.data[2] += range.data[2];
79+
}
80+
} else if (nchw_dim == DimChannel) {
81+
int32_t src_offset = 0;
82+
int32_t dst_offset = 0;
83+
84+
for (ValueRef out_ref : *out_list) {
85+
vTensorPtr t_out = graph.get_tensor(out_ref);
86+
int32_t range = dim_at<Dim4D::Channel>(t_out->sizes());
87+
add_copy_channel_offset_node(
88+
graph, in, range, src_offset, dst_offset, out_ref);
89+
src_offset += range;
90+
}
91+
92+
} else {
93+
VK_THROW("not ipmlemented");
94+
}
95+
}
96+
97+
void add_split_with_sizes_default_node(
98+
ComputeGraph& graph,
99+
ValueRef in,
100+
ValueRef split_sizes_ref,
101+
ValueRef dim_ref,
102+
ValueRef out) {
103+
int64_t dim = graph.extract_scalar<int64_t>(dim_ref);
104+
std::vector<int64_t> split_sizes = *(graph.get_int_list(split_sizes_ref));
105+
106+
add_split_with_sizes_default_node(graph, in, split_sizes, dim, out);
107+
}
108+
109+
void split_with_sizes_default(
110+
ComputeGraph& graph,
111+
const std::vector<ValueRef>& args) {
112+
add_split_with_sizes_default_node(graph, args[0], args[1], args[2], args[3]);
113+
}
114+
115+
void add_split_tensor_node(
116+
ComputeGraph& graph,
117+
ValueRef in,
118+
ValueRef split_size_ref,
119+
ValueRef dim_ref,
120+
ValueRef out) {
121+
int64_t split_size = graph.extract_scalar<int64_t>(split_size_ref);
122+
int64_t dim = graph.extract_scalar<int64_t>(dim_ref);
123+
124+
vTensorPtr t_in = graph.get_tensor(in);
125+
NchwDim nchw_dim = normalize_to_nchw_dim(*t_in, dim);
126+
int64_t size = dim_at(*t_in, nchw_dim);
127+
std::vector<int64_t> split_sizes(size / split_size, split_size);
128+
129+
add_split_with_sizes_default_node(graph, in, split_sizes, dim, out);
130+
}
131+
132+
void split_tensor(ComputeGraph& graph, const std::vector<ValueRef>& args) {
133+
add_split_tensor_node(graph, args[0], args[1], args[2], args[3]);
134+
}
135+
136+
REGISTER_OPERATORS {
137+
VK_REGISTER_OP(aten.split_with_sizes.default, split_with_sizes_default);
138+
VK_REGISTER_OP(aten.split.Tensor, split_tensor);
139+
}
140+
141+
} // namespace vkcompute

backends/vulkan/runtime/graph/ops/impl/utils/DimUtils.h

Lines changed: 34 additions & 27 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,27 @@
1212

1313
namespace vkcompute {
1414

15+
// A canonical way to represent dimensions as enum. Motivation behind a
16+
// canonical enum is that in the user tensor, it is using a "big-endian"-ish
17+
// mechanism to reference a dimension in a nchw-tensor, leading to tensor of
18+
// different dimension have different mapping from dim to the underlying texture
19+
// dimension. For instasnce, for a 2d (height x width) tensors, dim 0 refers to
20+
// height and dim 1 refers to width; for a 4d (batch x channel x height x width)
21+
// tensor, dim 0 refers to batch and dim 1 refers to channel. Using this
22+
// canonical enum allows us to bring clarity in code.
23+
24+
enum NchwDim : uint32_t {
25+
DimWidth = 1u,
26+
DimHeight = 2u,
27+
DimChannel = 3u,
28+
DimBatch = 4u,
29+
};
30+
31+
// Convert a dim provided by user into canonical enum.
32+
inline NchwDim normalize_to_nchw_dim(const vTensor& v_in, int32_t dim) {
33+
return static_cast<NchwDim>(v_in.dim() - dim);
34+
}
35+
1536
/*
1637
* Maps a semantic dimension name to an integer that
1738
* corresponds to its innermost ordering in a 4D tensor in
@@ -20,10 +41,10 @@ namespace vkcompute {
2041
* corresponds to 2, and so on.
2142
*/
2243
struct Dim4D {
23-
static constexpr uint32_t Width = 1u;
24-
static constexpr uint32_t Height = 2u;
25-
static constexpr uint32_t Channel = 3u;
26-
static constexpr uint32_t Batch = 4u;
44+
static constexpr uint32_t Width = DimWidth;
45+
static constexpr uint32_t Height = DimHeight;
46+
static constexpr uint32_t Channel = DimChannel;
47+
static constexpr uint32_t Batch = DimBatch;
2748
};
2849

2950
/*
@@ -65,34 +86,20 @@ uint32_t dim_at(const std::vector<int64_t>& sizes) {
6586
return dims < N ? 1 : api::utils::safe_downcast<uint32_t>(sizes[dims - N]);
6687
}
6788

89+
inline uint32_t dim_at(const std::vector<int64_t>& sizes, NchwDim nchw_dim) {
90+
const uint32_t dims = sizes.size();
91+
return dims < nchw_dim
92+
? 1
93+
: api::utils::safe_downcast<uint32_t>(sizes[dims - nchw_dim]);
94+
}
95+
6896
template <uint32_t N>
6997
uint32_t dim_at(const vTensor& v_in) {
7098
return dim_at<N>(v_in.sizes());
7199
}
72100

73-
// A canonical way to represent dimensions as enum. Intended to use the same
74-
// value as Dim4D for potential future refactoring.
75-
76-
enum NchwDim {
77-
DimWidth = 1,
78-
DimHeight = 2,
79-
DimChannel = 3,
80-
DimBatch = 4,
81-
};
82-
83-
/* This function return a NchwDim
84-
* given a Tensor and a user provided dim. The reason for this normalization is
85-
* that in the user tensor coordinate, it is using a "big-endian" mechanism when
86-
* referring to a nchw dimension, in that dim=0 refers to the batch dimension in
87-
* a 4d tensor but dim=0 reference to height in a 2d tensor. Despite in a common
88-
* texture representation of channel packing, a 2d tensor has exactly the same
89-
* layout as a 4d with the batch and channel size equals to 1. This function
90-
* returns a canonical dimension to simplify dimension reasoning in the code.
91-
*
92-
*/
93-
94-
inline NchwDim normalize_to_nchw_dim(const vTensor& v_in, int32_t dim) {
95-
return static_cast<NchwDim>(v_in.dim() - dim);
101+
inline uint32_t dim_at(const vTensor& v_in, NchwDim nchw_dim) {
102+
return dim_at(v_in.sizes(), nchw_dim);
96103
}
97104

98105
inline std::ostream& operator<<(std::ostream& os, NchwDim nchw_dim) {

backends/vulkan/runtime/graph/ops/impl/utils/TensorUtils.cpp

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -8,8 +8,6 @@
88

99
#include <executorch/backends/vulkan/runtime/graph/ops/impl/utils/TensorUtils.h>
1010

11-
#include <executorch/backends/vulkan/runtime/graph/ops/impl/utils/DimUtils.h>
12-
1311
namespace vkcompute {
1412

1513
//

backends/vulkan/test/op_tests/cases.py

Lines changed: 89 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -438,9 +438,7 @@ def get_cat_inputs():
438438
([(3, 5), (4, 5)], 0),
439439
([(3, 5), (4, 5), (1, 5)], 0),
440440
(
441-
[
442-
(3, 5),
443-
],
441+
[(3, 5)],
444442
0,
445443
),
446444
# Cat on Width
@@ -449,9 +447,7 @@ def get_cat_inputs():
449447
([(5, 3), (5, 4)], 1),
450448
([(5, 3), (5, 4), (5, 1)], 1),
451449
(
452-
[
453-
(5, 4),
454-
],
450+
[(5, 4)],
455451
1,
456452
),
457453
([(5,), (6,)], 0),
@@ -474,6 +470,91 @@ def get_cat_inputs():
474470
return test_suite
475471

476472

473+
def get_split_with_sizes_inputs():
474+
Test = namedtuple("VkSliceTest", ["self", "sizes", "dim"])
475+
test_cases = [
476+
# Split on Width
477+
Test(self=(S1, 7, 10, 10), sizes=[1, 2, 3, 4], dim=3),
478+
Test(self=(7, 10, 10), sizes=[1, 2, 3, 4], dim=2),
479+
Test(self=(7, 10, 10), sizes=[1, 9], dim=2),
480+
Test(self=(10, 10), sizes=[1, 9], dim=1),
481+
Test(self=(10,), sizes=[1, 9], dim=0),
482+
# Split on Height
483+
Test(self=(S1, 7, 10, 10), sizes=[1, 2, 3, 4], dim=2),
484+
Test(self=(7, 10, 10), sizes=[1, 2, 3, 4], dim=1),
485+
Test(self=(7, 10, 10), sizes=[10], dim=1),
486+
Test(self=(7, 6, 10), sizes=[1, 1, 1, 1, 1, 1], dim=1),
487+
Test(self=(10, 10), sizes=[1, 2, 3, 4], dim=0),
488+
# Split on Batch
489+
Test(self=(10, 7, 10, 10), sizes=[3, 6, 1], dim=0),
490+
Test(self=(10, 7, 10, 10), sizes=[10], dim=0),
491+
# Split on Channel
492+
Test(self=(7, 13, 4, 8), sizes=[3, 6, 1, 3], dim=1),
493+
Test(self=(7, 13, 4, 8), sizes=[3, 3, 3, 3, 1], dim=1),
494+
Test(self=(13, 4, 8), sizes=[3, 3, 3, 3, 1], dim=0),
495+
Test(self=(13, 4, 8), sizes=[2, 9, 2], dim=0),
496+
Test(self=(13, 4, 8), sizes=[13], dim=0),
497+
]
498+
test_suite = VkTestSuite([tuple(tc) for tc in test_cases])
499+
500+
test_suite.layouts = [
501+
"api::kChannelsPacked",
502+
]
503+
test_suite.data_gen = "make_seq_tensor"
504+
test_suite.dtypes = ["at::kFloat"]
505+
return test_suite
506+
507+
508+
def get_split_tensor_inputs():
509+
test_suite = VkTestSuite(
510+
[
511+
# Split on Width
512+
((S1, 7, 10, 12), 12, 3),
513+
((S1, 7, 10, 12), 3, 3),
514+
((S1, 7, 10, 12), 1, 3),
515+
((7, 10, 12), 12, 2),
516+
((7, 10, 12), 3, 2),
517+
((7, 10, 12), 1, 2),
518+
((10, 12), 12, 1),
519+
((10, 12), 3, 1),
520+
((10, 12), 1, 1),
521+
((12,), 12, 0),
522+
((12,), 3, 0),
523+
((12,), 1, 0),
524+
# Split on Height
525+
((S1, 7, 12, 8), 12, 2),
526+
((S1, 7, 12, 8), 3, 2),
527+
((S1, 7, 12, 8), 1, 2),
528+
((7, 12, 8), 12, 1),
529+
((7, 12, 8), 3, 1),
530+
((7, 12, 8), 1, 1),
531+
((12, 8), 12, 0),
532+
((12, 8), 3, 0),
533+
((12, 8), 1, 0),
534+
# Split on Batch
535+
((12, 7, 10, 10), 12, 0),
536+
((12, 7, 10, 10), 3, 0),
537+
((12, 7, 10, 10), 1, 0),
538+
# Split on Channel
539+
((7, 15, 10, 10), 15, 1),
540+
((7, 15, 10, 10), 5, 1),
541+
((7, 15, 10, 10), 3, 1),
542+
((7, 15, 10, 10), 1, 1),
543+
((15, 10, 10), 15, 0),
544+
((15, 10, 10), 5, 0),
545+
((15, 10, 10), 3, 0),
546+
((15, 10, 10), 1, 0),
547+
]
548+
)
549+
550+
test_suite.layouts = [
551+
"api::kChannelsPacked",
552+
]
553+
test_suite.data_gen = "make_seq_tensor"
554+
test_suite.dtypes = ["at::kFloat"]
555+
return test_suite
556+
557+
477558
test_suites = {
478559
"aten.add.Tensor": get_binary_elementwise_inputs(),
479560
"aten.sub.Tensor": get_binary_elementwise_inputs(),
@@ -494,4 +575,6 @@ def get_cat_inputs():
494575
"aten.clone.default": get_clone_inputs(),
495576
"aten.repeat.default": get_repeat_inputs(),
496577
"aten.cat.default": get_cat_inputs(),
578+
"aten.split_with_sizes.default": get_split_with_sizes_inputs(),
579+
"aten.split.Tensor": get_split_tensor_inputs(),
497580
}

0 commit comments

Comments
 (0)