Skip to content

Commit 350313b

Browse files
author
morelos
committed
Update on "[ET-VK][Ops] aten.var.dim from scratch implementation"
Created the var.dim operator (which functionally supports var) from scratch Differential Revision: [D75244137](https://our.internmc.facebook.com/intern/diff/D75244137/) [ghstack-poisoned]
1 parent f6e2952 commit 350313b

File tree

3 files changed

+23
-19
lines changed

3 files changed

+23
-19
lines changed

backends/vulkan/runtime/graph/ops/glsl/var.glsl renamed to backends/vulkan/runtime/graph/ops/glsl/var_texture3d.glsl

Lines changed: 14 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -43,8 +43,8 @@ layout(constant_id = 5) const int group_dim = 1;
4343
// work group will write into its assigned element in the shared array.
4444
#define MAX_NTHREADS 16
4545

46-
shared vec4 shared_sum[MAX_NTHREADS];
47-
shared vec4 shared_sum_sq[MAX_NTHREADS];
46+
shared VEC4_T shared_sum[MAX_NTHREADS];
47+
shared VEC4_T shared_sum_sq[MAX_NTHREADS];
4848
shared int shared_count[MAX_NTHREADS];
4949

5050
#include "indexing_utils.h"
@@ -53,9 +53,9 @@ int tid_to_smi(const ivec2 tid) {
5353
return tid.x + tid.y * NWORKERS;
5454
}
5555

56-
vec4 calculate_variance(vec4 sum, vec4 sum_sq, int count) {
57-
vec4 mean = sum / float(count);
58-
vec4 variance = (sum_sq / float(count)) - (mean * mean);
56+
VEC4_T calculate_variance(VEC4_T sum, VEC4_T sum_sq, int count) {
57+
VEC4_T mean = sum / float(count);
58+
VEC4_T variance = (sum_sq / float(count)) - (mean * mean);
5959

6060
if ((pc.unbiased != 0) && (count > 1)) {
6161
variance = variance * (float(count) / float(count - 1.0));
@@ -68,14 +68,14 @@ void reduce_nonpacked_dim(const ivec2 tid, ivec3 scan_pos) {
6868
// shared memory index of this thread
6969
const int smi = tid_to_smi(tid);
7070

71-
vec4 sum = VEC4_T(0);
72-
vec4 sum_sq = VEC4_T(0);
71+
VEC4_T sum = VEC4_T(0);
72+
VEC4_T sum_sq = VEC4_T(0);
7373
int count = 0;
7474

7575
scan_pos[reduce_dim] = tid.x;
7676
for (int i = tid.x; i < tin_sizes[reduce_dim];
7777
i += NWORKERS, scan_pos[reduce_dim] += NWORKERS) {
78-
vec4 val = load_texel(tin, scan_pos);
78+
VEC4_T val = load_texel(tin, scan_pos);
7979
sum += val;
8080
sum_sq += val * val;
8181
count += 1;
@@ -109,7 +109,7 @@ void reduce_nonpacked_dim(const ivec2 tid, ivec3 scan_pos) {
109109
const bool is_last_texel =
110110
scan_pos[packed_dim] == (tin_limits[packed_dim] - 1);
111111

112-
vec4 variance = calculate_variance(sum, sum_sq, count);
112+
VEC4_T variance = calculate_variance(sum, sum_sq, count);
113113

114114
// Explicitly set padding elements to 0
115115
if (is_last_texel && nspill > 0) {
@@ -141,16 +141,16 @@ void reduce_packed_dim(const ivec2 tid, ivec3 scan_pos) {
141141
// handled specially if it has padding elements.
142142
const int reduce_len = tin_sizes[packed_dim] - nspill;
143143

144-
vec4 sum = VEC4_T(0);
145-
vec4 sum_sq = VEC4_T(0);
144+
VEC4_T sum = VEC4_T(0);
145+
VEC4_T sum_sq = VEC4_T(0);
146146
int count = 0;
147147

148148
// Partially accumulate over elements i, i + NWORKERS, i + 2*NWORKERS, ... of
149149
// the reduction row
150150
scan_pos[reduce_dim] = tid.x;
151151
for (int i = tid.x * 4; i < reduce_len;
152152
i += NWORKERS * 4, scan_pos[reduce_dim] += NWORKERS) {
153-
vec4 val = load_texel(tin, scan_pos);
153+
VEC4_T val = load_texel(tin, scan_pos);
154154
sum += val;
155155
sum_sq += val * val;
156156
count += 4;
@@ -159,7 +159,7 @@ void reduce_packed_dim(const ivec2 tid, ivec3 scan_pos) {
159159
// element of the texel needs to be processed individually such that the
160160
// padding elements are ignored
161161
if (scan_pos[reduce_dim] == tin_limits[reduce_dim] - 1 && nspill > 0) {
162-
const vec4 val = load_texel(tin, scan_pos);
162+
const VEC4_T val = load_texel(tin, scan_pos);
163163
for (int i = 0; i < nspill; i++) {
164164
sum.x += val[i];
165165
sum_sq.x += val[i] * val[i];
@@ -198,7 +198,7 @@ void reduce_packed_dim(const ivec2 tid, ivec3 scan_pos) {
198198
}
199199

200200
scan_pos[reduce_dim] = tid.x;
201-
write_texel(tout, scan_pos, vec4(variance, 0, 0, 0));
201+
write_texel(tout, scan_pos, VEC4_T(variance, 0, 0, 0));
202202
}
203203
}
204204

backends/vulkan/runtime/graph/ops/glsl/var.yaml renamed to backends/vulkan/runtime/graph/ops/glsl/var_texture3d.yaml

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@
44
# This source code is licensed under the BSD-style license found in the
55
# LICENSE file in the root directory of this source tree.
66

7-
var:
7+
var_texture3d:
88
parameter_names_with_default_values:
99
DTYPE: float
1010
STORAGE: texture3d
@@ -13,4 +13,4 @@ var:
1313
- VALUE: half
1414
- VALUE: float
1515
shader_variants:
16-
- NAME: var
16+
- NAME: var_texture3d

backends/vulkan/runtime/graph/ops/impl/Var.cpp

Lines changed: 7 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -39,8 +39,9 @@ void add_var_buffer_node(
3939
int32_t reduce_dim = normalize(dim, ndim);
4040
reduce_dim = nchw_dim_to_whcn_dim(reduce_dim, ndim);
4141

42-
std::string kernel_name = "var_buffer";
42+
std::string kernel_name = "var";
4343
kernel_name.reserve(kShaderNameReserve);
44+
add_storage_type_suffix(kernel_name, graph.storage_type_of(out));
4445
add_dtype_suffix(kernel_name, graph.dtype_of(out));
4546

4647
const uint32_t nworkers_per_group = 4;
@@ -56,7 +57,8 @@ void add_var_buffer_node(
5657
std::vector<PushConstantDataInfo> push_constants;
5758
int32_t unbiased_int = static_cast<int32_t>(unbiased);
5859
push_constants.emplace_back(
59-
PushConstantDataInfo(&unbiased_int, sizeof(unbiased_int)));
60+
PushConstantDataInfo(
61+
&unbiased_int, sizeof(unbiased_int)));
6062

6163
graph.execute_nodes().emplace_back(new DispatchNode(
6264
graph,
@@ -103,6 +105,7 @@ void add_var_texture_node(
103105

104106
std::string kernel_name = "var";
105107
kernel_name.reserve(kShaderNameReserve);
108+
add_storage_type_suffix(kernel_name, graph.storage_type_of(out));
106109
add_dtype_suffix(kernel_name, graph.dtype_of(out));
107110

108111
// This should match the value of MAX_NTHREADS in the softmax shader.
@@ -131,7 +134,8 @@ void add_var_texture_node(
131134
std::vector<PushConstantDataInfo> push_constants;
132135
int32_t unbiased_int = static_cast<int32_t>(unbiased);
133136
push_constants.emplace_back(
134-
PushConstantDataInfo(&unbiased_int, sizeof(unbiased_int)));
137+
PushConstantDataInfo(
138+
&unbiased_int, sizeof(unbiased_int)));
135139

136140
graph.execute_nodes().emplace_back(new DispatchNode(
137141
graph,

0 commit comments

Comments
 (0)