Skip to content

Commit 547ab28

Browse files
committed
[ET-VK][8/n] Unsqueeze
Pull Request resolved: #3172 Exploit the fact that, we reduce the unsqueeze operation to permute. ``` torch.all(torch.permute(x.unsqueeze(0), [1, 0, 2, 3]) == x.unsqueeze(1)) torch.all(torch.permute(x.unsqueeze(0), [1, 2, 0, 3]) == x.unsqueeze(2)) torch.all(torch.permute(x.unsqueeze(0), [1, 2, 3, 0]) == x.unsqueeze(3)) ``` This diff introduce a minor change to the Permute implementation that it no longer requires the input dimension length to match the length of the permute array. This allows the `unsqueeze` operation to achieve a no-op `unsqueeze(0)` and then apply a permute. ghstack-source-id: 223698863 Differential Revision: [D56347734](https://our.internmc.facebook.com/intern/diff/D56347734/)
1 parent de0c233 commit 547ab28

File tree

4 files changed

+135
-19
lines changed

4 files changed

+135
-19
lines changed

backends/vulkan/runtime/graph/ops/impl/Permute.cpp

Lines changed: 30 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -6,8 +6,11 @@
66
* LICENSE file in the root directory of this source tree.
77
*/
88

9+
#include <executorch/backends/vulkan/runtime/graph/ops/impl/Permute.h>
10+
911
#include <executorch/backends/vulkan/runtime/graph/ops/OperatorRegistry.h>
1012

13+
#include <executorch/backends/vulkan/runtime/graph/ops/impl/utils/DimUtils.h>
1114
#include <executorch/backends/vulkan/runtime/graph/ops/impl/utils/KernelUtils.h>
1215
#include <executorch/backends/vulkan/runtime/graph/ops/impl/utils/TensorUtils.h>
1316
#include <executorch/backends/vulkan/runtime/graph/ops/utils/ShaderNameUtils.h>
@@ -20,53 +23,51 @@ using api::utils::uvec4;
2023

2124
void check_args(
2225
const vTensor& in,
23-
const IntListPtr& permute_dims,
26+
const std::vector<int64_t>& permute_dims,
2427
const vTensor& out) {
2528
VK_CHECK_COND(check_memory_layout_is(in, api::kChannelsPacked));
2629
VK_CHECK_COND(check_memory_layout_is(out, api::kChannelsPacked));
2730

28-
int64_t in_dim = in.dim();
31+
// This implementation doesn't not requires the input tensor to have the same
32+
// dim size as the argument. The code will work as long as the input tensor's
33+
// dim size is shorter than the permute dim array. In this case, the code
34+
// assume size of 1 at the higher dimensions.
35+
36+
int64_t out_dim = out.dim();
2937
VK_CHECK_COND(
30-
in_dim == permute_dims->size(),
31-
"Input tensor dim size must match argument");
38+
out_dim == permute_dims.size(),
39+
"Output tensor dim size must match argument");
3240
}
3341

3442
void add_permute_node(
3543
ComputeGraph& graph,
3644
ValueRef in,
37-
ValueRef permute_dims_ref,
45+
const std::vector<int64_t>& permute_dims,
3846
ValueRef out) {
3947
vTensorPtr t_in = graph.get_tensor(in);
4048
vTensorPtr t_out = graph.get_tensor(out);
4149

42-
IntListPtr permute_dims = graph.get_int_list(permute_dims_ref);
43-
4450
check_args(*t_in, permute_dims, *t_out);
4551

46-
uvec4 in_size{1u, 1u, 1u, 1u}, out_size{1u, 1u, 1u, 1u};
4752
uvec4 out_dims{0u, 1u, 2u, 3u};
4853

49-
int64_t in_dim = t_in->dim();
50-
51-
std::vector<bool> seen(in_dim);
52-
for (int i = 0; i < in_dim; i++) {
53-
int64_t permute_dim = (*permute_dims)[i];
54+
int64_t out_dim = t_out->dim();
55+
std::vector<bool> seen(out_dim);
56+
for (int i = 0; i < t_out->dim(); i++) {
57+
int64_t permute_dim = permute_dims[i];
5458
VK_CHECK_COND(
5559
!seen[permute_dim], "Argument dim ", permute_dim, " is repeated");
5660
seen[permute_dim] = true;
5761

58-
// Map to 4D tensor dims.
59-
in_size.data[(4u - in_dim) + i] = t_in->size(i);
60-
out_size.data[(4u - in_dim) + i] = t_in->size(permute_dim);
61-
out_dims.data[(4u - in_dim) + i] = permute_dim + (4u - in_dim);
62+
out_dims.data[(4u - out_dim) + i] = permute_dim + (4u - out_dim);
6263
}
6364

6465
std::string kernel_name = "permute";
6566
kernel_name.reserve(kShaderNameReserve);
6667
add_dtype_suffix(kernel_name, *t_out);
6768

68-
uint32_t out_channels = out_size.data[1u];
69-
uint32_t in_channels = in_size.data[1u];
69+
uint32_t out_channels = dim_at<Dim4D::Channel>(t_out->sizes());
70+
uint32_t in_channels = dim_at<Dim4D::Channel>(t_in->sizes());
7071

7172
uint32_t out_c_aligned = api::utils::align_up(out_channels, 4u);
7273
uint32_t in_c_aligned = api::utils::align_up(in_channels, 4u);
@@ -98,6 +99,16 @@ void add_permute_node(
9899
{}));
99100
}
100101

