Skip to content

Commit 3b8a671

Browse files
committed
[ET-VK][8/n] Unsqueeze
Exploit the fact that, we reduce the unsqueeze operation to permute. ``` torch.all(torch.permute(x.unsqueeze(0), [1, 0, 2, 3]) == x.unsqueeze(1)) torch.all(torch.permute(x.unsqueeze(0), [1, 2, 0, 3]) == x.unsqueeze(2)) torch.all(torch.permute(x.unsqueeze(0), [1, 2, 3, 0]) == x.unsqueeze(3)) ``` This diff introduce a minor change to the Permute implementation that it no longer requires the input dimension length to match the length of the permute array. This allows the `unsqueeze` operation to achieve a no-op `unsqueeze(0)` and then apply a permute. Differential Revision: [D56347734](https://our.internmc.facebook.com/intern/diff/D56347734/) ghstack-source-id: 223244602 Pull Request resolved: #3172
1 parent 0b007df commit 3b8a671

File tree

4 files changed

+141
-19
lines changed

4 files changed

+141
-19
lines changed

backends/vulkan/runtime/graph/ops/impl/Permute.cpp

Lines changed: 32 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -6,8 +6,13 @@
66
* LICENSE file in the root directory of this source tree.
77
*/
88

9+
#include <executorch/backends/vulkan/runtime/graph/ops/impl/Permute.h>
10+
911
#include <executorch/backends/vulkan/runtime/graph/ops/OperatorRegistry.h>
1012

13+
#include <executorch/backends/vulkan/runtime/api/api.h>
14+
15+
#include <executorch/backends/vulkan/runtime/graph/ops/impl/utils/DimUtils.h>
1116
#include <executorch/backends/vulkan/runtime/graph/ops/impl/utils/KernelUtils.h>
1217
#include <executorch/backends/vulkan/runtime/graph/ops/impl/utils/TensorUtils.h>
1318
#include <executorch/backends/vulkan/runtime/graph/ops/utils/ShaderNameUtils.h>
@@ -20,53 +25,51 @@ using api::utils::uvec4;
2025

2126
void check_args(
2227
const vTensor& in,
23-
const IntListPtr& permute_dims,
28+
const std::vector<int64_t>& permute_dims,
2429
const vTensor& out) {
2530
VK_CHECK_COND(check_memory_layout_is(in, api::kChannelsPacked));
2631
VK_CHECK_COND(check_memory_layout_is(out, api::kChannelsPacked));
2732

28-
int64_t in_dim = in.dim();
33+
// This implementation doesn't not requires the input tensor to have the same
34+
// dim size as the argument. The code will work as long as the input tensor's
35+
// dim size is shorter than the permute dim array. In this case, the code
36+
// assume size of 1 at the higher dimensions.
37+
38+
int64_t out_dim = out.dim();
2939
VK_CHECK_COND(
30-
in_dim == permute_dims->size(),
31-
"Input tensor dim size must match argument");
40+
out_dim == permute_dims.size(),
41+
"Output tensor dim size must match argument");
3242
}
3343

3444
void add_permute_node(
3545
ComputeGraph& graph,
3646
ValueRef in,
37-
ValueRef permute_dims_ref,
47+
const std::vector<int64_t>& permute_dims,
3848
ValueRef out) {
3949
vTensorPtr t_in = graph.get_tensor(in);
4050
vTensorPtr t_out = graph.get_tensor(out);
4151

42-
IntListPtr permute_dims = graph.get_int_list(permute_dims_ref);
43-
4452
check_args(*t_in, permute_dims, *t_out);
4553

46-
uvec4 in_size{1u, 1u, 1u, 1u}, out_size{1u, 1u, 1u, 1u};
4754
uvec4 out_dims{0u, 1u, 2u, 3u};
4855

49-
int64_t in_dim = t_in->dim();
50-
51-
std::vector<bool> seen(in_dim);
52-
for (int i = 0; i < in_dim; i++) {
53-
int64_t permute_dim = (*permute_dims)[i];
56+
int64_t out_dim = t_out->dim();
57+
std::vector<bool> seen(out_dim);
58+
for (int i = 0; i < t_out->dim(); i++) {
59+
int64_t permute_dim = permute_dims[i];
5460
VK_CHECK_COND(
5561
!seen[permute_dim], "Argument dim ", permute_dim, " is repeated");
5662
seen[permute_dim] = true;
5763

58-
// Map to 4D tensor dims.
59-
in_size.data[(4u - in_dim) + i] = t_in->size(i);
60-
out_size.data[(4u - in_dim) + i] = t_in->size(permute_dim);
61-
out_dims.data[(4u - in_dim) + i] = permute_dim + (4u - in_dim);
64+
out_dims.data[(4u - out_dim) + i] = permute_dim + (4u - out_dim);
6265
}
6366

6467
std::string kernel_name = "permute";
6568
kernel_name.reserve(kShaderNameReserve);
6669
add_dtype_suffix(kernel_name, *t_out);
6770

68-
uint32_t out_channels = out_size.data[1u];
69-
uint32_t in_channels = in_size.data[1u];
71+
uint32_t out_channels = dim_at<Dim4D::Channel>(t_out->sizes());
72+
uint32_t in_channels = dim_at<Dim4D::Channel>(t_in->sizes());
7073

7174
uint32_t out_c_aligned = api::utils::align_up(out_channels, 4u);
7275
uint32_t in_c_aligned = api::utils::align_up(in_channels, 4u);
@@ -91,6 +94,16 @@ void add_permute_node(
9194
{t_out->gpu_sizes_ubo(), graph.create_params_buffer(params)}));
9295
}
9396

