Skip to content

Commit 8140a90

Browse files
pytorchbotSS-JIA
andauthored
[ET-VK] Implement generic reduction shader + mean, sum, amax, amin (#6473)
Pull Request resolved: #6457 ## Context Introduce a generic shader to compute reduction along a single dim, and `keepdim = True`. With the generic shader template, `mean`, `sum`, `amin`, and `amax` can be implemented. ghstack-source-id: 249709743 @exported-using-ghexport Differential Revision: [D64840504](https://our.internmc.facebook.com/intern/diff/D64840504/) --------- Co-authored-by: Stephen Jia <[email protected]>
1 parent 5692203 commit 8140a90

File tree

13 files changed

+418
-424
lines changed

13 files changed

+418
-424
lines changed

backends/vulkan/TARGETS

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -27,7 +27,6 @@ runtime.python_library(
2727
"//executorch/backends/transforms:fuse_conv_with_clamp",
2828
"//executorch/backends/transforms:fuse_dequant_linear",
2929
"//executorch/backends/transforms:fuse_view_copy",
30-
"//executorch/backends/transforms:mean_to_sum_div",
3130
"//executorch/backends/transforms:remove_clone_ops",
3231
"//executorch/backends/vulkan/_passes:vulkan_passes",
3332
"//executorch/exir:graph_module",

backends/vulkan/partitioner/supported_ops.py

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -89,6 +89,10 @@ def __contains__(self, op):
8989
# Reduction
9090
exir_ops.edge.aten._log_softmax.default,
9191
exir_ops.edge.aten._softmax.default,
92+
exir_ops.edge.aten.mean.dim,
93+
exir_ops.edge.aten.sum.dim_IntList,
94+
exir_ops.edge.aten.amax.default,
95+
exir_ops.edge.aten.amin.default,
9296
# 2D Pooling
9397
exir_ops.edge.aten.avg_pool2d.default,
9498
exir_ops.edge.aten.max_pool2d_with_indices.default,
@@ -101,9 +105,6 @@ def __contains__(self, op):
101105
]
102106

103107
NO_DYNAMIC_SHAPE = [
104-
# Reduction
105-
exir_ops.edge.aten.mean.dim,
106-
exir_ops.edge.aten.sum.dim_IntList,
107108
# Normalization
108109
exir_ops.edge.aten._native_batch_norm_legit_no_training.default,
109110
exir_ops.edge.aten.native_layer_norm.default,
Lines changed: 214 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,214 @@
1+
/*
2+
* Copyright (c) Meta Platforms, Inc. and affiliates.
3+
* All rights reserved.
4+
*
5+
* This source code is licensed under the BSD-style license found in the
6+
* LICENSE file in the root directory of this source tree.
7+
*/
8+
9+
#version 450 core
10+
11+
#define PRECISION ${PRECISION}
12+
#define VEC4_T ${texel_load_type(DTYPE, STORAGE)}
13+
14+
${define_active_storage_type(STORAGE)}
15+
16+
#extension GL_EXT_control_flow_attributes : require
17+
18+
layout(std430) buffer;
19+
20+
${layout_declare_tensor(B, "w", "tout", DTYPE, STORAGE)}
21+
${layout_declare_tensor(B, "r", "tin", DTYPE, STORAGE)}
22+
23+
${layout_declare_ubo(B, "ivec3", "tin_limits")}
24+
${layout_declare_ubo(B, "ivec4", "tin_sizes")}
25+
26+
layout(local_size_x_id = 0, local_size_y_id = 1, local_size_z_id = 2) in;
27+
28+
layout(constant_id = 3) const int packed_dim = 0;
29+
layout(constant_id = 4) const int reduce_dim = 0;
30+
layout(constant_id = 5) const int group_dim = 1;
31+
32+
// A more verbose name would be NWORKERS_PER_GROUP. This describes the number of
33+
// threads that will co-operate to compute one reduction output. There may be
34+
// multiple groups computing distinct reduction outputs within one work group.
35+
#define NWORKERS 4
36+
37+
// Sets an upper limit on the total size of a work group based on how many
38+
// elements are allocated in the shared memory array below. Each thread in the
39+
// work group will write into its assigned element in the shared array.
40+
#define MAX_NTHREADS 16
41+
42+
43+
shared vec4 shared_vecs[MAX_NTHREADS];
44+
45+
#include "indexing_utils.h"
46+
47+
int tid_to_smi(const ivec2 tid) {
48+
return tid.x + tid.y * NWORKERS;
49+
}
50+
51+
/*
52+
* The functions below compute reduction along a single dimension for a tensor.
53+
* The shader template generalize reduction by abstracting the initial value of
54+
* the accumulator, the calculation used to update the accumulator with new
55+
* values, and a postprocessing calculation that can be used to modify the
56+
* accumulator before writing to output.
57+
*
58+
* This shader also utilize shared memory to have multiple threads help compute
59+
* the max and sum reduction operations. A total of NGROUPS x NWORKERS threads
60+
* are expected to be launched. Each group works on a unique reduction "row", and
61+
* within a group NWORKERS threads co-operate to compute the max and sum of one
62+
* "row". Each worker in the group is responsible for computing a partial output
63+
* of the "row" and uploading it to shared memory; the overall reduction output
64+
* can then be determined by aggregating the partial outputs stored in shared
65+
* memory.
66+
*
67+
* As a caveat, this shader does not currently support cases where `batch` > 1
68+
* and the reduce dim happens to also be the batch concatenation dim. To support
69+
* this, there will need to be additional logic to set the starting value of
70+
* `scan_pos[reduce_dim]`. Since this is not expected to be a common use-case,
71+
* supporting this case is left as an exercise for when it is required.
72+
*/
73+
74+
// Initializing the accumulator accepts the first value in the reduction row,
75+
// since some reduction operations (i.e. amax, amin) prefer to initialize with
76+
// a data point instead of a static value.
77+
#define INIT_ACCUM(first_val) ${INIT_ACCUM}
78+
#define UPDATE_ACCUM(accum, new_val) ${UPDATE_ACCUM}
79+
// Useful for operators such as mean which want to perform a final calculation
80+
// with the accumulator.
81+
#define POSTPROCESS(accum) ${POSTPROCESS}
82+
83+
/*
84+
* Computes reduction where the reduction dim is orthogonal to the packed dim.
85+
* This case is simpler because each element of a texel belongs to a separate
86+
* reduction "group", meaning we don't have to perform reduction along a texel.
87+
*/
88+
void reduce_nonpacked_dim(const ivec2 tid, ivec3 scan_pos) {
89+
// shared memory index of this thread
90+
const int smi = tid_to_smi(tid);
91+
92+
scan_pos[reduce_dim] = 0;
93+
vec4 accum = INIT_ACCUM(load_texel(tin, scan_pos));
94+
95+
scan_pos[reduce_dim] = tid.x;
96+
// Partially accumulate over elements i, i + NWORKERS, i + 2*NWORKERS, ... of
97+
// the reduction row
98+
for (int i = tid.x; i < tin_sizes[reduce_dim];
99+
i += NWORKERS, scan_pos[reduce_dim] += NWORKERS) {
100+
accum = UPDATE_ACCUM(accum, load_texel(tin, scan_pos));
101+
}
102+
// Write partial output to shared memory and synchronize work group
103+
shared_vecs[smi] = accum;
104+
barrier();
105+
106+
// Since the reduction row is reduced to only one element, only the "main"
107+
// thread in the group needs aggregate the partial outputs
108+
if (tid.x == 0) {
109+
// Iterate over the partial outputs to obtain the overall output
110+
int group_i = tid.y * NWORKERS;
111+
accum = shared_vecs[group_i++];
112+
for (int i = 1; i < NWORKERS; i++, group_i++) {
113+
accum = UPDATE_ACCUM(accum, shared_vecs[group_i]);
114+
}
115+
116+
// Determine if there are any padding elements in the final texel of the
117+
// packed dimension
118+
const int nspill = mod4(tin_sizes[packed_dim]);
119+
// Detect if this thread is working on the final texels of the packed
120+
// dimension, which may have padding elements
121+
const bool is_last_texel =
122+
scan_pos[packed_dim] == (tin_limits[packed_dim] - 1);
123+
124+
// Explicitly set padding elements to 0
125+
if (is_last_texel && nspill > 0) {
126+
[[unroll]] for (int i = nspill; i < 4; i++) {
127+
accum[i] = 0;
128+
}
129+
}
130+
scan_pos[reduce_dim] = tid.x;
131+
write_texel(tout, scan_pos, POSTPROCESS(accum));
132+
}
133+
}
134+
135+
/*
136+
* Compute reduction where the reduction dim is also the packed dim. This case is
137+
* complex because the reduction needs to occur over the individual texels.
138+
* Therefore, in this algorithm each element of the accumulator texels are
139+
* themselves partial outputs. Special care has to be taken to ignore padding
140+
* elements in texels (which occur when the size of the packed dim is not a
141+
* multiple of 4) so that they do not influence the output of reduction.
142+
*/
143+
void reduce_packed_dim(const ivec2 tid, ivec3 scan_pos) {
144+
// shared memory index of this thread
145+
const int smi = tid_to_smi(tid);
146+
147+
// Number of non-padding elements in the last texel in the reduction row
148+
const int nspill = mod4(tin_sizes[packed_dim]);
149+
// Only reduce up to the last "complete" texel. The last texel will need to be
150+
// handled specially if it has padding elements.
151+
const int reduce_len = tin_sizes[packed_dim] - nspill;
152+
153+
scan_pos[reduce_dim] = 0;
154+
vec4 accum = INIT_ACCUM(vec4(load_texel(tin, scan_pos).x));
155+
156+
// Partially accumulate over elements i, i + NWORKERS, i + 2*NWORKERS, ... of
157+
// the reduction row
158+
scan_pos[reduce_dim] = tid.x;
159+
for (int i = tid.x * 4; i < reduce_len;
160+
i += NWORKERS * 4, scan_pos[reduce_dim] += NWORKERS) {
161+
accum = UPDATE_ACCUM(accum, load_texel(tin, scan_pos));
162+
}
163+
// For the last texel in the dim, if there are padding elements then each
164+
// element of the texel needs to be processed individually such that the
165+
// padding elements are ignored
166+
if (scan_pos[reduce_dim] == tin_limits[reduce_dim] - 1 && nspill > 0) {
167+
const vec4 intex = load_texel(tin, scan_pos);
168+
for (int i = 0; i < nspill; i++) {
169+
accum.x = UPDATE_ACCUM(accum.x, intex[i]);
170+
}
171+
}
172+
// Write partial output to shared memory and synchronize work group
173+
shared_vecs[smi] = accum;
174+
barrier();
175+
176+
// Since the reduction row is reduced to only one element, only the "main"
177+
// thread in the group needs aggregate the partial outputs
178+
if (tid.x == 0) {
179+
// Iterate over the partial maximums to obtain the overall maximum
180+
int group_i = tid.y * NWORKERS;
181+
accum = shared_vecs[group_i++];
182+
for (int i = 1; i < NWORKERS; i++, group_i++) {
183+
accum = UPDATE_ACCUM(accum, shared_vecs[group_i]);
184+
}
185+
// Each element of the texel is itself a partial maximum; iterate over the
186+
// texel to find the actual maximum
187+
float accum_final = accum.x;
188+
[[unroll]] for (int i = 1; i < 4; i++) {
189+
accum_final = UPDATE_ACCUM(accum[i], accum_final);
190+
}
191+
192+
scan_pos[reduce_dim] = tid.x;
193+
write_texel(tout, scan_pos, POSTPROCESS(vec4(accum_final, 0, 0, 0)));
194+
}
195+
}
196+
197+
void main() {
198+
ivec3 scan_pos = ivec3(gl_GlobalInvocationID);
199+
scan_pos[reduce_dim] = 0;
200+
201+
const ivec2 tid = ivec2(
202+
gl_LocalInvocationID[reduce_dim],
203+
gl_LocalInvocationID[group_dim]);
204+
205+
if (any(greaterThanEqual(scan_pos, tin_limits))) {
206+
return;
207+
}
208+
209+
if (reduce_dim != packed_dim) {
210+
reduce_nonpacked_dim(tid, scan_pos);
211+
} else {
212+
reduce_packed_dim(tid, scan_pos);
213+
}
214+
}
Lines changed: 29 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,29 @@
1+
# Copyright (c) Meta Platforms, Inc. and affiliates.
2+
# All rights reserved.
3+
#
4+
# This source code is licensed under the BSD-style license found in the
5+
# LICENSE file in the root directory of this source tree.
6+
7+
reduce:
8+
parameter_names_with_default_values:
9+
DTYPE: float
10+
STORAGE: texture3d
11+
INIT_ACCUM: VEC4_T(0)
12+
UPDATE_ACCUM: accum + new_val
13+
POSTPROCESS: accum
14+
generate_variant_forall:
15+
DTYPE:
16+
- VALUE: half
17+
- VALUE: float
18+
shader_variants:
19+
- NAME: sum
20+
- NAME: mean
21+
POSTPROCESS: (accum / tin_sizes[reduce_dim])
22+
- NAME: amax
23+
INIT_ACCUM: first_val
24+
UPDATE_ACCUM: max(accum, new_val)
25+
POSTPROCESS: accum
26+
- NAME: amin
27+
INIT_ACCUM: first_val
28+
UPDATE_ACCUM: min(accum, new_val)
29+
POSTPROCESS: accum

backends/vulkan/runtime/graph/ops/glsl/sum_dim.glsl

Lines changed: 0 additions & 108 deletions
This file was deleted.

backends/vulkan/runtime/graph/ops/glsl/sum_dim.yaml

Lines changed: 0 additions & 16 deletions
This file was deleted.

0 commit comments

Comments
 (0)