Skip to content

Implement aten.squeeze_copy.dims #4223

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed
wants to merge 1 commit into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions backends/vulkan/partitioner/supported_ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -97,6 +97,7 @@ def __contains__(self, op):
]

SHAPE_MANIPULATION_OPS = [
exir_ops.edge.aten.squeeze_copy.dims,
exir_ops.edge.aten.unsqueeze_copy.default,
exir_ops.edge.aten.view_copy.default,
exir_ops.edge.aten.permute_copy.default,
Expand Down
19 changes: 19 additions & 0 deletions backends/vulkan/runtime/graph/ops/impl/Clone.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,19 @@
/*
* Copyright (c) Meta Platforms, Inc. and affiliates.
* All rights reserved.
*
* This source code is licensed under the BSD-style license found in the
* LICENSE file in the root directory of this source tree.
*/

#pragma once

#include <executorch/backends/vulkan/runtime/api/api.h>

#include <executorch/backends/vulkan/runtime/graph/ComputeGraph.h>

namespace vkcompute {

void add_clone_node(ComputeGraph& graph, const ValueRef in, const ValueRef out);

} // namespace vkcompute
16 changes: 7 additions & 9 deletions backends/vulkan/runtime/graph/ops/impl/Permute.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -35,11 +35,6 @@ void check_args(
// dim size as the argument. The code will work as long as the input tensor's
// dim size is shorter than the permute dim array. In this case, the code
// assume size of 1 at the higher dimensions.

int64_t out_dim = out.dim();
VK_CHECK_COND(
out_dim == permute_dims.size(),
"Output tensor dim size must match argument");
}

} // namespace
Expand All @@ -56,15 +51,18 @@ void add_permute_node(

ivec4 out_dims{0, 1, 2, 3};

int64_t out_dim = t_out->dim();
std::vector<bool> seen(out_dim);
for (int i = 0; i < t_out->dim(); i++) {
// Special cases of squeeze/unsqueeze. Because the input dim size can be
// different with output dim size. So pick t_in->dim() if squeeze, and
// t_out->dim() if unsqueeze to create parameter for permute.
int64_t out_ndim = std::max(t_in->dim(), t_out->dim());
std::vector<bool> seen(out_ndim);
for (int i = 0; i < out_ndim; i++) {
int64_t permute_dim = permute_dims[i];
VK_CHECK_COND(
!seen[permute_dim], "Argument dim ", permute_dim, " is repeated");
seen[permute_dim] = true;

out_dims.data[(4u - out_dim) + i] = permute_dim + (4 - out_dim);
out_dims.data[(4u - out_ndim) + i] = permute_dim + (4 - out_ndim);
}

std::string kernel_name = "permute";
Expand Down
63 changes: 63 additions & 0 deletions backends/vulkan/runtime/graph/ops/impl/Squeeze.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,63 @@
/*
* Copyright (c) Meta Platforms, Inc. and affiliates.
* All rights reserved.
*
* This source code is licensed under the BSD-style license found in the
* LICENSE file in the root directory of this source tree.
*/

#include <executorch/backends/vulkan/runtime/graph/ops/OperatorRegistry.h>

#include <executorch/backends/vulkan/runtime/graph/ops/impl/Clone.h>
#include <executorch/backends/vulkan/runtime/graph/ops/impl/Permute.h>
#include <executorch/backends/vulkan/runtime/graph/ops/impl/utils/KernelUtils.h>
#include <executorch/backends/vulkan/runtime/graph/ops/utils/ShaderNameUtils.h>

namespace vkcompute {

void add_squeeze_copy_dims_node(
ComputeGraph& graph,
ValueRef in,
ValueRef dims_ref,
ValueRef out) {
vTensorPtr t_in = graph.get_tensor(in);
vTensorPtr t_out = graph.get_tensor(out);

IntListPtr dims = graph.get_int_list(dims_ref);
std::vector<int64_t> squeeze_dims;
// Filter out edge cases that we don't need squeeze:
// 1. The size of squeeze dim is larger than 1.
// 2. Squeeze outter most dim
// For these cases, just pass input to output via clone.
for (int i = 0; i < dims->size(); ++i) {
if (dims->at(i) != 0 && t_in->sizes().at(dims->at(i)) == 1) {
squeeze_dims.push_back(dims->at(i));
}
}
if (squeeze_dims.size() == 0) {
add_clone_node(graph, in, out);
} else {
std::vector<int64_t> permute_dims(t_in->dim());
for (int i = 0; i < t_in->dim(); ++i) {
permute_dims.at(i) = i;
}
for (auto& elem : squeeze_dims) {
auto it = std::find(permute_dims.begin(), permute_dims.end(), elem);
VK_CHECK_COND(
it != permute_dims.end(), "Squeeze dim not found in permute_dims");
std::rotate(permute_dims.begin(), it, it + 1);
}

add_permute_node(graph, in, permute_dims, out);
}
}

void squeeze_copy_dims(ComputeGraph& graph, const std::vector<ValueRef>& args) {
return add_squeeze_copy_dims_node(graph, args[0], args[1], args[2]);
}

REGISTER_OPERATORS {
VK_REGISTER_OP(aten.squeeze_copy.dims, squeeze_copy_dims);
}

} // namespace vkcompute
19 changes: 19 additions & 0 deletions backends/vulkan/test/op_tests/cases.py
Original file line number Diff line number Diff line change
Expand Up @@ -1038,3 +1038,22 @@ def get_minimum_inputs():
]
)
return test_suite


@register_test_suite("aten.squeeze_copy.dims")
def get_squeeze_copy_dim_inputs():
test_suite = VkTestSuite(
[
([S, S, S, 1], 3),
([S, 1, S, S], 1),
([S, 1, 1, S], [1, 2]),
([1, S, S, S], 0),
([S, S, S, S], 3),
([S, S, S, S], 2),
([S, S, S, S], 1),
([M, M1, 1], 2),
([M, 1, M1], 1),
([1, M1, M1], 0),
]
)
return test_suite
16 changes: 16 additions & 0 deletions backends/vulkan/test/test_vulkan_delegate.py
Original file line number Diff line number Diff line change
Expand Up @@ -1157,6 +1157,22 @@ def forward(self, x):
memory_layouts=[vk_graph_schema.VkMemoryLayout.TENSOR_CHANNELS_PACKED],
)

def test_vulkan_backend_squeeze(self):
class SqueezeModule(torch.nn.Module):
def __init__(self):
super().__init__()

def forward(self, x):
return torch.squeeze(x, 0)

sample_inputs = (torch.randn(size=(1, 2, 2, 1), dtype=torch.float32),)

self.lower_module_and_test_output(
SqueezeModule(),
sample_inputs,
memory_layouts=[vk_graph_schema.VkMemoryLayout.TENSOR_CHANNELS_PACKED],
)

def test_vulkan_backend_select(self):
class SelectModule(torch.nn.Module):
def __init__(self):
Expand Down
Loading