Skip to content

Commit 6d1fed4

Browse files
yipjustinfacebook-github-bot
authored andcommitted
aten.select.int (#3033)
Summary: Port over the `select.int` shaders to ET. 1. Since in ET, tensor-shape reasoning happens in AOT, therefore we can simplify the c++ caller code by a lot. 2. In this diff, we also try to use the same buffer object for passing arguments to all shaders. Not worry about perf cost, since cost difference between passing int and ivec4 is very minor. Reviewed By: SS-JIA Differential Revision: D56082483
1 parent d0208d0 commit 6d1fed4

18 files changed

+657
-0
lines changed

backends/vulkan/runtime/api/Tensor.h

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -255,6 +255,14 @@ class vTensor final {
255255
return sizes_;
256256
}
257257

258+
inline const int64_t size(size_t dim) const {
259+
return sizes().at(dim);
260+
}
261+
262+
inline const int64_t dim() const {
263+
return sizes_.size();
264+
}
265+
258266
inline const std::vector<int64_t>& strides() const {
259267
return strides_;
260268
}

backends/vulkan/runtime/graph/Logging.h

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,8 @@
88

99
#pragma once
1010

11+
#include <executorch/backends/vulkan/runtime/api/Utils.h>
12+
1113
#include <ostream>
1214
#include <vector>
1315

@@ -23,4 +25,8 @@ inline std::ostream& operator<<(std::ostream& os, const std::vector<T>& vec) {
2325
return os; // Return the ostream to allow chaining
2426
}
2527

28+
inline std::ostream& operator<<(std::ostream& os, const api::utils::uvec3& v) {
29+
return api::utils::operator<<(os, v);
30+
}
31+
2632
} // namespace vkcompute
Lines changed: 53 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,53 @@
1+
/*
2+
* Copyright (c) Meta Platforms, Inc. and affiliates.
3+
* All rights reserved.
4+
*
5+
* This source code is licensed under the BSD-style license found in the
6+
* LICENSE file in the root directory of this source tree.
7+
*/
8+
9+
#version 450 core
10+
11+
#define PRECISION ${PRECISION}
12+
13+
layout(std430) buffer;
14+
15+
#include "indexing_utils.h"
16+
17+
layout(set = 0, binding = 0, ${IMAGE_FORMAT[DTYPE]}) uniform PRECISION restrict writeonly ${IMAGE_T[NDIM][DTYPE]} image_out;
18+
layout(set = 0, binding = 1) uniform PRECISION sampler3D image_in;
19+
20+
layout(set = 0, binding = 2) uniform PRECISION restrict OutSizes {
21+
uvec4 data;
22+
}
23+
out_sizes;
24+
25+
layout(set = 0, binding = 3) uniform PRECISION restrict SelectVal {
26+
// data.x: index along batch dim to select
27+
// data.y: number of batches
28+
// data.z: number of texels per batch
29+
// data.w: unused
30+
ivec4 data;
31+
}
32+
select_info;
33+
34+
layout(local_size_x_id = 0, local_size_y_id = 1, local_size_z_id = 2) in;
35+
36+
void main() {
37+
const int num_batches = select_info.data.y;
38+
const int num_texel_per_batch = select_info.data.z;
39+
const int index = select_info.data.x;
40+
41+
const ivec3 pos = ivec3(gl_GlobalInvocationID);
42+
43+
const ivec4 idx = to_tensor_idx_C_packed(pos, out_sizes.data);
44+
45+
if (any(greaterThanEqual(idx, out_sizes.data))) {
46+
return;
47+
}
48+
49+
const uint src_pos_z = (num_texel_per_batch * index) + pos.z;
50+
imageStore(
51+
image_out, pos, texelFetch(image_in, ivec3(pos.x, pos.y, src_pos_z), 0));
52+
}
53+
Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,10 @@
1+
select_batch_4d:
2+
parameter_names_with_default_values:
3+
DTYPE: float
4+
NDIM: 3
5+
generate_variant_forall:
6+
DTYPE:
7+
- VALUE: half
8+
- VALUE: float
9+
shader_variants:
10+
- NAME: select_batch_4d
Lines changed: 50 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,50 @@
1+
/*
2+
* Copyright (c) Meta Platforms, Inc. and affiliates.
3+
* All rights reserved.
4+
*
5+
* This source code is licensed under the BSD-style license found in the
6+
* LICENSE file in the root directory of this source tree.
7+
*/
8+
9+
#version 450 core
10+
11+
#define PRECISION ${PRECISION}
12+
13+
#define VEC4_T ${texel_type(DTYPE)}
14+
#define T ${texel_component_type(DTYPE)}
15+
16+
layout(std430) buffer;
17+
18+
#include "indexing_utils.h"
19+
20+
layout(set = 0, binding = 0, ${IMAGE_FORMAT[DTYPE]}) uniform PRECISION restrict writeonly ${IMAGE_T[NDIM][DTYPE]} image_out;
21+
layout(set = 0, binding = 1) uniform PRECISION sampler3D image_in;
22+
23+
layout(set = 0, binding = 2) uniform PRECISION restrict OutSizes {
24+
uvec4 data;
25+
}
26+
out_sizes;
27+
28+
// index to select
29+
layout(set = 0, binding = 3) uniform PRECISION restrict IndexVal {
30+
int data;
31+
}
32+
index;
33+
34+
layout(local_size_x_id = 0, local_size_y_id = 1, local_size_z_id = 2) in;
35+
36+
void main() {
37+
const ivec3 pos = ivec3(gl_GlobalInvocationID);
38+
39+
const ivec4 idx = to_tensor_idx_C_packed(pos, out_sizes.data);
40+
41+
if (any(greaterThanEqual(idx, out_sizes.data))) {
42+
return;
43+
}
44+
45+
const int tex = index.data / 4;
46+
const int ind = index.data % 4;
47+
const T v = VEC4_T(texelFetch(image_in, ivec3(pos.x, pos.y, tex), 0))[ind];
48+
49+
imageStore(image_out, ivec3(pos.x, pos.y, 0), VEC4_T(v, 0, 0, 0));
50+
}
Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,10 @@
1+
select_channel_3d:
2+
parameter_names_with_default_values:
3+
DTYPE: float
4+
NDIM: 3
5+
generate_variant_forall:
6+
DTYPE:
7+
- VALUE: half
8+
- VALUE: float
9+
shader_variants:
10+
- NAME: select_channel_3d
Lines changed: 64 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,64 @@
1+
/*
2+
* Copyright (c) Meta Platforms, Inc. and affiliates.
3+
* All rights reserved.
4+
*
5+
* This source code is licensed under the BSD-style license found in the
6+
* LICENSE file in the root directory of this source tree.
7+
*/
8+
9+
#version 450 core
10+
11+
#define PRECISION ${PRECISION}
12+
#define VEC4_T ${texel_type(DTYPE)}
13+
14+
layout(std430) buffer;
15+
16+
#include "indexing_utils.h"
17+
18+
layout(set = 0, binding = 0, ${IMAGE_FORMAT[DTYPE]}) uniform PRECISION restrict writeonly ${IMAGE_T[NDIM][DTYPE]} image_out;
19+
layout(set = 0, binding = 1) uniform PRECISION sampler3D image_in;
20+
21+
layout(set = 0, binding = 2) uniform PRECISION restrict OutSizes {
22+
uvec4 data;
23+
}
24+
out_sizes;
25+
26+
layout(set = 0, binding = 3) uniform PRECISION restrict SelectVal {
27+
// data.x: index along channel dim to select
28+
// data.y: number of batches
29+
// data.z: number of texels per batch
30+
// data.w: unused
31+
ivec4 data;
32+
}
33+
select_info;
34+
35+
layout(local_size_x_id = 0, local_size_y_id = 1, local_size_z_id = 2) in;
36+
37+
void main() {
38+
const ivec3 pos = ivec3(gl_GlobalInvocationID);
39+
const ivec4 idx = to_tensor_idx_C_packed(pos, out_sizes.data);
40+
41+
if (any(greaterThanEqual(idx, out_sizes.data))) {
42+
return;
43+
}
44+
45+
const int num_batches = select_info.data.y;
46+
const int num_texel_per_batch = select_info.data.z;
47+
const int index = select_info.data.x;
48+
49+
// read in the same channel from 4 separate batches
50+
VEC4_T out_texel = VEC4_T(0, 0, 0, 0);
51+
for (int k = 0; k < 4; k++) {
52+
if ((k + pos.z * 4) >=
53+
num_batches) {
54+
break;
55+
}
56+
const uint src_pos_z = (4 * num_texel_per_batch * pos.z) +
57+
(k * num_texel_per_batch) + (index / 4);
58+
const uint src_pos_t = index % 4;
59+
out_texel[k] =
60+
VEC4_T(texelFetch(image_in, ivec3(pos.x, pos.y, src_pos_z), 0))[src_pos_t];
61+
}
62+
63+
imageStore(image_out, pos, out_texel);
64+
}
Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,10 @@
1+
select_channel_4d:
2+
parameter_names_with_default_values:
3+
DTYPE: float
4+
NDIM: 3
5+
generate_variant_forall:
6+
DTYPE:
7+
- VALUE: half
8+
- VALUE: float
9+
shader_variants:
10+
- NAME: select_channel_4d
Lines changed: 62 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,62 @@
1+
/*
2+
* Copyright (c) Meta Platforms, Inc. and affiliates.
3+
* All rights reserved.
4+
*
5+
* This source code is licensed under the BSD-style license found in the
6+
* LICENSE file in the root directory of this source tree.
7+
*/
8+
9+
#version 450 core
10+
11+
#define PRECISION ${PRECISION}
12+
#define VEC4_T ${texel_type(DTYPE)}
13+
14+
layout(std430) buffer;
15+
16+
#include "indexing_utils.h"
17+
18+
layout(set = 0, binding = 0, ${IMAGE_FORMAT[DTYPE]}) uniform PRECISION restrict writeonly ${IMAGE_T[NDIM][DTYPE]} image_out;
19+
layout(set = 0, binding = 1) uniform PRECISION sampler3D image_in;
20+
21+
layout(set = 0, binding = 2) uniform PRECISION restrict OutSizes {
22+
uvec4 data;
23+
}
24+
out_sizes;
25+
26+
// index to select
27+
layout(set = 0, binding = 3) uniform PRECISION restrict IndexVal {
28+
int data;
29+
}
30+
index;
31+
32+
layout(local_size_x_id = 0, local_size_y_id = 1, local_size_z_id = 2) in;
33+
34+
void main() {
35+
const ivec3 pos = ivec3(gl_GlobalInvocationID);
36+
37+
const ivec4 idx = to_tensor_idx_C_packed(pos, out_sizes.data);
38+
39+
if (any(greaterThanEqual(idx, out_sizes.data))) {
40+
return;
41+
}
42+
43+
// w
44+
const int src_x = pos.x;
45+
// h
46+
const int src_y = index.data;
47+
// c
48+
const int src_z = pos.y;
49+
50+
const VEC4_T v = VEC4_T(texelFetch(image_in, ivec3(src_x, src_y, src_z), 0));
51+
52+
for (int i = 0; i < 4; i++) {
53+
ivec3 new_pos = ivec3(pos.x, pos.y * 4 + i, 0);
54+
55+
// When the C-channel exceeds original block size, exit early
56+
if (new_pos.y >= out_sizes.data.y) {
57+
return;
58+
}
59+
60+
imageStore(image_out, new_pos, VEC4_T(v[i], 0, 0, 0));
61+
}
62+
}
Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,10 @@
1+
select_height_3d:
2+
parameter_names_with_default_values:
3+
DTYPE: float
4+
NDIM: 3
5+
generate_variant_forall:
6+
DTYPE:
7+
- VALUE: half
8+
- VALUE: float
9+
shader_variants:
10+
- NAME: select_height_3d
Lines changed: 62 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,62 @@
1+
/*
2+
* Copyright (c) Meta Platforms, Inc. and affiliates.
3+
* All rights reserved.
4+
*
5+
* This source code is licensed under the BSD-style license found in the
6+
* LICENSE file in the root directory of this source tree.
7+
*/
8+
9+
#version 450 core
10+
11+
#define PRECISION ${PRECISION}
12+
#define VEC4_T ${texel_type(DTYPE)}
13+
14+
layout(std430) buffer;
15+
16+
#include "indexing_utils.h"
17+
18+
layout(set = 0, binding = 0, ${IMAGE_FORMAT[DTYPE]}) uniform PRECISION restrict writeonly ${IMAGE_T[NDIM][DTYPE]} image_out;
19+
layout(set = 0, binding = 1) uniform PRECISION sampler3D image_in;
20+
21+
layout(set = 0, binding = 2) uniform PRECISION restrict OutSizes {
22+
uvec4 data;
23+
}
24+
out_sizes;
25+
26+
// index to select
27+
layout(set = 0, binding = 3) uniform PRECISION restrict IndexVal {
28+
// data.x: index along height dim to select
29+
// data.y: number of batches
30+
// data.z: number of texels per batch
31+
// data.w: unused
32+
ivec4 data;
33+
}
34+
select_info;
35+
36+
layout(local_size_x_id = 0, local_size_y_id = 1, local_size_z_id = 2) in;
37+
38+
void main() {
39+
const ivec3 pos = ivec3(gl_GlobalInvocationID);
40+
const ivec4 idx = to_tensor_idx_C_packed(pos, out_sizes.data);
41+
if (any(greaterThanEqual(idx, out_sizes.data))) {
42+
return;
43+
}
44+
45+
const int num_batches = select_info.data.y;
46+
const int num_texel_per_batch = select_info.data.z;
47+
const int index = select_info.data.x;
48+
49+
VEC4_T out_texel = VEC4_T(0, 0, 0, 0);
50+
// read in the same channel from 4 separate batches
51+
for (int k = 0; k < 4; k++) {
52+
if ((k + pos.z * 4) >= num_batches
53+
) { // < 4 batches for this texel, exit early
54+
break;
55+
}
56+
const uint src_pos_z = (pos.z * num_texel_per_batch * 4) +
57+
k * num_texel_per_batch + (pos.y / 4);
58+
out_texel[k] = VEC4_T(texelFetch(
59+
image_in, ivec3(pos.x, index, src_pos_z), 0))[pos.y % 4];
60+
}
61+
imageStore(image_out, pos, out_texel);
62+
}
Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,10 @@
1+
select_height_4d:
2+
parameter_names_with_default_values:
3+
DTYPE: float
4+
NDIM: 3
5+
generate_variant_forall:
6+
DTYPE:
7+
- VALUE: half
8+
- VALUE: float
9+
shader_variants:
10+
- NAME: select_height_4d

0 commit comments

Comments
 (0)