Skip to content

[ET-VK][8/n] Unsqueeze #3172

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed
wants to merge 4 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
49 changes: 30 additions & 19 deletions backends/vulkan/runtime/graph/ops/impl/Permute.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -6,8 +6,11 @@
* LICENSE file in the root directory of this source tree.
*/

#include <executorch/backends/vulkan/runtime/graph/ops/impl/Permute.h>

#include <executorch/backends/vulkan/runtime/graph/ops/OperatorRegistry.h>

#include <executorch/backends/vulkan/runtime/graph/ops/impl/utils/DimUtils.h>
#include <executorch/backends/vulkan/runtime/graph/ops/impl/utils/KernelUtils.h>
#include <executorch/backends/vulkan/runtime/graph/ops/impl/utils/TensorUtils.h>
#include <executorch/backends/vulkan/runtime/graph/ops/utils/ShaderNameUtils.h>
Expand All @@ -20,53 +23,51 @@ using api::utils::uvec4;

void check_args(
const vTensor& in,
const IntListPtr& permute_dims,
const std::vector<int64_t>& permute_dims,
const vTensor& out) {
VK_CHECK_COND(check_memory_layout_is(in, api::kChannelsPacked));
VK_CHECK_COND(check_memory_layout_is(out, api::kChannelsPacked));

int64_t in_dim = in.dim();
// This implementation doesn't not requires the input tensor to have the same
// dim size as the argument. The code will work as long as the input tensor's
// dim size is shorter than the permute dim array. In this case, the code
// assume size of 1 at the higher dimensions.

int64_t out_dim = out.dim();
VK_CHECK_COND(
in_dim == permute_dims->size(),
"Input tensor dim size must match argument");
out_dim == permute_dims.size(),
"Output tensor dim size must match argument");
}

void add_permute_node(
ComputeGraph& graph,
ValueRef in,
ValueRef permute_dims_ref,
const std::vector<int64_t>& permute_dims,
ValueRef out) {
vTensorPtr t_in = graph.get_tensor(in);
vTensorPtr t_out = graph.get_tensor(out);

IntListPtr permute_dims = graph.get_int_list(permute_dims_ref);

check_args(*t_in, permute_dims, *t_out);

uvec4 in_size{1u, 1u, 1u, 1u}, out_size{1u, 1u, 1u, 1u};
uvec4 out_dims{0u, 1u, 2u, 3u};

int64_t in_dim = t_in->dim();

std::vector<bool> seen(in_dim);
for (int i = 0; i < in_dim; i++) {
int64_t permute_dim = (*permute_dims)[i];
int64_t out_dim = t_out->dim();
std::vector<bool> seen(out_dim);
for (int i = 0; i < t_out->dim(); i++) {
int64_t permute_dim = permute_dims[i];
VK_CHECK_COND(
!seen[permute_dim], "Argument dim ", permute_dim, " is repeated");
seen[permute_dim] = true;

// Map to 4D tensor dims.
in_size.data[(4u - in_dim) + i] = t_in->size(i);
out_size.data[(4u - in_dim) + i] = t_in->size(permute_dim);
out_dims.data[(4u - in_dim) + i] = permute_dim + (4u - in_dim);
out_dims.data[(4u - out_dim) + i] = permute_dim + (4u - out_dim);
}

std::string kernel_name = "permute";
kernel_name.reserve(kShaderNameReserve);
add_dtype_suffix(kernel_name, *t_out);

uint32_t out_channels = out_size.data[1u];
uint32_t in_channels = in_size.data[1u];
uint32_t out_channels = dim_at<Dim4D::Channel>(t_out->sizes());
uint32_t in_channels = dim_at<Dim4D::Channel>(t_in->sizes());

uint32_t out_c_aligned = api::utils::align_up(out_channels, 4u);
uint32_t in_c_aligned = api::utils::align_up(in_channels, 4u);
Expand Down Expand Up @@ -98,6 +99,16 @@ void add_permute_node(
{}));
}

void add_permute_node(
ComputeGraph& graph,
ValueRef in,
ValueRef permute_dims_ref,
ValueRef out) {
IntListPtr permute_dims = graph.get_int_list(permute_dims_ref);

add_permute_node(graph, in, *permute_dims, out);
}

