Skip to content

Commit a9178f1

Browse files
pytorchbotmorelos
andauthored
[ET-VK][Ops] aten.var.dim from scratch implementation (#11380)
This PR was created by the merge bot to help merge the original PR into the main branch. ghstack PR number: #11197 by @ahmtox ^ Please use this as the source of truth for the PR details, comments, and reviews ghstack PR base: https://github.com/pytorch/executorch/tree/gh/ahmtox/4/base ghstack PR head: https://github.com/pytorch/executorch/tree/gh/ahmtox/4/head Merge bot PR base: https://github.com/pytorch/executorch/tree/main Merge bot PR head: https://github.com/pytorch/executorch/tree/gh/ahmtox/4/orig @diff-train-skip-merge Co-authored-by: morelos <[email protected]>
1 parent e4b91fa commit a9178f1

File tree

6 files changed

+628
-0
lines changed

6 files changed

+628
-0
lines changed
Lines changed: 121 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,121 @@
1+
/*
2+
* Copyright (c) Meta Platforms, Inc. and affiliates.
3+
* All rights reserved.
4+
*
5+
* This source code is licensed under the BSD-style license found in the
6+
* LICENSE file in the root directory of this source tree.
7+
*/
8+
9+
#version 450 core
10+
11+
#define PRECISION ${PRECISION}
12+
#define T ${buffer_scalar_type(DTYPE)}
13+
14+
${define_required_extensions(DTYPE)}
15+
16+
layout(std430) buffer;
17+
18+
${layout_declare_tensor(B, "w", "out_buf", DTYPE, STORAGE)}
19+
${layout_declare_tensor(B, "r", "in_buf", DTYPE, STORAGE)}
20+
21+
${layout_declare_ubo(B, "ivec4", "in_sizes")}
22+
${layout_declare_ubo(B, "ivec4", "in_strides")}
23+
${layout_declare_ubo(B, "ivec4", "out_sizes")}
24+
${layout_declare_ubo(B, "ivec4", "out_strides")}
25+
26+
layout(push_constant) uniform PushConstants {
27+
int unbiased;
28+
} pc;
29+
30+
layout(local_size_x_id = 0, local_size_y_id = 1, local_size_z_id = 2) in;
31+
32+
layout(constant_id = 3) const int reduce_dim = 0;
33+
34+
#define NWORKERS 4
35+
#define MAX_THREADS 16
36+
37+
shared T shared_sum[NWORKERS];
38+
shared T shared_sum_sq[NWORKERS];
39+
shared int shared_count[NWORKERS];
40+
41+
#include "indexing_utils.h"
42+
43+
void main() {
44+
const ivec4 out_idx = ivec4(
45+
gl_GlobalInvocationID.x,
46+
gl_GlobalInvocationID.y,
47+
gl_GlobalInvocationID.z % out_sizes.z,
48+
gl_GlobalInvocationID.z / out_sizes.z);
49+
50+
const uint tid = gl_LocalInvocationID[reduce_dim];
51+
52+
shared_sum[tid] = T(0);
53+
shared_sum_sq[tid] = T(0);
54+
shared_count[tid] = 0;
55+
barrier();
56+
57+
const int R = in_sizes[reduce_dim];
58+
const uint N = gl_WorkGroupSize[reduce_dim];
59+
60+
// Each workgroup processes a contiguous chunk of the input tensor
61+
// along the reduce_dim. Each thread processes a part of this chunk.
62+
uint q = R / N;
63+
uint rem = R % N;
64+
65+
uint len = q + (tid < rem ? 1u : 0u);
66+
uint base = tid * q + min(tid, rem);
67+
68+
T sum = T(0);
69+
T sum_sq = T(0);
70+
int count = 0;
71+
72+
ivec4 in_idx = out_idx;
73+
for (uint off = 0u; off < len; ++off) {
74+
uint i = base + off;
75+
in_idx[reduce_dim] = int(i);
76+
77+
// out_idx is a 4D index, so for tensors with reduce_dim == 2,
78+
// we need to set the reduce_dim + 1 to 0 as gl_GlobalInvocationID.z
79+
// is influenced by the tid
80+
if (reduce_dim == 2) {
81+
in_idx[reduce_dim + 1] -= int(tid);
82+
}
83+
84+
T v = in_buf[tidx_to_bufi(in_idx, in_strides)];
85+
86+
sum += v;
87+
sum_sq += v * v;
88+
count += 1;
89+
}
90+
91+
shared_sum[tid] = sum;
92+
shared_sum_sq[tid] = sum_sq;
93+
shared_count[tid] = count;
94+
barrier();
95+
96+
if (tid == 0u) {
97+
T tot_sum = T(0);
98+
T tot_sum_sq = T(0);
99+
int tot_count = 0;
100+
101+
for (uint i = 0; i < N; ++i) {
102+
tot_sum += shared_sum[i];
103+
tot_sum_sq += shared_sum_sq[i];
104+
tot_count += shared_count[i];
105+
}
106+
107+
T var;
108+
if (tot_count > 0) {
109+
T mean = tot_sum / T(tot_count);
110+
var = (tot_sum_sq / T(tot_count)) - (mean * mean);
111+
if (pc.unbiased != 0 && tot_count > 1) {
112+
var *= T(tot_count) / T(tot_count - 1);
113+
}
114+
} else{
115+
// NaN to match PyTorch behavior
116+
var = T(0.0/0.0);
117+
}
118+
119+
out_buf[tidx_to_bufi(out_idx, out_strides)] = var;
120+
}
121+
}
Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,16 @@
1+
# Copyright (c) Meta Platforms, Inc. and affiliates.
2+
# All rights reserved.
3+
#
4+
# This source code is licensed under the BSD-style license found in the
5+
# LICENSE file in the root directory of this source tree.
6+
7+
var_buffer:
8+
parameter_names_with_default_values:
9+
DTYPE: float
10+
STORAGE: buffer
11+
generate_variant_forall:
12+
DTYPE:
13+
- VALUE: half
14+
- VALUE: float
15+
shader_variants:
16+
- NAME: var_buffer
Lines changed: 222 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,222 @@
1+
/*
2+
* Copyright (c) Meta Platforms, Inc. and affiliates.
3+
* All rights reserved.
4+
*
5+
* This source code is licensed under the BSD-style license found in the
6+
* LICENSE file in the root directory of this source tree.
7+
*/
8+
9+
#version 450 core
10+
11+
#define PRECISION ${PRECISION}
12+
#define VEC4_T ${texel_load_type(DTYPE, STORAGE)}
13+
14+
${define_active_storage_type(STORAGE)}
15+
16+
#extension GL_EXT_control_flow_attributes : require
17+
18+
layout(std430) buffer;
19+
20+
${layout_declare_tensor(B, "w", "tout", DTYPE, STORAGE)}
21+
${layout_declare_tensor(B, "r", "tin", DTYPE, STORAGE)}
22+
23+
${layout_declare_ubo(B, "ivec3", "tin_limits")}
24+
${layout_declare_ubo(B, "ivec4", "tin_sizes")}
25+
26+
layout(push_constant) uniform PushConstants {
27+
int unbiased;
28+
} pc;
29+
30+
layout(local_size_x_id = 0, local_size_y_id = 1, local_size_z_id = 2) in;
31+
32+
layout(constant_id = 3) const int packed_dim = 0;
33+
layout(constant_id = 4) const int reduce_dim = 0;
34+
layout(constant_id = 5) const int group_dim = 1;
35+
36+
// A more verbose name would be NWORKERS_PER_GROUP. This describes the number of
37+
// threads that will co-operate to compute one reduction output. There may be
38+
// multiple groups computing distinct reduction outputs within one work group.
39+
#define NWORKERS 4
40+
41+
// Sets an upper limit on the total size of a work group based on how many
42+
// elements are allocated in the shared memory array below. Each thread in the
43+
// work group will write into its assigned element in the shared array.
44+
#define MAX_NTHREADS 16
45+
46+
shared VEC4_T shared_sum[MAX_NTHREADS];
47+
shared VEC4_T shared_sum_sq[MAX_NTHREADS];
48+
shared int shared_count[MAX_NTHREADS];
49+
50+
#include "indexing_utils.h"
51+
52+
int tid_to_smi(const ivec2 tid) {
53+
return tid.x + tid.y * NWORKERS;
54+
}
55+
56+
VEC4_T calculate_variance(VEC4_T sum, VEC4_T sum_sq, int count) {
57+
VEC4_T mean = sum / float(count);
58+
VEC4_T variance = (sum_sq / float(count)) - (mean * mean);
59+
60+
if ((pc.unbiased != 0) && (count > 1)) {
61+
variance = variance * (float(count) / float(count - 1.0));
62+
}
63+
64+
return variance;
65+
}
66+
67+
void reduce_nonpacked_dim(const ivec2 tid, ivec3 scan_pos) {
68+
// shared memory index of this thread
69+
const int smi = tid_to_smi(tid);
70+
71+
VEC4_T sum = VEC4_T(0);
72+
VEC4_T sum_sq = VEC4_T(0);
73+
int count = 0;
74+
75+
scan_pos[reduce_dim] = tid.x;
76+
for (int i = tid.x; i < tin_sizes[reduce_dim];
77+
i += NWORKERS, scan_pos[reduce_dim] += NWORKERS) {
78+
VEC4_T val = load_texel(tin, scan_pos);
79+
sum += val;
80+
sum_sq += val * val;
81+
count += 1;
82+
}
83+
// Write partial output to shared memory and synchronize work group
84+
shared_sum[smi] = sum;
85+
shared_sum_sq[smi] = sum_sq;
86+
shared_count[smi] = count;
87+
barrier();
88+
89+
// Since the reduction row is reduced to only one element, only the "main"
90+
// thread in the group needs aggregate the partial outputs
91+
if (tid.x == 0) {
92+
int group_i = tid.y * NWORKERS;
93+
sum = shared_sum[group_i];
94+
sum_sq = shared_sum_sq[group_i];
95+
count = shared_count[group_i];
96+
97+
for (int i = 1; i < NWORKERS; i++) {
98+
int idx = tid.y * NWORKERS + i;
99+
sum += shared_sum[idx];
100+
sum_sq += shared_sum_sq[idx];
101+
count += shared_count[idx];
102+
}
103+
104+
// Determine if there are any padding elements in the final texel of the
105+
// packed dimension
106+
const int nspill = mod4(tin_sizes[packed_dim]);
107+
// Detect if this thread is working on the final texels of the packed
108+
// dimension, which may have padding elements
109+
const bool is_last_texel =
110+
scan_pos[packed_dim] == (tin_limits[packed_dim] - 1);
111+
112+
VEC4_T variance = calculate_variance(sum, sum_sq, count);
113+
114+
// Explicitly set padding elements to 0
115+
if (is_last_texel && nspill > 0) {
116+
[[unroll]] for (int i = nspill; i < 4; i++) {
117+
variance[i] = 0;
118+
}
119+
}
120+
121+
scan_pos[reduce_dim] = tid.x;
122+
write_texel(tout, scan_pos, variance);
123+
}
124+
}
125+
126+
/*
127+
* Compute reduction where the reduction dim is also the packed dim. This case is
128+
* complex because the reduction needs to occur over the individual texels.
129+
* Therefore, in this algorithm each element of the accumulator texels are
130+
* themselves partial outputs. Special care has to be taken to ignore padding
131+
* elements in texels (which occur when the size of the packed dim is not a
132+
* multiple of 4) so that they do not influence the output of reduction.
133+
*/
134+
void reduce_packed_dim(const ivec2 tid, ivec3 scan_pos) {
135+
// shared memory index of this thread
136+
const int smi = tid_to_smi(tid);
137+
138+
// Number of non-padding elements in the last texel in the reduction row
139+
const int nspill = mod4(tin_sizes[packed_dim]);
140+
// Only reduce up to the last "complete" texel. The last texel will need to be
141+
// handled specially if it has padding elements.
142+
const int reduce_len = tin_sizes[packed_dim] - nspill;
143+
144+
VEC4_T sum = VEC4_T(0);
145+
VEC4_T sum_sq = VEC4_T(0);
146+
int count = 0;
147+
148+
// Partially accumulate over elements i, i + NWORKERS, i + 2*NWORKERS, ... of
149+
// the reduction row
150+
scan_pos[reduce_dim] = tid.x;
151+
for (int i = tid.x * 4; i < reduce_len;
152+
i += NWORKERS * 4, scan_pos[reduce_dim] += NWORKERS) {
153+
VEC4_T val = load_texel(tin, scan_pos);
154+
sum += val;
155+
sum_sq += val * val;
156+
count += 4;
157+
}
158+
// For the last texel in the dim, if there are padding elements then each
159+
// element of the texel needs to be processed individually such that the
160+
// padding elements are ignored
161+
if (scan_pos[reduce_dim] == tin_limits[reduce_dim] - 1 && nspill > 0) {
162+
const VEC4_T val = load_texel(tin, scan_pos);
163+
for (int i = 0; i < nspill; i++) {
164+
sum.x += val[i];
165+
sum_sq.x += val[i] * val[i];
166+
count += 1;
167+
}
168+
}
169+
// Write partial output to shared memory and synchronize work group
170+
shared_sum[smi] = sum;
171+
shared_sum_sq[smi] = sum_sq;
172+
shared_count[smi] = count;
173+
barrier();
174+
175+
// Since the reduction row is reduced to only one element, only the "main"
176+
// thread in the group needs aggregate the partial outputs
177+
if (tid.x == 0) {
178+
sum = shared_sum[tid.y * NWORKERS];
179+
sum_sq = shared_sum_sq[tid.y * NWORKERS];
180+
count = shared_count[tid.y * NWORKERS];
181+
for (int i = 1; i < NWORKERS; i++) {
182+
int idx = tid.y * NWORKERS + i;
183+
sum += shared_sum[idx];
184+
sum_sq += shared_sum_sq[idx];
185+
count += shared_count[idx];
186+
}
187+
188+
// Combine across the elements of the combined state
189+
float total_sum = sum.x + sum.y + sum.z + sum.w;
190+
float total_sum_sq = sum_sq.x + sum_sq.y + sum_sq.z + sum_sq.w;
191+
int total_count = count;
192+
193+
float mean = total_sum / float(total_count);
194+
float variance = (total_sum_sq / float(total_count)) - (mean * mean);
195+
196+
if ((pc.unbiased != 0) && (total_count > 1)) {
197+
variance = variance * (float(total_count) / float(total_count - 1.0));
198+
}
199+
200+
scan_pos[reduce_dim] = tid.x;
201+
write_texel(tout, scan_pos, VEC4_T(variance, 0, 0, 0));
202+
}
203+
}
204+
205+
void main() {
206+
ivec3 scan_pos = ivec3(gl_GlobalInvocationID);
207+
scan_pos[reduce_dim] = 0;
208+
209+
const ivec2 tid = ivec2(
210+
gl_LocalInvocationID[reduce_dim],
211+
gl_LocalInvocationID[group_dim]);
212+
213+
if (any(greaterThanEqual(scan_pos, tin_limits))) {
214+
return;
215+
}
216+
217+
if (reduce_dim != packed_dim) {
218+
reduce_nonpacked_dim(tid, scan_pos);
219+
} else {
220+
reduce_packed_dim(tid, scan_pos);
221+
}
222+
}
Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,16 @@
1+
# Copyright (c) Meta Platforms, Inc. and affiliates.
2+
# All rights reserved.
3+
#
4+
# This source code is licensed under the BSD-style license found in the
5+
# LICENSE file in the root directory of this source tree.
6+
7+
var_texture3d:
8+
parameter_names_with_default_values:
9+
DTYPE: float
10+
STORAGE: texture3d
11+
generate_variant_forall:
12+
DTYPE:
13+
- VALUE: half
14+
- VALUE: float
15+
shader_variants:
16+
- NAME: var_texture3d

0 commit comments

Comments
 (0)