Skip to content

Commit 7047162

Browse files
Yujie Huifacebook-github-bot
authored andcommitted
Implement aten.squeeze_copy.dims (#4223)
Summary: Pull Request resolved: #4223 Implement aten.squeeze_copy.dims operator This op is compiled from `torch.squeeze(x, dims=0)`. bypass-github-export-checks bypass-github-pytorch-ci-checks bypass-github-executorch-ci-checks Reviewed By: jorgep31415 Differential Revision: D59605342 fbshipit-source-id: 2acabe080360875937e4e48d427d6cc7fae802ff
1 parent 9221ab6 commit 7047162

File tree

6 files changed

+125
-9
lines changed

6 files changed

+125
-9
lines changed

backends/vulkan/partitioner/supported_ops.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -97,6 +97,7 @@ def __contains__(self, op):
9797
]
9898

9999
SHAPE_MANIPULATION_OPS = [
100+
exir_ops.edge.aten.squeeze_copy.dims,
100101
exir_ops.edge.aten.unsqueeze_copy.default,
101102
exir_ops.edge.aten.view_copy.default,
102103
exir_ops.edge.aten.permute_copy.default,
Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,19 @@
1+
/*
2+
* Copyright (c) Meta Platforms, Inc. and affiliates.
3+
* All rights reserved.
4+
*
5+
* This source code is licensed under the BSD-style license found in the
6+
* LICENSE file in the root directory of this source tree.
7+
*/
8+
9+
#pragma once
10+
11+
#include <executorch/backends/vulkan/runtime/api/api.h>
12+
13+
#include <executorch/backends/vulkan/runtime/graph/ComputeGraph.h>
14+
15+
namespace vkcompute {
16+
17+
void add_clone_node(ComputeGraph& graph, const ValueRef in, const ValueRef out);
18+
19+
} // namespace vkcompute

backends/vulkan/runtime/graph/ops/impl/Permute.cpp

Lines changed: 7 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -35,11 +35,6 @@ void check_args(
3535
// dim size as the argument. The code will work as long as the input tensor's
3636
// dim size is shorter than the permute dim array. In this case, the code
3737
// assume size of 1 at the higher dimensions.
38-
39-
int64_t out_dim = out.dim();
40-
VK_CHECK_COND(
41-
out_dim == permute_dims.size(),
42-
"Output tensor dim size must match argument");
4338
}
4439

4540
} // namespace
@@ -56,15 +51,18 @@ void add_permute_node(
5651

5752
ivec4 out_dims{0, 1, 2, 3};
5853

59-
int64_t out_dim = t_out->dim();
60-
std::vector<bool> seen(out_dim);
61-
for (int i = 0; i < t_out->dim(); i++) {
54+
// Special cases of squeeze/unsqueeze. Because the input dim size can be
55+
// different with output dim size. So pick t_in->dim() if squeeze, and
56+
// t_out->dim() if unsqueeze to create parameter for permute.
57+
int64_t out_ndim = std::max(t_in->dim(), t_out->dim());
58+
std::vector<bool> seen(out_ndim);
59+
for (int i = 0; i < out_ndim; i++) {
6260
int64_t permute_dim = permute_dims[i];
6361
VK_CHECK_COND(
6462
!seen[permute_dim], "Argument dim ", permute_dim, " is repeated");
6563
seen[permute_dim] = true;
6664

67-
out_dims.data[(4u - out_dim) + i] = permute_dim + (4 - out_dim);
65+
out_dims.data[(4u - out_ndim) + i] = permute_dim + (4 - out_ndim);
6866
}
6967

7068
std::string kernel_name = "permute";
Lines changed: 63 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,63 @@
1+
/*
2+
* Copyright (c) Meta Platforms, Inc. and affiliates.
3+
* All rights reserved.
4+
*
5+
* This source code is licensed under the BSD-style license found in the
6+
* LICENSE file in the root directory of this source tree.
7+
*/
8+
9+
#include <executorch/backends/vulkan/runtime/graph/ops/OperatorRegistry.h>
10+
11+
#include <executorch/backends/vulkan/runtime/graph/ops/impl/Clone.h>
12+
#include <executorch/backends/vulkan/runtime/graph/ops/impl/Permute.h>
13+
#include <executorch/backends/vulkan/runtime/graph/ops/impl/utils/KernelUtils.h>
14+
#include <executorch/backends/vulkan/runtime/graph/ops/utils/ShaderNameUtils.h>
15+
16+
namespace vkcompute {
17+
18+
void add_squeeze_copy_dims_node(
19+
ComputeGraph& graph,
20+
ValueRef in,
21+
ValueRef dims_ref,
22+
ValueRef out) {
23+
vTensorPtr t_in = graph.get_tensor(in);
24+
vTensorPtr t_out = graph.get_tensor(out);
25+
26+
IntListPtr dims = graph.get_int_list(dims_ref);
27+
std::vector<int64_t> squeeze_dims;
28+
// Filter out edge cases that we don't need squeeze:
29+
// 1. The size of squeeze dim is larger than 1.
30+
// 2. Squeeze outter most dim
31+
// For these cases, just pass input to output via clone.
32+
for (int i = 0; i < dims->size(); ++i) {
33+
if (dims->at(i) != 0 && t_in->sizes().at(dims->at(i)) == 1) {
34+
squeeze_dims.push_back(dims->at(i));
35+
}
36+
}
37+
if (squeeze_dims.size() == 0) {
38+
add_clone_node(graph, in, out);
39+
} else {
40+
std::vector<int64_t> permute_dims(t_in->dim());
41+
for (int i = 0; i < t_in->dim(); ++i) {
42+
permute_dims.at(i) = i;
43+
}
44+
for (auto& elem : squeeze_dims) {
45+
auto it = std::find(permute_dims.begin(), permute_dims.end(), elem);
46+
VK_CHECK_COND(
47+
it != permute_dims.end(), "Squeeze dim not found in permute_dims");
48+
std::rotate(permute_dims.begin(), it, it + 1);
49+
}
50+
51+
add_permute_node(graph, in, permute_dims, out);
52+
}
53+
}
54+
55+
void squeeze_copy_dims(ComputeGraph& graph, const std::vector<ValueRef>& args) {
56+
return add_squeeze_copy_dims_node(graph, args[0], args[1], args[2]);
57+
}
58+
59+
REGISTER_OPERATORS {
60+
VK_REGISTER_OP(aten.squeeze_copy.dims, squeeze_copy_dims);
61+
}
62+
63+
} // namespace vkcompute

backends/vulkan/test/op_tests/cases.py

Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1038,3 +1038,22 @@ def get_minimum_inputs():
10381038
]
10391039
)
10401040
return test_suite
1041+
1042+
1043+
@register_test_suite("aten.squeeze_copy.dims")
1044+
def get_squeeze_copy_dim_inputs():
1045+
test_suite = VkTestSuite(
1046+
[
1047+
([S, S, S, 1], 3),
1048+
([S, 1, S, S], 1),
1049+
([S, 1, 1, S], [1, 2]),
1050+
([1, S, S, S], 0),
1051+
([S, S, S, S], 3),
1052+
([S, S, S, S], 2),
1053+
([S, S, S, S], 1),
1054+
([M, M1, 1], 2),
1055+
([M, 1, M1], 1),
1056+
([1, M1, M1], 0),
1057+
]
1058+
)
1059+
return test_suite

backends/vulkan/test/test_vulkan_delegate.py

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1157,6 +1157,22 @@ def forward(self, x):
11571157
memory_layouts=[vk_graph_schema.VkMemoryLayout.TENSOR_CHANNELS_PACKED],
11581158
)
11591159

1160+
def test_vulkan_backend_squeeze(self):
1161+
class SqueezeModule(torch.nn.Module):
1162+
def __init__(self):
1163+
super().__init__()
1164+
1165+
def forward(self, x):
1166+
return torch.squeeze(x, 0)
1167+
1168+
sample_inputs = (torch.randn(size=(1, 2, 2, 1), dtype=torch.float32),)
1169+
1170+
self.lower_module_and_test_output(
1171+
SqueezeModule(),
1172+
sample_inputs,
1173+
memory_layouts=[vk_graph_schema.VkMemoryLayout.TENSOR_CHANNELS_PACKED],
1174+
)
1175+
11601176
def test_vulkan_backend_select(self):
11611177
class SelectModule(torch.nn.Module):
11621178
def __init__(self):

0 commit comments

Comments
 (0)