Skip to content

Commit bf9d119

Browse files
committed
[ET-VK][10/n] copy node, aten.repeat
Pull Request resolved: #3299 1. Introduce a `CopyNode` for generic copy-with-offset operations. 2. `aten.repeat` on all dimensions. 2.1 Use `CopyNode` where possible. 2.2. Specialized `repeat_channel` shader to handle packings 3. Update codegen to support `Methods` variant only operations. Need a new route to trigger the dispatch. ghstack-source-id: 223792412 Differential Revision: [D56499329](https://our.internmc.facebook.com/intern/diff/D56499329/)
1 parent b669056 commit bf9d119

File tree

13 files changed

+565
-4
lines changed

13 files changed

+565
-4
lines changed

backends/vulkan/runtime/api/Tensor.h

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -220,6 +220,10 @@ class vTensor final {
220220
*/
221221
const api::BufferBindInfo texture_limits_ubo();
222222

223+
inline const api::utils::ivec3 texture_limits() const {
224+
return texture_limits_.limits;
225+
}
226+
223227
inline size_t numel() const {
224228
return api::utils::multiply_integers(sizes());
225229
}

backends/vulkan/runtime/api/Utils.h

Lines changed: 37 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -262,12 +262,23 @@ inline std::ostream& operator<<(std::ostream& os, const uvec3& v) {
262262
return os;
263263
}
264264

265+
inline std::ostream& operator<<(std::ostream& os, const ivec3& v) {
266+
os << "(" << v.data[0u] << ", " << v.data[1u] << ", " << v.data[2u] << ")";
267+
return os;
268+
}
269+
265270
inline std::ostream& operator<<(std::ostream& os, const uvec4& v) {
266271
os << "(" << v.data[0u] << ", " << v.data[1u] << ", " << v.data[2u] << ", "
267272
<< v.data[3u] << ")";
268273
return os;
269274
}
270275

276+
inline std::ostream& operator<<(std::ostream& os, const ivec4& v) {
277+
os << "(" << v.data[0u] << ", " << v.data[1u] << ", " << v.data[2u] << ", "
278+
<< v.data[3u] << ")";
279+
return os;
280+
}
281+
271282
//
272283
// std::vector<T> Handling
273284
//
@@ -298,6 +309,25 @@ inline ivec2 make_ivec2(
298309
}
299310
}
300311

312+
inline ivec3 make_ivec3(
313+
const std::vector<int64_t>& ints,
314+
bool reverse = false) {
315+
VK_CHECK_COND(ints.size() == 3);
316+
if (reverse) {
317+
return {
318+
safe_downcast<int32_t>(ints[2]),
319+
safe_downcast<int32_t>(ints[1]),
320+
safe_downcast<int32_t>(ints[0]),
321+
};
322+
} else {
323+
return {
324+
safe_downcast<int32_t>(ints[0]),
325+
safe_downcast<int32_t>(ints[1]),
326+
safe_downcast<int32_t>(ints[2]),
327+
};
328+
}
329+
}
330+
301331
inline ivec4 make_ivec4(
302332
const std::vector<int64_t>& ints,
303333
bool reverse = false) {
@@ -338,6 +368,13 @@ inline ivec3 make_ivec3(uvec3 ints) {
338368
safe_downcast<int32_t>(ints.data[2u])};
339369
}
340370