void permute(ComputeGraph& graph, const std::vector<ValueRef>& args) {
return add_permute_node(graph, args[0], args[1], args[2]);
}
Expand Down
25 changes: 25 additions & 0 deletions backends/vulkan/runtime/graph/ops/impl/Permute.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,25 @@
/*
* Copyright (c) Meta Platforms, Inc. and affiliates.
* All rights reserved.
*
* This source code is licensed under the BSD-style license found in the
* LICENSE file in the root directory of this source tree.
*/

#pragma once

#include <executorch/backends/vulkan/runtime/api/api.h>

#include <executorch/backends/vulkan/runtime/graph/ComputeGraph.h>

#include <vector>

namespace vkcompute {

void add_permute_node(
ComputeGraph& graph,
ValueRef in,
const std::vector<int64_t>& permute_dims,
ValueRef out);

} // namespace vkcompute
52 changes: 52 additions & 0 deletions backends/vulkan/runtime/graph/ops/impl/Unsqueeze.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,52 @@
/*
* Copyright (c) Meta Platforms, Inc. and affiliates.
* All rights reserved.
*
* This source code is licensed under the BSD-style license found in the
* LICENSE file in the root directory of this source tree.
*/

#include <executorch/backends/vulkan/runtime/graph/ops/OperatorRegistry.h>

#include <executorch/backends/vulkan/runtime/graph/ops/impl/Permute.h>
#include <executorch/backends/vulkan/runtime/graph/ops/impl/utils/KernelUtils.h>
#include <executorch/backends/vulkan/runtime/graph/ops/utils/ShaderNameUtils.h>

namespace vkcompute {

void add_unsqueeze_node(
ComputeGraph& graph,
ValueRef in,
ValueRef dim_ref,
ValueRef out) {
vTensorPtr t_in = graph.get_tensor(in);
vTensorPtr t_out = graph.get_tensor(out);

VK_CHECK_COND(
t_in->dim() < 4, "Cannot unsqueeze a tensor with more than 3 dimensions");

int64_t dim = graph.extract_scalar<int64_t>(dim_ref);
int64_t out_dim = t_out->dim();

std::vector<int64_t> permute_dims(out_dim);
for (int i = 1; i <= dim; i++) {
permute_dims[i - 1] = i;
}
permute_dims[dim] = 0;

for (int i = dim + 1; i < out_dim; i++) {
permute_dims[i] = i;
}

add_permute_node(graph, in, permute_dims, out);
}

void unsqueeze(ComputeGraph& graph, const std::vector<ValueRef>& args) {
return add_unsqueeze_node(graph, args[0], args[1], args[2]);
}

REGISTER_OPERATORS {
VK_REGISTER_OP(aten.unsqueeze_copy.default, unsqueeze);
}

} // namespace vkcompute
28 changes: 28 additions & 0 deletions backends/vulkan/test/op_tests/cases.py
Original file line number Diff line number Diff line change
Expand Up @@ -236,6 +236,7 @@ def get_permute_inputs():
((9, 2), [1, 0]),
]
)

test_suite.layouts = ["api::kChannelsPacked"]
return test_suite

Expand Down Expand Up @@ -334,6 +335,32 @@ def get_slice_inputs():
return test_suite


def get_unsqueeze_inputs():
test_suite = VkTestSuite(
[
((2, 3, 4), 0),
((1, 1, 1), 0),
((1, 1, 1), 1),
((1, 1, 1), 2),
((1, 1, 1), 3),
((9, 9, 9), 0),
((9, 9, 9), 1),
((9, 9, 9), 2),
((9, 9, 9), 3),
((9, 9), 0),
((9, 9), 1),
((9, 9), 2),
((9,), 0),
((9,), 1),
]
)
test_suite.layouts = [
"api::kChannelsPacked",
]
test_suite.data_gen = "make_seq_tensor"
return test_suite


test_suites = {
"aten.add.Tensor": get_binary_elementwise_inputs(),
"aten.sub.Tensor": get_binary_elementwise_inputs(),
Expand All @@ -350,4 +377,5 @@ def get_slice_inputs():
"aten.permute_copy.default": get_permute_inputs(),
"aten.view_copy.default": get_view_inputs(),
"aten.slice_copy.Tensor": get_slice_inputs(),
"aten.unsqueeze_copy.default": get_unsqueeze_inputs(),
}