Skip to content

Commit c51d18f

Browse files
committed
[ET-VK][3/n] aten.permute_copy.default
Implementation adopted from LI, with clean-up. Differential Revision: [D56093765](https://our.internmc.facebook.com/intern/diff/D56093765/) ghstack-source-id: 222802330 Pull Request resolved: #3086
1 parent d46653b commit c51d18f

File tree

7 files changed

+237
-5
lines changed

7 files changed

+237
-5
lines changed

backends/vulkan/runtime/api/Tensor.h

Lines changed: 6 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -149,13 +149,14 @@ class vTensor final {
149149
// to be interpreted as a tensor with a different size.
150150
api::utils::uvec3 virtual_extents_;
151151

152-
// A Vulkan uniform buffer containing the tensor sizes that can be passed into
153-
// a shader.
152+
// A Vulkan uniform buffer containing the tensor sizes in WHCN that can be
153+
// passed into a shader.
154154
std::shared_ptr<api::UniformParamsBuffer> cpu_sizes_uniform_;
155155

156-
// A Vulkan uniform buffer containing the GPU tensor sizes that can be passed
157-
// into a shader. GPU sizes refers to the sizes of the tensor after padding
158-
// has been applied to one dimension to align it to the next multiple of 4.
156+
// A Vulkan uniform buffer containing the GPU tensor sizes in WHCN that can
157+
// be passed into a shader. GPU sizes refers to the sizes of the tensor after
158+
// padding has been applied to one dimension to align it to the next multiple
159+
// of 4.
159160
std::shared_ptr<api::UniformParamsBuffer> gpu_sizes_uniform_;
160161

161162
// A Vulkan uniform buffer containing the image extents of the underlying

backends/vulkan/runtime/api/Utils.h

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -262,6 +262,12 @@ inline std::ostream& operator<<(std::ostream& os, const uvec3& v) {
262262
return os;
263263
}
264264

265+
inline std::ostream& operator<<(std::ostream& os, const uvec4& v) {
266+
os << "(" << v.data[0u] << ", " << v.data[1u] << ", " << v.data[2u] << ", "
267+
<< v.data[3u] << ")";
268+
return os;
269+
}
270+
265271
//
266272
// std::vector<T> Handling
267273
//

backends/vulkan/runtime/graph/Logging.h

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -29,4 +29,8 @@ inline std::ostream& operator<<(std::ostream& os, const api::utils::uvec3& v) {
2929
return api::utils::operator<<(os, v);
3030
}
3131

32+
inline std::ostream& operator<<(std::ostream& os, const api::utils::uvec4& v) {
33+
return api::utils::operator<<(os, v);
34+
}
35+
3236
} // namespace vkcompute
Lines changed: 80 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,80 @@
1+
/*
2+
* Copyright (c) Meta Platforms, Inc. and affiliates.
3+
* All rights reserved.
4+
*
5+
* This source code is licensed under the BSD-style license found in the
6+
* LICENSE file in the root directory of this source tree.
7+
*/
8+
9+
#version 450 core
10+
11+
#define PRECISION ${PRECISION}
12+
13+
#define VEC4_T ${texel_type(DTYPE)}
14+
15+
layout(std430) buffer;
16+
17+
#include "indexing_utils.h"
18+
19+
layout(set = 0, binding = 0, ${IMAGE_FORMAT[DTYPE]}) uniform PRECISION restrict writeonly ${IMAGE_T[NDIM][DTYPE]} image_out;
20+
layout(set = 0, binding = 1) uniform PRECISION sampler3D image_in;
21+
22+
layout(set = 0, binding = 2) uniform PRECISION restrict OutExtents {
23+
// tensor size in WHCN.
24+
uvec4 data;
25+
}
26+
out_sizes;
27+
28+
/*
29+
* Params Buffer
30+
*/
31+
layout(set = 0, binding = 3) uniform PRECISION restrict Block {
32+
// output dims
33+
uvec4 out_ndims;
34+
// x = output channels aligned to 4, y = input channels aligned to 4
35+
uvec2 ch_info;
36+
}
37+
uBlock;
38+
39+
/*
40+
* Local Work Group
41+
*/
42+
layout(local_size_x_id = 0, local_size_y_id = 1, local_size_z_id = 2) in;
43+
44+
void main() {
45+
const ivec3 posOut = ivec3(gl_GlobalInvocationID);
46+
47+
const ivec4 idx = to_tensor_idx_C_packed(posOut, out_sizes.data);
48+
if (any(greaterThanEqual(idx, out_sizes.data))) {
49+
return;
50+
}
51+
52+
const int out_channel_4up = int(uBlock.ch_info.x);
53+
const int in_channel_4up = int(uBlock.ch_info.y);
54+
55+
const int out_batch = int(out_sizes.data[3]);
56+
const int max_dst_index = out_batch * out_channel_4up;
57+
VEC4_T outval = VEC4_T(0.0);
58+
59+
for (int j = 0; j < 4; ++j) {
60+
int dst_index = posOut.z * 4 + j;
61+
if (dst_index >= max_dst_index) {
62+
// out of range
63+
break;
64+
}
65+
66+
ivec4 v = ivec4(0); // holds b,c,h,w
67+
v[uBlock.out_ndims[0]] = dst_index / out_channel_4up;
68+
v[uBlock.out_ndims[1]] = dst_index % out_channel_4up;
69+
v[uBlock.out_ndims[2]] = posOut.y;
70+
v[uBlock.out_ndims[3]] = posOut.x;
71+
72+
int src_index = v[0] * in_channel_4up + v[1];
73+
int w = v[3];
74+
int h = v[2];
75+
76+
VEC4_T inval = VEC4_T(texelFetch(image_in, ivec3(w, h, src_index / 4), 0));
77+
outval[j] = inval[src_index % 4];
78+
}
79+
imageStore(image_out, posOut, outval);
80+
}
Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,10 @@
1+
permute:
2+
parameter_names_with_default_values:
3+
DTYPE: float
4+
NDIM: 3
5+
generate_variant_forall:
6+
DTYPE:
7+
- VALUE: half
8+
- VALUE: float
9+
shader_variants:
10+
- NAME: permute
Lines changed: 107 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,107 @@
1+
/*
2+
* Copyright (c) Meta Platforms, Inc. and affiliates.
3+
* All rights reserved.
4+
*
5+
* This source code is licensed under the BSD-style license found in the
6+
* LICENSE file in the root directory of this source tree.
7+
*/
8+
9+
#include <executorch/backends/vulkan/runtime/graph/ops/OperatorRegistry.h>
10+
11+
#include <executorch/backends/vulkan/runtime/api/api.h>
12+
13+
#include <executorch/backends/vulkan/runtime/graph/ops/impl/utils/KernelUtils.h>
14+
#include <executorch/backends/vulkan/runtime/graph/ops/impl/utils/TensorUtils.h>
15+
#include <executorch/backends/vulkan/runtime/graph/ops/utils/ShaderNameUtils.h>
16+
17+
#include <executorch/backends/vulkan/runtime/graph/Logging.h>
18+
#include <iostream>
19+
20+
namespace vkcompute {
21+
22+
using api::utils::ivec3;
23+
using api::utils::uvec2;
24+
using api::utils::uvec4;
25+
26+
void check_args(
27+
const vTensor& in,
28+
const IntListPtr& permute_dims,
29+
const vTensor& out) {
30+
VK_CHECK_COND(check_memory_layout_is(in, api::kChannelsPacked));
31+
VK_CHECK_COND(check_memory_layout_is(out, api::kChannelsPacked));
32+
33+
int64_t in_dim = in.dim();
34+
VK_CHECK_COND(
35+
in_dim == permute_dims->size(),
36+
"Input tensor dim size must match argument");
37+
}
38+
39+
void add_permute_node(
40+
ComputeGraph& graph,
41+
ValueRef in,
42+
ValueRef permute_dims_ref,
43+
ValueRef out) {
44+
vTensorPtr t_in = graph.get_tensor(in);
45+
vTensorPtr t_out = graph.get_tensor(out);
46+
47+
IntListPtr permute_dims = graph.get_int_list(permute_dims_ref);
48+
49+
check_args(*t_in, permute_dims, *t_out);
50+
51+
uvec4 in_size{1u, 1u, 1u, 1u}, out_size{1u, 1u, 1u, 1u};
52+
uvec4 out_dims{0u, 1u, 2u, 3u};
53+
54+
int64_t in_dim = t_in->dim();
55+
56+
std::vector<bool> seen(in_dim);
57+
for (int i = 0; i < in_dim; i++) {
58+
int64_t permute_dim = (*permute_dims)[i];
59+
VK_CHECK_COND(
60+
!seen[permute_dim], "Argument dim ", permute_dim, " is repeated");
61+
seen[permute_dim] = true;
62+
63+
// Map to 4D tensor dims.
64+
in_size.data[(4u - in_dim) + i] = t_in->size(i);
65+
out_size.data[(4u - in_dim) + i] = t_in->size(permute_dim);
66+
out_dims.data[(4u - in_dim) + i] = permute_dim + (4u - in_dim);
67+
}
68+
69+
std::string kernel_name = "permute";
70+
kernel_name.reserve(kShaderNameReserve);
71+
add_dtype_suffix(kernel_name, *t_out);
72+
73+
uint32_t out_channels = out_size.data[1u];
74+
uint32_t in_channels = in_size.data[1u];
75+
76+
uint32_t out_c_aligned = api::utils::align_up(out_channels, 4u);
77+
uint32_t in_c_aligned = api::utils::align_up(in_channels, 4u);
78+
79+
const struct Block final {
80+
uvec4 out_ndims;
81+
uvec2 ch_info;
82+
} params{
83+
out_dims,
84+
{out_c_aligned, in_c_aligned},
85+
};
86+
87+
api::utils::uvec3 global_size = t_out->virtual_extents();
88+
api::utils::uvec3 local_size = adaptive_work_group_size(global_size);
89+
90+
graph.execute_nodes().emplace_back(new ExecuteNode(
91+
graph,
92+
VK_KERNEL_FROM_STR(kernel_name),
93+
global_size,
94+
local_size,
95+
{{out, api::MemoryAccessType::WRITE}, {in, api::MemoryAccessType::READ}},
96+
{t_out->gpu_sizes_ubo(), graph.create_params_buffer(params)}));
97+
}
98+
99+
void permute(ComputeGraph& graph, const std::vector<ValueRef>& args) {
100+
return add_permute_node(graph, args[0], args[1], args[2]);
101+
}
102+
103+
REGISTER_OPERATORS {
104+
VK_REGISTER_OP(aten.permute_copy.default, permute);
105+
}
106+
107+
} // namespace vkcompute

backends/vulkan/test/op_tests/cases.py

Lines changed: 24 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -171,6 +171,29 @@ def get_select_int_inputs():
171171
return test_suite
172172

173173

174+
def get_permute_inputs():
175+
test_suite = VkTestSuite(
176+
[
177+
((9, 2, 9, 4), [0, 1, 2, 3]),
178+
((9, 2, 9, 4), [0, 1, 3, 2]),
179+
((9, 2, 9, 4), [0, 2, 1, 3]),
180+
((9, 2, 9, 4), [0, 2, 3, 1]),
181+
((9, 2, 9, 4), [0, 3, 1, 2]),
182+
((9, 2, 9, 4), [0, 3, 2, 1]),
183+
((9, 2, 9, 4), [3, 0, 1, 2]),
184+
((9, 2, 9, 4), [3, 2, 0, 1]),
185+
((9, 2, 9, 4), [2, 3, 0, 1]),
186+
((9, 2, 9, 4), [2, 0, 3, 1]),
187+
((9, 2, 9), [2, 0, 1]),
188+
((9, 2, 9), [1, 2, 0]),
189+
((9, 2), [0, 1]),
190+
((9, 2), [1, 0]),
191+
]
192+
)
193+
test_suite.layouts = ["api::kChannelsPacked"]
194+
return test_suite
195+
196+
174197
test_suites = {
175198
"aten.add.Tensor": get_binary_elementwise_inputs(),
176199
"aten.sub.Tensor": get_binary_elementwise_inputs(),
@@ -183,4 +206,5 @@ def get_select_int_inputs():
183206
"aten.full.default": get_full_inputs(),
184207
"aten.select.int": get_select_int_inputs(),
185208
"aten.select_copy.int": get_select_int_inputs(),
209+
"aten.permute_copy.default": get_permute_inputs(),
186210
}

0 commit comments

Comments
 (0)