102+
void add_permute_node(
103+
ComputeGraph& graph,
104+
ValueRef in,
105+
ValueRef permute_dims_ref,
106+
ValueRef out) {
107+
IntListPtr permute_dims = graph.get_int_list(permute_dims_ref);
108+
109+
add_permute_node(graph, in, *permute_dims, out);
110+
}
111+
101112
void permute(ComputeGraph& graph, const std::vector<ValueRef>& args) {
102113
return add_permute_node(graph, args[0], args[1], args[2]);
103114
}
Lines changed: 25 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,25 @@
1+
/*
2+
* Copyright (c) Meta Platforms, Inc. and affiliates.
3+
* All rights reserved.
4+
*
5+
* This source code is licensed under the BSD-style license found in the
6+
* LICENSE file in the root directory of this source tree.
7+
*/
8+
9+
#pragma once
10+
11+
#include <executorch/backends/vulkan/runtime/api/api.h>
12+
13+
#include <executorch/backends/vulkan/runtime/graph/ComputeGraph.h>
14+
15+
#include <vector>
16+
17+
namespace vkcompute {
18+
19+
void add_permute_node(
20+
ComputeGraph& graph,
21+
ValueRef in,
22+
const std::vector<int64_t>& permute_dims,
23+
ValueRef out);
24+
25+
} // namespace vkcompute
Lines changed: 52 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,52 @@
1+
/*
2+
* Copyright (c) Meta Platforms, Inc. and affiliates.
3+
* All rights reserved.
4+
*
5+
* This source code is licensed under the BSD-style license found in the
6+
* LICENSE file in the root directory of this source tree.
7+
*/
8+
9+
#include <executorch/backends/vulkan/runtime/graph/ops/OperatorRegistry.h>
10+
11+
#include <executorch/backends/vulkan/runtime/graph/ops/impl/Permute.h>
12+
#include <executorch/backends/vulkan/runtime/graph/ops/impl/utils/KernelUtils.h>
13+
#include <executorch/backends/vulkan/runtime/graph/ops/utils/ShaderNameUtils.h>
14+
15+
namespace vkcompute {
16+
17+
void add_unsqueeze_node(
18+
ComputeGraph& graph,
19+
ValueRef in,
20+
ValueRef dim_ref,
21+
ValueRef out) {
22+
vTensorPtr t_in = graph.get_tensor(in);
23+
vTensorPtr t_out = graph.get_tensor(out);
24+
25+
VK_CHECK_COND(
26+
t_in->dim() < 4, "Cannot unsqueeze a tensor with more than 3 dimensions");
27+
28+
int64_t dim = graph.extract_scalar<int64_t>(dim_ref);
29+
int64_t out_dim = t_out->dim();
30+
31+
std::vector<int64_t> permute_dims(out_dim);
32+
for (int i = 1; i <= dim; i++) {
33+
permute_dims[i - 1] = i;
34+
}
35+
permute_dims[dim] = 0;
36+
37+
for (int i = dim + 1; i < out_dim; i++) {
38+
permute_dims[i] = i;
39+
}
40+
41+
add_permute_node(graph, in, permute_dims, out);
42+
}
43+
44+
void unsqueeze(ComputeGraph& graph, const std::vector<ValueRef>& args) {
45+
return add_unsqueeze_node(graph, args[0], args[1], args[2]);
46+
}
47+
48+
REGISTER_OPERATORS {
49+
VK_REGISTER_OP(aten.unsqueeze_copy.default, unsqueeze);
50+
}
51+
52+
} // namespace vkcompute

backends/vulkan/test/op_tests/cases.py

Lines changed: 28 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -236,6 +236,7 @@ def get_permute_inputs():
236236
((9, 2), [1, 0]),
237237
]
238238
)
239+
239240
test_suite.layouts = ["api::kChannelsPacked"]
240241
return test_suite
241242

@@ -334,6 +335,32 @@ def get_slice_inputs():
334335
return test_suite
335336

336337

338+
def get_unsqueeze_inputs():
339+
test_suite = VkTestSuite(
340+
[
341+
((2, 3, 4), 0),
342+
((1, 1, 1), 0),
343+
((1, 1, 1), 1),
344+
((1, 1, 1), 2),
345+
((1, 1, 1), 3),
346+
((9, 9, 9), 0),
347+
((9, 9, 9), 1),
348+
((9, 9, 9), 2),
349+
((9, 9, 9), 3),
350+
((9, 9), 0),
351+
((9, 9), 1),
352+
((9, 9), 2),
353+
((9,), 0),
354+
((9,), 1),
355+
]
356+
)
357+
test_suite.layouts = [
358+
"api::kChannelsPacked",
359+
]
360+
test_suite.data_gen = "make_seq_tensor"
361+
return test_suite
362+
363+
337364
test_suites = {
338365
"aten.add.Tensor": get_binary_elementwise_inputs(),
339366
"aten.sub.Tensor": get_binary_elementwise_inputs(),
@@ -350,4 +377,5 @@ def get_slice_inputs():
350377
"aten.permute_copy.default": get_permute_inputs(),
351378
"aten.view_copy.default": get_view_inputs(),
352379
"aten.slice_copy.Tensor": get_slice_inputs(),
380+
"aten.unsqueeze_copy.default": get_unsqueeze_inputs(),
353381
}

0 commit comments

Comments
 (0)