97+
void add_permute_node(
98+
ComputeGraph& graph,
99+
ValueRef in,
100+
ValueRef permute_dims_ref,
101+
ValueRef out) {
102+
IntListPtr permute_dims = graph.get_int_list(permute_dims_ref);
103+
104+
add_permute_node(graph, in, *permute_dims, out);
105+
}
106+
94107
void permute(ComputeGraph& graph, const std::vector<ValueRef>& args) {
95108
return add_permute_node(graph, args[0], args[1], args[2]);
96109
}
Lines changed: 23 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,23 @@
1+
/*
2+
* Copyright (c) Meta Platforms, Inc. and affiliates.
3+
* All rights reserved.
4+
*
5+
* This source code is licensed under the BSD-style license found in the
6+
* LICENSE file in the root directory of this source tree.
7+
*/
8+
9+
#include <executorch/backends/vulkan/runtime/api/api.h>
10+
11+
#include <executorch/backends/vulkan/runtime/graph/ComputeGraph.h>
12+
13+
#include <vector>
14+
15+
namespace vkcompute {
16+
17+
void add_permute_node(
18+
ComputeGraph& graph,
19+
ValueRef in,
20+
const std::vector<int64_t>& permute_dims,
21+
ValueRef out);
22+
23+
} // namespace vkcompute
Lines changed: 58 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,58 @@
1+
/*
2+
* Copyright (c) Meta Platforms, Inc. and affiliates.
3+
* All rights reserved.
4+
*
5+
* This source code is licensed under the BSD-style license found in the
6+
* LICENSE file in the root directory of this source tree.
7+
*/
8+
9+
#include <executorch/backends/vulkan/runtime/graph/ops/OperatorRegistry.h>
10+
11+
#include <executorch/backends/vulkan/runtime/api/api.h>
12+
13+
#include <executorch/backends/vulkan/runtime/graph/ops/impl/Permute.h>
14+
#include <executorch/backends/vulkan/runtime/graph/ops/impl/utils/KernelUtils.h>
15+
#include <executorch/backends/vulkan/runtime/graph/ops/impl/utils/TensorUtils.h>
16+
#include <executorch/backends/vulkan/runtime/graph/ops/utils/ShaderNameUtils.h>
17+
18+
#include <executorch/backends/vulkan/runtime/graph/Logging.h>
19+
#include <iostream>
20+
21+
namespace vkcompute {
22+
23+
void add_unsqueeze_node(
24+
ComputeGraph& graph,
25+
ValueRef in,
26+
ValueRef dim_ref,
27+
ValueRef out) {
28+
vTensorPtr t_in = graph.get_tensor(in);
29+
vTensorPtr t_out = graph.get_tensor(out);
30+
31+
VK_CHECK_COND(
32+
t_in->dim() < 4, "Cannot unsqueeze a tensor with more than 3 dimensions");
33+
34+
int64_t dim = graph.extract_scalar<int64_t>(dim_ref);
35+
int64_t out_dim = t_out->dim();
36+
37+
std::vector<int64_t> permute_dims(out_dim);
38+
for (int i = 1; i <= dim; i++) {
39+
permute_dims[i - 1] = i;
40+
}
41+
permute_dims[dim] = 0;
42+
43+
for (int i = dim + 1; i < out_dim; i++) {
44+
permute_dims[i] = i;
45+
}
46+
47+
add_permute_node(graph, in, permute_dims, out);
48+
}
49+
50+
void unsqueeze(ComputeGraph& graph, const std::vector<ValueRef>& args) {
51+
return add_unsqueeze_node(graph, args[0], args[1], args[2]);
52+
}
53+
54+
REGISTER_OPERATORS {
55+
VK_REGISTER_OP(aten.unsqueeze_copy.default, unsqueeze);
56+
}
57+
58+
} // namespace vkcompute

backends/vulkan/test/op_tests/cases.py

Lines changed: 28 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -192,6 +192,7 @@ def get_permute_inputs():
192192
((9, 2), [1, 0]),
193193
]
194194
)
195+
195196
test_suite.layouts = ["api::kChannelsPacked"]
196197
return test_suite
197198

@@ -290,6 +291,32 @@ def get_slice_inputs():
290291
return test_suite
291292

292293

294+
def get_unsqueeze_inputs():
295+
test_suite = VkTestSuite(
296+
[
297+
((2, 3, 4), 0),
298+
((1, 1, 1), 0),
299+
((1, 1, 1), 1),
300+
((1, 1, 1), 2),
301+
((1, 1, 1), 3),
302+
((9, 9, 9), 0),
303+
((9, 9, 9), 1),
304+
((9, 9, 9), 2),
305+
((9, 9, 9), 3),
306+
((9, 9), 0),
307+
((9, 9), 1),
308+
((9, 9), 2),
309+
((9,), 0),
310+
((9,), 1),
311+
]
312+
)
313+
test_suite.layouts = [
314+
"api::kChannelsPacked",
315+
]
316+
test_suite.data_gen = "make_seq_tensor"
317+
return test_suite
318+
319+
293320
test_suites = {
294321
"aten.add.Tensor": get_binary_elementwise_inputs(),
295322
"aten.sub.Tensor": get_binary_elementwise_inputs(),
@@ -306,4 +333,5 @@ def get_slice_inputs():
306333
"aten.permute_copy.default": get_permute_inputs(),
307334
"aten.view_copy.default": get_view_inputs(),
308335
"aten.slice_copy.Tensor": get_slice_inputs(),
336+
"aten.unsqueeze_copy.default": get_unsqueeze_inputs(),
309337
}

0 commit comments

Comments
 (0)