Skip to content

Commit b7c4590

Browse files
committed
Update on "[ET-VK] Allow overwriting local workgroup size"
Introduce a `GraphConfig` toggle following the convention of `storage_type` and `memory_layout`. Differential Revision: [D58957058](https://our.internmc.facebook.com/intern/diff/D58957058/) [ghstack-poisoned]
2 parents 3ddce46 + 1875f3a commit b7c4590

File tree

4 files changed

+24
-16
lines changed

4 files changed

+24
-16
lines changed

backends/vulkan/runtime/graph/ops/glsl/conv2d_dw_output_tile.glsl

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,8 @@
1212

1313
#define VEC4_T ${texel_type(DTYPE)}
1414

15+
#define TILE_SIZE ${TILE_SIZE}
16+
1517
#define op(X, A, B) ${OPERATOR}
1618

1719
#include "indexing_utils.h"
@@ -73,8 +75,8 @@ void main() {
7375

7476
VEC4_T sum = texelFetch(bias_in, ivec2(pos.z, 0), 0);
7577
int kx = 0;
76-
for (int y = start.y, i = 0; i < ${TILE_SIZE}; y += dilation.y, i++) {
77-
for (int x = start.x, j = 0; j < ${TILE_SIZE}; x += dilation.x, j++) {
78+
for (int y = start.y, i = 0; i < TILE_SIZE; y += dilation.y, i++) {
79+
for (int x = start.x, j = 0; j < TILE_SIZE; x += dilation.x, j++) {
7880
// The weight kernel was rearranged such that every NxN filter is
7981
// flattened to fit in one row. Each filter was then stacked on top of
8082
// each other vertically.

backends/vulkan/runtime/graph/ops/glsl/conv2d_pw.glsl

Lines changed: 14 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,8 @@
1212

1313
#define VEC4_T ${texel_type(DTYPE)}
1414

15+
#define TILE_SIZE ${TILE_SIZE}
16+
1517
#define op(X, A, B) ${OPERATOR}
1618

1719
#include "indexing_utils.h"
@@ -65,11 +67,11 @@ void main() {
6567
// +--------+--------+
6668
// | pos[2] | pos[3] |
6769
// +--------+--------+
68-
ivec3 pos[${TILE_SIZE * TILE_SIZE}];
69-
for (int y = 0, i = 0; y < ${TILE_SIZE}; ++y) {
70-
for (int x = 0; x < ${TILE_SIZE}; ++x) {
70+
ivec3 pos[TILE_SIZE * TILE_SIZE];
71+
for (int y = 0, i = 0; y < TILE_SIZE; ++y) {
72+
for (int x = 0; x < TILE_SIZE; ++x) {
7173
pos[i] = ivec3(
72-
gpos.x * ${TILE_SIZE} + x, gpos.y * ${TILE_SIZE} + y, gpos.z);
74+
gpos.x * TILE_SIZE + x, gpos.y * TILE_SIZE + y, gpos.z);
7375
i++;
7476
}
7577
}
@@ -83,14 +85,14 @@ void main() {
8385
// Compute the index of the input texture that needs to be loaded for each
8486
// output position. Note that negative indices can be produced indicating that
8587
// the top-left element is in a region added by padding.
86-
ivec2 ipos[${TILE_SIZE * TILE_SIZE}];
87-
for (int i = 0; i < ${TILE_SIZE * TILE_SIZE}; ++i) {
88+
ivec2 ipos[TILE_SIZE * TILE_SIZE];
89+
for (int i = 0; i < TILE_SIZE * TILE_SIZE; ++i) {
8890
ipos[i] = pos[i].xy * stride - padding;
8991
}
9092

91-
vec4 sum[${TILE_SIZE * TILE_SIZE}];
93+
vec4 sum[TILE_SIZE * TILE_SIZE];
9294
sum[0] = texelFetch(bias_in, ivec2(gpos.z, 0), 0);
93-
for (int i = 1; i < ${TILE_SIZE * TILE_SIZE}; ++i) {
95+
for (int i = 1; i < TILE_SIZE * TILE_SIZE; ++i) {
9496
sum[i] = sum[0];
9597
}
9698

@@ -99,17 +101,17 @@ void main() {
99101
// During prepacking, the weight tensor has been permuted so that the
100102
// channel (IC) dim is along the x-axis, and the batch (OC) dim is along
101103
// the z-axis.
102-
vec4 in_tex[${TILE_SIZE * TILE_SIZE}];
104+
vec4 in_tex[TILE_SIZE * TILE_SIZE];
103105
const vec4 ktex_0 = texelFetch(kernel_in, ivec2(z + 0, gpos.z), 0);
104106
const vec4 ktex_1 = texelFetch(kernel_in, ivec2(z + 1, gpos.z), 0);
105107
const vec4 ktex_2 = texelFetch(kernel_in, ivec2(z + 2, gpos.z), 0);
106108
const vec4 ktex_3 = texelFetch(kernel_in, ivec2(z + 3, gpos.z), 0);
107109

108-
for (int i = 0; i < ${TILE_SIZE * TILE_SIZE}; ++i) {
110+
for (int i = 0; i < TILE_SIZE * TILE_SIZE; ++i) {
109111
in_tex[i] = texelFetch(image_in, ivec3(ipos[i], z4), 0);
110112
}
111113

112-
for (int i = 0; i < ${TILE_SIZE * TILE_SIZE}; ++i) {
114+
for (int i = 0; i < TILE_SIZE * TILE_SIZE; ++i) {
113115
// For 2x2 tile size algorithm works as follows.
114116
// To explain the calculations below, the contents of one in_tex and the
115117
// group of 4 texels loaded from kernel_in are shown:
@@ -150,7 +152,7 @@ void main() {
150152
}
151153
}
152154

153-
for (int i = 0; i < ${TILE_SIZE * TILE_SIZE}; ++i) {
155+
for (int i = 0; i < TILE_SIZE * TILE_SIZE; ++i) {
154156
if (all(lessThan(pos[i], out_limits))) {
155157
imageStore(image_out, pos[i], op(sum[i], out_min, out_max));
156158
}

backends/vulkan/runtime/graph/ops/glsl/full.glsl

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,8 @@
1212

1313
#define VEC4_T ${texel_type(DTYPE)}
1414

15+
#define POS ${get_pos[NDIM]("pos")}
16+
1517
#include "indexing_utils.h"
1618

1719
layout(std430) buffer;
@@ -48,5 +50,5 @@ void main() {
4850
outtex = outtex * valid_idx;
4951
}
5052

51-
imageStore(image_out, ${get_pos[NDIM]("pos")}, outtex);
53+
imageStore(image_out, POS, outtex);
5254
}

backends/vulkan/test/glsl/idx_fill_texture.glsl

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,8 @@
1212

1313
#define VEC4_T ${texel_type(DTYPE)}
1414

15+
#define POS ${get_pos[NDIM]("pos")}
16+
1517
#include "indexing_utils.h"
1618

1719
layout(std430) buffer;
@@ -36,5 +38,5 @@ void main() {
3638

3739
const ivec4 buf_indices = get_texel_nchw_buffer_ixs(idx, sizes, packed_dim);
3840
VEC4_T texel = VEC4_T(buf_indices);
39-
imageStore(image_out, ${get_pos[NDIM]("pos")}, texel);
41+
imageStore(image_out, POS, texel);
4042
}

0 commit comments

Comments
 (0)