Skip to content

Commit a91eb8a

Browse files
SS-JIAfacebook-github-bot
authored andcommitted
Generalize softmax for packed dim vs non packed dim (#5755)
Summary: Pull Request resolved: #5755 ## Context This diff performs a rewrite of the `softmax` shaders. Previously, the shaders were separated into the `channels` case and `batch_height_width` case. This is because channels packing was the only packing format used when these shaders were written, thus these shaders represented the case when the reduction dim is equal to the packed dim, and the case when the reduction dim is orthogonal to the packed dim respectively. Now that we expect tensors to be width packed as well, the `channels`/`batch_height_width` separation no longer makes sense. This diff consolidates both cases into a single shader that takes the `packed_dim` and `reduce_dim` as specialization constants, and selects the correct function to execute based on if they are the same or different. Additionally, I implemented a optimization in the form of using a co-operative algorithm to allow multiple threads to co-operate in computing max and sum. More details can be found in the comments of the new shader file. ghstack-source-id: 245571371 Reviewed By: jorgep31415 Differential Revision: D63642091 fbshipit-source-id: cfe960a20cacdb7670390f7626fbf64366a734b0
1 parent b60fa71 commit a91eb8a

File tree

14 files changed

+401
-291
lines changed

14 files changed

+401
-291
lines changed

backends/vulkan/runtime/api/containers/Tensor.h

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -395,6 +395,14 @@ class vTensor final {
395395
return packed_dim_;
396396
}
397397

398+
/*
399+
* Returns the WHCN index of the dimension that is used to concatenate batches
400+
* as an int32_t.
401+
*/
402+
inline int32_t concat_dim() const {
403+
return utils::safe_downcast<int32_t>(axis_map_.at(3));
404+
}
405+
398406
inline const std::vector<int64_t>& sizes() const {
399407
return sizes_;
400408
}

backends/vulkan/runtime/graph/ComputeGraph.h

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -322,6 +322,10 @@ class ComputeGraph final {
322322
return values_.at(idx).toConstTensor().packed_dim();
323323
}
324324

325+
inline int32_t concat_dim_of(const ValueRef idx) const {
326+
return values_.at(idx).toConstTensor().concat_dim();
327+
}
328+
325329
inline vkapi::BufferBindInfo sizes_ubo(const ValueRef idx) {
326330
return values_.at(idx).toTensor().sizes_ubo();
327331
}

backends/vulkan/runtime/graph/ops/glsl/indexing_utils.h

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -53,6 +53,11 @@
5353
*/
5454
#define alignup4(x) ((x + 3) & -4)
5555

56+
/*
57+
* Fast modulo by 4 using bit masking
58+
*/
59+
#define mod4(x) (x & 3)
60+
5661
/*
5762
* Find the packed dimension of a tensor given its strides. The packed dimension
5863
* is the "fastest moving" dimension which will have a stride of 1.
Lines changed: 267 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,267 @@
1+
/*
2+
* Copyright (c) Meta Platforms, Inc. and affiliates.
3+
* All rights reserved.
4+
*
5+
* This source code is licensed under the BSD-style license found in the
6+
* LICENSE file in the root directory of this source tree.
7+
*/
8+
9+
#version 450 core
10+
11+
#define PRECISION ${PRECISION}
12+
13+
#define op1(X) ${OPERATOR1}
14+
15+
#define op2(X, Y) ${OPERATOR2}
16+
17+
${define_active_storage_type(STORAGE)}
18+
19+
#extension GL_EXT_control_flow_attributes : require
20+
21+
layout(std430) buffer;
22+
23+
${layout_declare_tensor(B, "w", "tout", DTYPE, STORAGE)}
24+
${layout_declare_tensor(B, "r", "tin", DTYPE, STORAGE)}
25+
26+
${layout_declare_ubo(B, "ivec3", "tout_limits")}
27+
${layout_declare_ubo(B, "ivec4", "tin_sizes")}
28+
29+
layout(local_size_x_id = 0, local_size_y_id = 1, local_size_z_id = 2) in;
30+
31+
layout(constant_id = 3) const int packed_dim = 0;
32+
layout(constant_id = 4) const int reduce_dim = 0;
33+
layout(constant_id = 5) const int group_dim = 1;
34+
35+
// A more verbose name would be NWORKERS_PER_GROUP. This describes the number of
36+
// threads that will co-operate to compute one reduction output. There may be
37+
// multiple groups computing distinct reduction outputs within one work group.
38+
#define NWORKERS 4
39+
40+
// Sets an upper limit on the total size of a work group based on how many
41+
// elements are allocated in the shared memory array below. Each thread in the
42+
// work group will write into its assigned element in the shared array.
43+
#define MAX_NTHREADS 16
44+
45+
shared vec4 shared_vecs[MAX_NTHREADS];
46+
47+
#include "indexing_utils.h"
48+
49+
int tid_to_smi(const ivec2 tid) {
50+
return tid.x + tid.y * NWORKERS;
51+
}
52+
53+
/*
54+
* The shaders below compute softmax for a tensor. Softmax is an interesting mix
55+
* between a reduction operator and a unary elementwise operator, defined as
56+
* exp(x) / (sum of exp(x)). The general flow of the computation is:
57+
*
58+
* First, find the maximum element along the reduction dim. The maximum element
59+
* is used to preserve numerical stability, since division of exponents is
60+
* translation invariant.
61+
*
62+
* Next, compute the sum of exp(x - max_element) along the reduction dim.
63+
*
64+
* Finally, for each element along the reduction dim, we compute the output as
65+
* exp(x - max_element) / sum_of_exponents.
66+
*
67+
* The shaders below also utilize shared memory to have multiple threads help
68+
* compute the max and sum reduction operations. A total of NGROUPS x NWORKERS
69+
* threads are launched. Each group works on a unique reduction "row", and
70+
* within a group NWORKERS threads co-operate to compute the max and sum of one
71+
* "row". Each worker in the group is responsible for computing a partial output
72+
* of the "row" and uploading it to shared memory; the overall reduction output
73+
* can then be determined by aggregating the partial outputs stored in shared
74+
* memory.
75+
*
76+
* As a caveat, this shader does not currently support cases where `batch` > 1
77+
* and the reduce dim happens to also be the batch concatenation dim. To support
78+
* this, there will need to be additional logic to set the starting value of
79+
* `scan_pos[reduce_dim]`. Since this is not expected to be a common use-case,
80+
* supporting this case is left as an exercise for when it is required.
81+
*
82+
* As a final note, log softmax is supported with this shader as well since via
83+
* the op1 and op2 macro definitions. See the corresponding YAML file for more
84+
* details.
85+
*/
86+
87+
/*
88+
* Computes softmax where the reduction dim is orthogonal to the packed dim.
89+
* This case is simpler because each element of a texel belongs to a separate
90+
* reduction dim, meaning we don't have to perform reduction along a texel.
91+
*/
92+
void softmax_nonpacked_dim(const ivec2 tid, ivec3 scan_pos) {
93+
// shared memory index of this thread
94+
const int smi = tid_to_smi(tid);
95+
// used to iterate over all shared memory in the group
96+
int group_i;
97+
98+
scan_pos[reduce_dim] = tid.x;
99+
vec4 max_elements = load_texel(tin, scan_pos);
100+
// This thread computes a partial maximum
101+
for (int i = tid.x; i < tin_sizes[reduce_dim];
102+
i += NWORKERS, scan_pos[reduce_dim] += NWORKERS) {
103+
max_elements = max(max_elements, load_texel(tin, scan_pos));
104+
}
105+
shared_vecs[smi] = max_elements;
106+
barrier();
107+
// Iterate over the partial maximums to obtain the overall maximum
108+
group_i = tid.y * NWORKERS;
109+
max_elements = shared_vecs[group_i++];
110+
for (int i = 1; i < NWORKERS; ++i, group_i++) {
111+
max_elements = max(max_elements, shared_vecs[group_i]);
112+
}
113+
114+
scan_pos[reduce_dim] = tid.x;
115+
vec4 denominators = vec4(0);
116+
// Compute partial sum
117+
for (int i = tid.x; i < tin_sizes[reduce_dim];
118+
i += NWORKERS, scan_pos[reduce_dim] += NWORKERS) {
119+
denominators += exp(load_texel(tin, scan_pos) - max_elements);
120+
}
121+
shared_vecs[smi] = denominators;
122+
barrier();
123+
// Iterate over the partial sums to obtain the overall sum
124+
group_i = tid.y * NWORKERS;
125+
denominators = shared_vecs[group_i++];
126+
for (int i = 1; i < NWORKERS; ++i, group_i++) {
127+
denominators += shared_vecs[group_i];
128+
}
129+
130+
// Determine if there are any padding elements in the final texel of the
131+
// packed dimension
132+
const int nspill = mod4(tin_sizes[packed_dim]);
133+
// Detect if this thread is working on the final texels of the packed
134+
// dimension, which may have padding elements
135+
const bool is_last_texel =
136+
scan_pos[packed_dim] == (tout_limits[packed_dim] - 1);
137+
138+
scan_pos[reduce_dim] = tid.x;
139+
for (int i = tid.x; i < tin_sizes[reduce_dim];
140+
i += NWORKERS, scan_pos[reduce_dim] += NWORKERS) {
141+
const vec4 numerators = op1(load_texel(tin, scan_pos) - max_elements);
142+
vec4 outtex = op2(numerators, denominators);
143+
// For the last texel in the packed dim, make sure that the padding elements
144+
// are explicitly set to 0. Otherwise, they may influence computations later
145+
// down the line.
146+
if (is_last_texel && nspill > 0) {
147+
[[unroll]] for (int i = nspill; i < 4; ++i) {
148+
outtex[i] = 0;
149+
}
150+
}
151+
write_texel(tout, scan_pos, outtex);
152+
}
153+
}
154+
155+
/*
156+
* Compute softmax where the reduction dim is also the packed dim. This case is
157+
* complex because the reduction needs to occur over the individual texels.
158+
* Therefore, in this algorithm each element of the accumulator texels are
159+
* themselves partial outputs. Special care has to be taken to ignore padding
160+
* elements in texels (which occur when the size of the packed dim is not a
161+
* multiple of 4) so that they do not influence the output of reduction.
162+
*/
163+
void softmax_packed_dim(const ivec2 tid, ivec3 scan_pos) {
164+
// shared memory index of this thread
165+
const int smi = tid_to_smi(tid);
166+
// used to iterate over all shared memory in the group
167+
int group_i;
168+
169+
const int nspill = mod4(tin_sizes[packed_dim]);
170+
const int reduce_len = tin_sizes[packed_dim] - nspill;
171+
172+
scan_pos[reduce_dim] = tid.x;
173+
vec4 max_elements = vec4(load_texel(tin, scan_pos).x);
174+
for (int i = tid.x * 4; i < reduce_len;
175+
i += NWORKERS * 4, scan_pos[reduce_dim] += NWORKERS) {
176+
max_elements = max(max_elements, load_texel(tin, scan_pos));
177+
}
178+
// For the last texel in the dim, if there are padding elements then each
179+
// element of the texel needs to be processed individually such that the
180+
// padding elements are ignored
181+
if (scan_pos[reduce_dim] == tout_limits[reduce_dim] - 1 && nspill > 0) {
182+
const vec4 intex = load_texel(tin, scan_pos);
183+
for (int i = 0; i < nspill; ++i) {
184+
max_elements.x = max(intex[i], max_elements.x);
185+
}
186+
}
187+
shared_vecs[smi] = max_elements;
188+
barrier();
189+
// Iterate over the partial maximums to obtain the overall maximum
190+
group_i = tid.y * NWORKERS;
191+
max_elements = shared_vecs[group_i++];
192+
for (int i = 1; i < NWORKERS; ++i, group_i++) {
193+
max_elements = max(max_elements, shared_vecs[group_i]);
194+
}
195+
// Each element of the texel is itself a partial maximum; iterate over the
196+
// texel to find the actual maximum
197+
float max_element = max_elements.x;
198+
[[unroll]] for (int i = 1; i < 4; ++i) {
199+
max_element = max(max_elements[i], max_element);
200+
}
201+
202+
scan_pos[reduce_dim] = tid.x;
203+
vec4 denominators = vec4(0);
204+
for (int i = tid.x * 4; i < reduce_len;
205+
i += NWORKERS * 4, scan_pos[reduce_dim] += NWORKERS) {
206+
denominators += exp(load_texel(tin, scan_pos) - max_element);
207+
}
208+
// For the last texel in the dim, if there are padding elements then each
209+
// element of the texel needs to be processed individually such that the
210+
// padding elements are ignored
211+
if (nspill > 0 && scan_pos[reduce_dim] == tout_limits[reduce_dim] - 1) {
212+
const vec4 intex = load_texel(tin, scan_pos);
213+
for (int i = 0; i < nspill; ++i) {
214+
denominators.x += exp(intex[i] - max_element);
215+
}
216+
}
217+
shared_vecs[smi] = denominators;
218+
barrier();
219+
// Iterate over the partial sums to obtain the overall sum
220+
group_i = tid.y * NWORKERS;
221+
denominators = shared_vecs[group_i++];
222+
for (int i = 1; i < NWORKERS; ++i, group_i++) {
223+
denominators += shared_vecs[group_i];
224+
}
225+
// Reduce over the accumulated texel to find the overall sum
226+
float denominator = 0;
227+
[[unroll]] for (int i = 0; i < 4; ++i) {
228+
denominator += denominators[i];
229+
}
230+
231+
scan_pos[reduce_dim] = tid.x;
232+
for (int i = tid.x * 4; i < reduce_len;
233+
i += NWORKERS * 4, scan_pos[reduce_dim] += NWORKERS) {
234+
const vec4 numerators = op1(load_texel(tin, scan_pos) - max_element);
235+
write_texel(tout, scan_pos, op2(numerators, denominator));
236+
}
237+
// For the last texel in the dim, if there are padding elements then the
238+
// padding elements need to be set to 0 explicitly, otherwise they may
239+
// influence subsequent operations.
240+
if (nspill > 0 && scan_pos[reduce_dim] == tout_limits[reduce_dim] - 1) {
241+
const vec4 numerator = op1(load_texel(tin, scan_pos) - max_element);
242+
vec4 outtex = op2(numerator, denominator);
243+
[[unroll]] for (int i = nspill; i < 4; ++i) {
244+
outtex[i] = 0;
245+
}
246+
write_texel(tout, scan_pos, outtex);
247+
}
248+
}
249+
250+
void main() {
251+
ivec3 scan_pos = ivec3(gl_GlobalInvocationID);
252+
scan_pos[reduce_dim] = 0;
253+
254+
const ivec2 tid = ivec2(
255+
gl_LocalInvocationID[reduce_dim],
256+
gl_LocalInvocationID[group_dim]);
257+
258+
if (any(greaterThanEqual(scan_pos, tout_limits))) {
259+
return;
260+
}
261+
262+
if (reduce_dim != packed_dim) {
263+
softmax_nonpacked_dim(tid, scan_pos);
264+
} else {
265+
softmax_packed_dim(tid, scan_pos);
266+
}
267+
}

backends/vulkan/runtime/graph/ops/glsl/softmax.h

Lines changed: 0 additions & 43 deletions
This file was deleted.

backends/vulkan/runtime/graph/ops/glsl/softmax_channel.yaml renamed to backends/vulkan/runtime/graph/ops/glsl/softmax.yaml

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -4,18 +4,18 @@
44
# This source code is licensed under the BSD-style license found in the
55
# LICENSE file in the root directory of this source tree.
66

7-
softmax_channel:
7+
softmax:
88
parameter_names_with_default_values:
99
OPERATOR1: exp(X)
1010
OPERATOR2: X / Y
11-
NDIM: 3
1211
DTYPE: float
12+
STORAGE: texture3d
1313
generate_variant_forall:
1414
DTYPE:
1515
- VALUE: half
1616
- VALUE: float
1717
shader_variants:
18-
- NAME: softmax_channel
19-
- NAME: log_softmax_channel
18+
- NAME: softmax
19+
- NAME: log_softmax
2020
OPERATOR1: X
2121
OPERATOR2: X - log(Y)

0 commit comments

Comments
 (0)