Skip to content

Commit 2166892

Browse files
committed
[ET-VK][8/n] Unsqueeze
Pull Request resolved: #3172 Exploit the fact that, we reduce the unsqueeze operation to permute. ``` torch.all(torch.permute(x.unsqueeze(0), [1, 0, 2, 3]) == x.unsqueeze(1)) torch.all(torch.permute(x.unsqueeze(0), [1, 2, 0, 3]) == x.unsqueeze(2)) torch.all(torch.permute(x.unsqueeze(0), [1, 2, 3, 0]) == x.unsqueeze(3)) ``` This diff introduce a minor change to the Permute implementation that it no longer requires the input dimension length to match the length of the permute array. This allows the `unsqueeze` operation to achieve a no-op `unsqueeze(0)` and then apply a permute. Differential Revision: [D56347734](https://our.internmc.facebook.com/intern/diff/D56347734/) ghstack-source-id: 223528485
1 parent cb77763 commit 2166892

File tree

4 files changed

+136
-19
lines changed

4 files changed

+136
-19
lines changed

backends/vulkan/runtime/graph/ops/impl/Permute.cpp

Lines changed: 30 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -6,8 +6,11 @@
66
* LICENSE file in the root directory of this source tree.
77
*/
88

9+
#include <executorch/backends/vulkan/runtime/graph/ops/impl/Permute.h>
10+
911
#include <executorch/backends/vulkan/runtime/graph/ops/OperatorRegistry.h>
1012

13+
#include <executorch/backends/vulkan/runtime/graph/ops/impl/utils/DimUtils.h>
1114
#include <executorch/backends/vulkan/runtime/graph/ops/impl/utils/KernelUtils.h>
1215
#include <executorch/backends/vulkan/runtime/graph/ops/impl/utils/TensorUtils.h>
1316
#include <executorch/backends/vulkan/runtime/graph/ops/utils/ShaderNameUtils.h>
@@ -20,53 +23,51 @@ using api::utils::uvec4;
2023

2124
void check_args(
2225
const vTensor& in,
23-
const IntListPtr& permute_dims,
26+
const std::vector<int64_t>& permute_dims,
2427
const vTensor& out) {
2528
VK_CHECK_COND(check_memory_layout_is(in, api::kChannelsPacked));
2629
VK_CHECK_COND(check_memory_layout_is(out, api::kChannelsPacked));
2730

28-
int64_t in_dim = in.dim();
31+
// This implementation doesn't not requires the input tensor to have the same
32+
// dim size as the argument. The code will work as long as the input tensor's
33+
// dim size is shorter than the permute dim array. In this case, the code
34+
// assume size of 1 at the higher dimensions.
35+
36+
int64_t out_dim = out.dim();
2937
VK_CHECK_COND(
30-
in_dim == permute_dims->size(),
31-
"Input tensor dim size must match argument");
38+
out_dim == permute_dims.size(),
39+
"Output tensor dim size must match argument");
3240
}
3341

3442
void add_permute_node(
3543
ComputeGraph& graph,
3644
ValueRef in,
37-
ValueRef permute_dims_ref,
45+
const std::vector<int64_t>& permute_dims,
3846
ValueRef out) {
3947
vTensorPtr t_in = graph.get_tensor(in);
4048
vTensorPtr t_out = graph.get_tensor(out);
4149

42-
IntListPtr permute_dims = graph.get_int_list(permute_dims_ref);
43-
4450
check_args(*t_in, permute_dims, *t_out);
4551

46-
uvec4 in_size{1u, 1u, 1u, 1u}, out_size{1u, 1u, 1u, 1u};
4752
uvec4 out_dims{0u, 1u, 2u, 3u};
4853

49-
int64_t in_dim = t_in->dim();
50-
51-
std::vector<bool> seen(in_dim);
52-
for (int i = 0; i < in_dim; i++) {
53-
int64_t permute_dim = (*permute_dims)[i];
54+
int64_t out_dim = t_out->dim();
55+
std::vector<bool> seen(out_dim);
56+
for (int i = 0; i < t_out->dim(); i++) {
57+
int64_t permute_dim = permute_dims[i];
5458
VK_CHECK_COND(
5559
!seen[permute_dim], "Argument dim ", permute_dim, " is repeated");
5660
seen[permute_dim] = true;
5761

58-
// Map to 4D tensor dims.
59-
in_size.data[(4u - in_dim) + i] = t_in->size(i);
60-
out_size.data[(4u - in_dim) + i] = t_in->size(permute_dim);
61-
out_dims.data[(4u - in_dim) + i] = permute_dim + (4u - in_dim);
62+
out_dims.data[(4u - out_dim) + i] = permute_dim + (4u - out_dim);
6263
}
6364

6465
std::string kernel_name = "permute";
6566
kernel_name.reserve(kShaderNameReserve);
6667
add_dtype_suffix(kernel_name, *t_out);
6768

68-
uint32_t out_channels = out_size.data[1u];
69-
uint32_t in_channels = in_size.data[1u];
69+
uint32_t out_channels = dim_at<Dim4D::Channel>(t_out->sizes());
70+
uint32_t in_channels = dim_at<Dim4D::Channel>(t_in->sizes());
7071

7172
uint32_t out_c_aligned = api::utils::align_up(out_channels, 4u);
7273
uint32_t in_c_aligned = api::utils::align_up(in_channels, 4u);
@@ -98,6 +99,16 @@ void add_permute_node(
9899
{}));
99100
}
100101

