Skip to content

Commit de00717

Browse files
yipjustinfacebook-github-bot
authored andcommitted
aten.permute_copy.default (#3086)
Summary: Pull Request resolved: #3086 Implementation adopted from LI, with clean-up. ghstack-source-id: 222906934 Reviewed By: copyrightly Differential Revision: D56093765 fbshipit-source-id: 0ed78ae06e5b106a92cf3c1fdc85179f1e829919
1 parent 49928bc commit de00717

File tree

7 files changed

+231
-5
lines changed

7 files changed

+231
-5
lines changed

backends/vulkan/runtime/api/Tensor.h

Lines changed: 6 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -149,13 +149,14 @@ class vTensor final {
149149
// to be interpreted as a tensor with a different size.
150150
api::utils::uvec3 virtual_extents_;
151151

152-
// A Vulkan uniform buffer containing the tensor sizes that can be passed into
153-
// a shader.
152+
// A Vulkan uniform buffer containing the tensor sizes in WHCN that can be
153+
// passed into a shader.
154154
std::shared_ptr<api::UniformParamsBuffer> cpu_sizes_uniform_;
155155

156-
// A Vulkan uniform buffer containing the GPU tensor sizes that can be passed
157-
// into a shader. GPU sizes refers to the sizes of the tensor after padding
158-
// has been applied to one dimension to align it to the next multiple of 4.
156+
// A Vulkan uniform buffer containing the GPU tensor sizes in WHCN that can
157+
// be passed into a shader. GPU sizes refers to the sizes of the tensor after
158+
// padding has been applied to one dimension to align it to the next multiple
159+
// of 4.
159160
std::shared_ptr<api::UniformParamsBuffer> gpu_sizes_uniform_;
160161

161162
// A Vulkan uniform buffer containing the image extents of the underlying

backends/vulkan/runtime/api/Utils.h

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -262,6 +262,12 @@ inline std::ostream& operator<<(std::ostream& os, const uvec3& v) {
262262
return os;
263263
}
264264

265+
inline std::ostream& operator<<(std::ostream& os, const uvec4& v) {
266+
os << "(" << v.data[0u] << ", " << v.data[1u] << ", " << v.data[2u] << ", "
267+
<< v.data[3u] << ")";
268+
return os;
269+
}
270+
265271
//
266272
// std::vector<T> Handling
267273
//

backends/vulkan/runtime/graph/Logging.h

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -29,4 +29,8 @@ inline std::ostream& operator<<(std::ostream& os, const api::utils::uvec3& v) {
2929
return api::utils::operator<<(os, v);
3030
}
3131

32+
inline std::ostream& operator<<(std::ostream& os, const api::utils::uvec4& v) {
33+
return api::utils::operator<<(os, v);
34+
}
35+
3236
} // namespace vkcompute
Lines changed: 79 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,79 @@
1+
/*
2+
* Copyright (c) Meta Platforms, Inc. and affiliates.
3+
* All rights reserved.
4+
*
5+
* This source code is licensed under the BSD-style license found in the
6+
* LICENSE file in the root directory of this source tree.
7+
*/
8+
9+
#version 450 core
10+
11+
#define PRECISION ${PRECISION}
12+
13+
#define VEC4_T ${texel_type(DTYPE)}
14+
15+
layout(std430) buffer;
16+
17+
#include "indexing_utils.h"
18+
19+
layout(set = 0, binding = 0, ${IMAGE_FORMAT[DTYPE]}) uniform PRECISION restrict writeonly ${IMAGE_T[NDIM][DTYPE]} image_out;
20+
layout(set = 0, binding = 1) uniform PRECISION sampler3D image_in;
21+
22+
layout(set = 0, binding = 2) uniform PRECISION restrict OutExtents {
23+
// tensor size in WHCN.
24+
uvec4 data;
25+
}
26+
out_sizes;
27+
28+
/*
29+
* Params Buffer
30+
*/
31+
layout(set = 0, binding = 3) uniform PRECISION restrict Block {
32+
// output dims
33+
uvec4 out_ndims;
34+
// x = output channels aligned to 4, y = input channels aligned to 4
35+
uvec2 ch_info;
36+
}
37+
uBlock;
38+
39+
/*
40+
* Local Work Group
41+
*/
42+
layout(local_size_x_id = 0, local_size_y_id = 1, local_size_z_id = 2) in;
43+
44+
void main() {
45+
const ivec3 posOut = ivec3(gl_GlobalInvocationID);
46+
47+
const ivec4 idx = to_tensor_idx_C_packed(posOut, out_sizes.data);
48+
if (any(greaterThanEqual(idx, out_sizes.data))) {
49+
return;
50+
}
51+
52+
const int out_channel_4up = int(uBlock.ch_info.x);
53+
const int in_channel_4up = int(uBlock.ch_info.y);
54+
const int out_batch = int(out_sizes.data[3]);
55+
const int max_dst_index = out_batch * out_channel_4up;
56+
VEC4_T outval = VEC4_T(0.0);
57+
58+
for (int j = 0; j < 4; ++j) {
59+
int dst_index = posOut.z * 4 + j;
60+
if (dst_index >= max_dst_index) {
61+
// out of range
62+
break;
63+
}
64+
65+
ivec4 v = ivec4(0); // holds b,c,h,w
66+
v[uBlock.out_ndims[0]] = dst_index / out_channel_4up;
67+
v[uBlock.out_ndims[1]] = dst_index % out_channel_4up;
68+
v[uBlock.out_ndims[2]] = posOut.y;
69+
v[uBlock.out_ndims[3]] = posOut.x;
70+
71+
int src_index = v[0] * in_channel_4up + v[1];
72+
int w = v[3];
73+
int h = v[2];
74+
75+
VEC4_T inval = VEC4_T(texelFetch(image_in, ivec3(w, h, src_index / 4), 0));
76+
outval[j] = inval[src_index % 4];
77+
}
78+
imageStore(image_out, posOut, outval);
79+
}
Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,10 @@
1+
permute:
2+
parameter_names_with_default_values:
3+
DTYPE: float
4+
NDIM: 3
5+
generate_variant_forall:
6+
DTYPE:
7+
- VALUE: half
8+
- VALUE: float
9+
shader_variants:
10+
- NAME: permute
Lines changed: 102 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,102 @@
1+
/*
2+
* Copyright (c) Meta Platforms, Inc. and affiliates.
3+
* All rights reserved.
4+
*
5+
* This source code is licensed under the BSD-style license found in the
6+
* LICENSE file in the root directory of this source tree.
7+
*/
8+
9+
#include <executorch/backends/vulkan/runtime/graph/ops/OperatorRegistry.h>
10+
11+
#include <executorch/backends/vulkan/runtime/graph/ops/impl/utils/KernelUtils.h>
12+
#include <executorch/backends/vulkan/runtime/graph/ops/impl/utils/TensorUtils.h>
13+
#include <executorch/backends/vulkan/runtime/graph/ops/utils/ShaderNameUtils.h>
14+
15+
namespace vkcompute {
16+
17+
using api::utils::ivec3;
18+
using api::utils::uvec2;
19+
using api::utils::uvec4;
20+
21+
void check_args(
22+
const vTensor& in,
23+
const IntListPtr& permute_dims,
24+
const vTensor& out) {
25+
VK_CHECK_COND(check_memory_layout_is(in, api::kChannelsPacked));
26+
VK_CHECK_COND(check_memory_layout_is(out, api::kChannelsPacked));
27+
28+
int64_t in_dim = in.dim();
29+
VK_CHECK_COND(
30+
in_dim == permute_dims->size(),
31+
"Input tensor dim size must match argument");
32+
}
33+
34+
void add_permute_node(
35+
ComputeGraph& graph,
36+
ValueRef in,
37+
ValueRef permute_dims_ref,
38+
ValueRef out) {
39+
vTensorPtr t_in = graph.get_tensor(in);
40+
vTensorPtr t_out = graph.get_tensor(out);
41+
42+
IntListPtr permute_dims = graph.get_int_list(permute_dims_ref);
43+
44+
check_args(*t_in, permute_dims, *t_out);
45+
46+
uvec4 in_size{1u, 1u, 1u, 1u}, out_size{1u, 1u, 1u, 1u};
47+
uvec4 out_dims{0u, 1u, 2u, 3u};
48+
49+
int64_t in_dim = t_in->dim();
50+
51+
std::vector<bool> seen(in_dim);
52+
for (int i = 0; i < in_dim; i++) {
53+
int64_t permute_dim = (*permute_dims)[i];
54+
VK_CHECK_COND(
55+
!seen[permute_dim], "Argument dim ", permute_dim, " is repeated");
56+
seen[permute_dim] = true;
57+
58+
// Map to 4D tensor dims.
59+
in_size.data[(4u - in_dim) + i] = t_in->size(i);
60+
out_size.data[(4u - in_dim) + i] = t_in->size(permute_dim);
61+
out_dims.data[(4u - in_dim) + i] = permute_dim + (4u - in_dim);
62+
}
63+
64+
std::string kernel_name = "permute";
65+
kernel_name.reserve(kShaderNameReserve);
66+
add_dtype_suffix(kernel_name, *t_out);
67+
68+
uint32_t out_channels = out_size.data[1u];
69+
uint32_t in_channels = in_size.data[1u];
70+
71+
uint32_t out_c_aligned = api::utils::align_up(out_channels, 4u);
72+
uint32_t in_c_aligned = api::utils::align_up(in_channels, 4u);
73+
74+
const struct Block final {
75+
uvec4 out_ndims;
76+
uvec2 ch_info;
77+
} params{
78+
out_dims,
79+
{out_c_aligned, in_c_aligned},
80+
};
81+
82+
api::utils::uvec3 global_size = t_out->virtual_extents();
83+
api::utils::uvec3 local_size = adaptive_work_group_size(global_size);
84+
85+
graph.execute_nodes().emplace_back(new ExecuteNode(
86+
graph,
87+
VK_KERNEL_FROM_STR(kernel_name),
88+
global_size,
89+
local_size,
90+
{{out, api::MemoryAccessType::WRITE}, {in, api::MemoryAccessType::READ}},
91+
{t_out->gpu_sizes_ubo(), graph.create_params_buffer(params)}));
92+
}
93+
94+
void permute(ComputeGraph& graph, const std::vector<ValueRef>& args) {
95+
return add_permute_node(graph, args[0], args[1], args[2]);
96+
}
97+
98+
REGISTER_OPERATORS {
99+
VK_REGISTER_OP(aten.permute_copy.default, permute);
100+
}
101+
102+
} // namespace vkcompute

backends/vulkan/test/op_tests/cases.py

Lines changed: 24 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -171,6 +171,29 @@ def get_select_int_inputs():
171171
return test_suite
172172

173173

174+
def get_permute_inputs():
175+
test_suite = VkTestSuite(
176+
[
177+
((9, 2, 9, 4), [0, 1, 2, 3]),
178+
((9, 2, 9, 4), [0, 1, 3, 2]),
179+
((9, 2, 9, 4), [0, 2, 1, 3]),
180+
((9, 2, 9, 4), [0, 2, 3, 1]),
181+
((9, 2, 9, 4), [0, 3, 1, 2]),
182+
((9, 2, 9, 4), [0, 3, 2, 1]),
183+
((9, 2, 9, 4), [3, 0, 1, 2]),
184+
((9, 2, 9, 4), [3, 2, 0, 1]),
185+
((9, 2, 9, 4), [2, 3, 0, 1]),
186+
((9, 2, 9, 4), [2, 0, 3, 1]),
187+
((9, 2, 9), [2, 0, 1]),
188+
((9, 2, 9), [1, 2, 0]),
189+
((9, 2), [0, 1]),
190+
((9, 2), [1, 0]),
191+
]
192+
)
193+
test_suite.layouts = ["api::kChannelsPacked"]
194+
return test_suite
195+
196+
174197
test_suites = {
175198
"aten.add.Tensor": get_binary_elementwise_inputs(),
176199
"aten.sub.Tensor": get_binary_elementwise_inputs(),
@@ -183,4 +206,5 @@ def get_select_int_inputs():
183206
"aten.full.default": get_full_inputs(),
184207
"aten.select.int": get_select_int_inputs(),
185208
"aten.select_copy.int": get_select_int_inputs(),
209+
"aten.permute_copy.default": get_permute_inputs(),
186210
}

0 commit comments

Comments
 (0)