371+
inline uvec3 make_uvec3(ivec3 ints) {
372+
return {
373+
safe_downcast<uint32_t>(ints.data[0u]),
374+
safe_downcast<uint32_t>(ints.data[1u]),
375+
safe_downcast<uint32_t>(ints.data[2u])};
376+
}
377+
341378
/*
342379
* Given an vector of up to 4 uint64_t representing the sizes of a tensor,
343380
* constructs a uvec4 containing those elements in reverse order.

backends/vulkan/runtime/graph/Logging.h

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -34,6 +34,14 @@ inline std::ostream& operator<<(std::ostream& os, const api::utils::uvec4& v) {
3434
return api::utils::operator<<(os, v);
3535
}
3636

37+
inline std::ostream& operator<<(std::ostream& os, const api::utils::ivec3& v) {
38+
return api::utils::operator<<(os, v);
39+
}
40+
41+
inline std::ostream& operator<<(std::ostream& os, const api::utils::ivec4& v) {
42+
return api::utils::operator<<(os, v);
43+
}
44+
3745
template <typename T>
3846
inline std::ostream& operator<<(std::ostream& os, const std::optional<T>& opt) {
3947
os << "[";
Lines changed: 54 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,54 @@
1+
/*
2+
* Copyright (c) Meta Platforms, Inc. and affiliates.
3+
* All rights reserved.
4+
*
5+
* This source code is licensed under the BSD-style license found in the
6+
* LICENSE file in the root directory of this source tree.
7+
*/
8+
9+
#version 450 core
10+
11+
#define PRECISION ${PRECISION}
12+
13+
#define VEC4_T ${texel_type(DTYPE)}
14+
15+
layout(std430) buffer;
16+
17+
#include "indexing_utils.h"
18+
19+
layout(set = 0, binding = 0, ${IMAGE_FORMAT[DTYPE]}) uniform PRECISION restrict writeonly ${IMAGE_T[NDIM][DTYPE]} image_out;
20+
layout(set = 0, binding = 1) uniform PRECISION sampler3D image_in;
21+
22+
layout(set = 0, binding = 2) uniform PRECISION restrict OutLimits {
23+
ivec3 out_limits;
24+
};
25+
26+
layout(set = 0, binding = 3) uniform PRECISION restrict InLimits {
27+
ivec3 in_limits;
28+
};
29+
30+
31+
32+
layout(set = 0, binding = 4) uniform PRECISION restrict CopyArgs {
33+
ivec3 range;
34+
int unused0;
35+
ivec3 src_offset;
36+
int unused1;
37+
ivec3 dst_offset;
38+
int unused2;
39+
};
40+
41+
layout(local_size_x_id = 0, local_size_y_id = 1, local_size_z_id = 2) in;
42+
43+
void main() {
44+
const ivec3 pos = ivec3(gl_GlobalInvocationID);
45+
46+
const ivec3 out_pos = pos + dst_offset;
47+
const ivec3 in_pos = pos + src_offset;
48+
49+
if (any(greaterThanEqual(pos, range))) {
50+
return;
51+
}
52+
53+
imageStore(image_out, out_pos, texelFetch(image_in, in_pos, 0));
54+
}
Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,10 @@
1+
copy_offset:
2+
parameter_names_with_default_values:
3+
DTYPE: float
4+
NDIM: 3
5+
generate_variant_forall:
6+
DTYPE:
7+
- VALUE: half
8+
- VALUE: float
9+
shader_variants:
10+
- NAME: copy_offset
Lines changed: 58 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,58 @@
1+
/*
2+
* Copyright (c) Meta Platforms, Inc. and affiliates.
3+
* All rights reserved.
4+
*
5+
* This source code is licensed under the BSD-style license found in the
6+
* LICENSE file in the root directory of this source tree.
7+
*/
8+
9+
#version 450 core
10+
11+
#define PRECISION ${PRECISION}
12+
13+
#define VEC4_T ${texel_type(DTYPE)}
14+
15+
layout(std430) buffer;
16+
17+
#include "indexing_utils.h"
18+
19+
layout(set = 0, binding = 0, ${IMAGE_FORMAT[DTYPE]}) uniform PRECISION restrict writeonly ${IMAGE_T[NDIM][DTYPE]} image_out;
20+
layout(set = 0, binding = 1) uniform PRECISION sampler3D image_in;
21+
22+
layout(set = 0, binding = 2) uniform PRECISION restrict RepeatArgs {
23+
// With input_size (n, c_i, h, w) and repeat r
24+
// out_size == (n, c_i * r, h, w)
25+
ivec4 out_sizes;
26+
ivec4 in_sizes;
27+
};
28+
29+
layout(local_size_x_id = 0, local_size_y_id = 1, local_size_z_id = 2) in;
30+
31+
layout(constant_id = 3) const int packed_dim = C_DIM;
32+
33+
34+
void main() {
35+
const ivec3 out_pos = ivec3(gl_GlobalInvocationID);
36+
37+
const ivec4 out_whcn = to_tensor_idx(out_pos, out_sizes, packed_dim);
38+
39+
if (any(greaterThanEqual(out_whcn, out_sizes))) {
40+
return;
41+
}
42+
43+
VEC4_T v;
44+
// Loop over the 4 elements in texel, calculate the corresponding elem, and
45+
// fetch. Not most efficient algorithm because likely we fetch same texel
46+
// multiple times in this loop.
47+
48+
for (int i=0; i<4;i++) {
49+
ivec4 in_whcn = out_whcn;
50+
in_whcn.z = (out_whcn.z + i) % in_sizes.z;
51+
52+
ivec4 in_elem_pos = to_texture_elem_pos(in_whcn, in_sizes, packed_dim);
53+
54+
v[i] = VEC4_T(texelFetch(image_in, in_elem_pos.xyz, 0))[in_elem_pos.w];
55+
}
56+
57+
imageStore(image_out, out_pos, v);
58+
}
Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,10 @@
1+
repeat_channel:
2+
parameter_names_with_default_values:
3+
DTYPE: float
4+
NDIM: 3
5+
generate_variant_forall:
6+
DTYPE:
7+
- VALUE: half
8+
- VALUE: float
9+
shader_variants:
10+
- NAME: repeat_channel
Lines changed: 75 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,75 @@
1+
/*
2+
* Copyright (c) Meta Platforms, Inc. and affiliates.
3+
* All rights reserved.
4+
*
5+
* This source code is licensed under the BSD-style license found in the
6+
* LICENSE file in the root directory of this source tree.
7+
*/
8+
9+
#include <executorch/backends/vulkan/runtime/graph/ops/OperatorRegistry.h>
10+
11+
#include <executorch/backends/vulkan/runtime/api/api.h>
12+
#include <executorch/backends/vulkan/runtime/graph/Logging.h>
13+
14+
#include <executorch/backends/vulkan/runtime/graph/ops/impl/utils/KernelUtils.h>
15+
#include <executorch/backends/vulkan/runtime/graph/ops/impl/utils/TensorUtils.h>
16+
#include <executorch/backends/vulkan/runtime/graph/ops/utils/ShaderNameUtils.h>
17+
18+
#include <iostream>
19+
20+
namespace vkcompute {
21+
22+
void add_copy_offset_node(
23+
ComputeGraph& graph,
24+
const ValueRef in,
25+
const api::utils::ivec3& range,
26+
const api::utils::ivec3& src_offset,
27+
const api::utils::ivec3& dst_offset,
28+
const ValueRef out) {
29+
vTensorPtr t_in = graph.get_tensor(in);
30+
vTensorPtr t_out = graph.get_tensor(out);
31+
32+
VK_CHECK_COND(check_memory_layout_is(*t_in, api::kChannelsPacked));
33+
VK_CHECK_COND(check_memory_layout_is(*t_out, api::kChannelsPacked));
34+
35+
std::string kernel_name = "copy_offset";
36+
kernel_name.reserve(kShaderNameReserve);
37+
add_dtype_suffix(kernel_name, *t_out);
38+
39+
api::utils::uvec3 global_size = api::utils::make_uvec3(range);
40+
api::utils::uvec3 local_size = adaptive_work_group_size(global_size);
41+
42+
const struct Block final {
43+
api::utils::ivec3 range;
44+
int32_t unused0;
45+
api::utils::ivec3 src_offset;
46+
int32_t unused1;
47+
api::utils::ivec3 dst_offset;
48+
int32_t unused2;
49+
} offset_params{
50+
range,
51+
0,
52+
src_offset,
53+
0,
54+
dst_offset,
55+
0,
56+
};
57+
58+
auto shader = VK_KERNEL_FROM_STR(kernel_name);
59+
60+
graph.execute_nodes().emplace_back(new ExecuteNode(
61+
graph,
62+
VK_KERNEL_FROM_STR(kernel_name),
63+
global_size,
64+
local_size,
65+
// Inputs and Outputs
66+
{{out, api::MemoryAccessType::WRITE}, {in, api::MemoryAccessType::READ}},
67+
// Parameter buffers
68+
{t_out->texture_limits_ubo(),
69+
t_in->texture_limits_ubo(),
70+
graph.create_params_buffer(offset_params)},
71+
// Specialization Constants
72+
{}));
73+
}
74+
75+
} // namespace vkcompute
Lines changed: 25 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,25 @@
1+
/*
2+
* Copyright (c) Meta Platforms, Inc. and affiliates.
3+
* All rights reserved.
4+
*
5+
* This source code is licensed under the BSD-style license found in the
6+
* LICENSE file in the root directory of this source tree.
7+
*/
8+
9+
#pragma once
10+
11+
#include <executorch/backends/vulkan/runtime/graph/ops/OperatorRegistry.h>
12+
13+
#include <executorch/backends/vulkan/runtime/api/api.h>
14+
15+
namespace vkcompute {
16+
17+
void add_copy_offset_node(
18+
ComputeGraph& graph,
19+
const ValueRef in,
20+
const api::utils::ivec3& range,
21+
const api::utils::ivec3& src_offset,
22+
const api::utils::ivec3& dst_offset,
23+
const ValueRef out);
24+
25+
} // namespace vkcompute

backends/vulkan/runtime/graph/ops/impl/Permute.cpp

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,8 @@ using api::utils::ivec3;
2121
using api::utils::uvec2;
2222
using api::utils::uvec4;
2323

24+
namespace {
25+
2426
void check_args(
2527
const vTensor& in,
2628
const std::vector<int64_t>& permute_dims,
@@ -39,6 +41,8 @@ void check_args(
3941
"Output tensor dim size must match argument");
4042
}
4143

44+
} // namespace
45+
4246
void add_permute_node(
4347
ComputeGraph& graph,
4448
ValueRef in,

0 commit comments

Comments
 (0)