102+
void add_permute_node(
103+
ComputeGraph& graph,
104+
ValueRef in,
105+
ValueRef permute_dims_ref,
106+
ValueRef out) {
107+
IntListPtr permute_dims = graph.get_int_list(permute_dims_ref);
108+
109+
add_permute_node(graph, in, *permute_dims, out);
110+
}
111+
101112
void permute(ComputeGraph& graph, const std::vector<ValueRef>& args) {
102113
return add_permute_node(graph, args[0], args[1], args[2]);
103114
}
Lines changed: 23 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,23 @@
1+
/*
2+
* Copyright (c) Meta Platforms, Inc. and affiliates.
3+
* All rights reserved.
4+
*
5+
* This source code is licensed under the BSD-style license found in the
6+
* LICENSE file in the root directory of this source tree.
7+
*/
8+
9+
#include <executorch/backends/vulkan/runtime/api/api.h>
10+
11+
#include <executorch/backends/vulkan/runtime/graph/ComputeGraph.h>
12+
13+
#include <vector>
14+
15+
namespace vkcompute {
16+
17+
void add_permute_node(
18+
ComputeGraph& graph,
19+
ValueRef in,
20+
const std::vector<int64_t>& permute_dims,
21+
ValueRef out);
22+
23+
} // namespace vkcompute
Lines changed: 55 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,55 @@
1+
/*
2+
* Copyright (c) Meta Platforms, Inc. and affiliates.
3+
* All rights reserved.
4+
*
5+
* This source code is licensed under the BSD-style license found in the
6+
* LICENSE file in the root directory of this source tree.
7+
*/
8+
9+
#include <executorch/backends/vulkan/runtime/graph/ops/OperatorRegistry.h>
10+
11+
#include <executorch/backends/vulkan/runtime/api/api.h>
12+
13+
#include <executorch/backends/vulkan/runtime/graph/ops/impl/Permute.h>
14+
#include <executorch/backends/vulkan/runtime/graph/ops/impl/utils/KernelUtils.h>
15+
#include <executorch/backends/vulkan/runtime/graph/ops/impl/utils/TensorUtils.h>
16+
#include <executorch/backends/vulkan/runtime/graph/ops/utils/ShaderNameUtils.h>
17+
18+
namespace vkcompute {
19+
20+
void add_unsqueeze_node(
21+
ComputeGraph& graph,
22+
ValueRef in,
23+
ValueRef dim_ref,
24+
ValueRef out) {
25+
vTensorPtr t_in = graph.get_tensor(in);
26+
vTensorPtr t_out = graph.get_tensor(out);
27+
28+
VK_CHECK_COND(
29+
t_in->dim() < 4, "Cannot unsqueeze a tensor with more than 3 dimensions");
30+
31+
int64_t dim = graph.extract_scalar<int64_t>(dim_ref);
32+
int64_t out_dim = t_out->dim();
33+
34+
std::vector<int64_t> permute_dims(out_dim);
35+
for (int i = 1; i <= dim; i++) {
36+
permute_dims[i - 1] = i;
37+
}
38+
permute_dims[dim] = 0;
39+
40+
for (int i = dim + 1; i < out_dim; i++) {
41+
permute_dims[i] = i;
42+
}
43+
44+
add_permute_node(graph, in, permute_dims, out);
45+
}
46+
47+
void unsqueeze(ComputeGraph& graph, const std::vector<ValueRef>& args) {
48+
return add_unsqueeze_node(graph, args[0], args[1], args[2]);
49+
}
50+
51+
REGISTER_OPERATORS {
52+
VK_REGISTER_OP(aten.unsqueeze_copy.default, unsqueeze);
53+
}
54+
55+
} // namespace vkcompute

backends/vulkan/test/op_tests/cases.py

Lines changed: 28 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -214,6 +214,7 @@ def get_permute_inputs():
214214
((9, 2), [1, 0]),
215215
]
216216
)
217+
217218
test_suite.layouts = ["api::kChannelsPacked"]
218219
return test_suite
219220

@@ -312,6 +313,32 @@ def get_slice_inputs():
312313
return test_suite
313314

314315

316+
def get_unsqueeze_inputs():
317+
test_suite = VkTestSuite(
318+
[
319+
((2, 3, 4), 0),
320+
((1, 1, 1), 0),
321+
((1, 1, 1), 1),
322+
((1, 1, 1), 2),
323+
((1, 1, 1), 3),
324+
((9, 9, 9), 0),
325+
((9, 9, 9), 1),
326+
((9, 9, 9), 2),
327+
((9, 9, 9), 3),
328+
((9, 9), 0),
329+
((9, 9), 1),
330+
((9, 9), 2),
331+
((9,), 0),
332+
((9,), 1),
333+
]
334+
)
335+
test_suite.layouts = [
336+
"api::kChannelsPacked",
337+
]
338+
test_suite.data_gen = "make_seq_tensor"
339+
return test_suite
340+
341+
315342
test_suites = {
316343
"aten.add.Tensor": get_binary_elementwise_inputs(),
317344
"aten.sub.Tensor": get_binary_elementwise_inputs(),
@@ -328,4 +355,5 @@ def get_slice_inputs():
328355
"aten.permute_copy.default": get_permute_inputs(),
329356
"aten.view_copy.default": get_view_inputs(),
330357
"aten.slice_copy.Tensor": get_slice_inputs(),
358+
"aten.unsqueeze_copy.default": get_unsqueeze_inputs(),
331359
}

0 commit comments

Comments
 (0)