Skip to content

Commit 9b20d26

Browse files
author
morelos
committed
[ET-VK][Ops] dequantization op shaders and impl
Pull Request resolved: #11483 Creating the dequantize_per_tensor and dequantize_per_token logic shaders and impl which are linked with the testing framework. ghstack-source-id: 289798621 @exported-using-ghexport Differential Revision: [D76267107](https://our.internmc.facebook.com/intern/diff/D76267107/)
1 parent 4447754 commit 9b20d26

File tree

4 files changed

+671
-4
lines changed

4 files changed

+671
-4
lines changed
Lines changed: 208 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,208 @@
1+
/*
2+
* Copyright (c) Meta Platforms, Inc. and affiliates.
3+
* All rights reserved.
4+
*
5+
* This source code is licensed under the BSD-style license found in the
6+
* LICENSE file in the root directory of this source tree.
7+
*/
8+
9+
#version 450 core
10+
11+
#define PRECISION ${PRECISION}
12+
13+
#define IN_T ${buffer_scalar_type(IN_DTYPE)}
14+
#define IVEC4_T ${texel_load_type(IN_DTYPE, STORAGE)}
15+
16+
#define OUT_T ${buffer_scalar_type(OUT_DTYPE)}
17+
#define FVEC4_T ${texel_load_type(OUT_DTYPE, STORAGE)}
18+
19+
${define_active_storage_type(STORAGE)}
20+
${define_required_extensions(IN_DTYPE)}
21+
${define_required_extensions(OUT_DTYPE)}
22+
23+
#extension GL_EXT_control_flow_attributes : require
24+
25+
layout(std430) buffer;
26+
27+
${layout_declare_tensor(B, "r", "t_in", IN_DTYPE, STORAGE)}
28+
${layout_declare_tensor(B, "w", "t_out", OUT_DTYPE, STORAGE)}
29+
30+
$if MODE == "per_tensor":
31+
layout(push_constant) uniform restrict Block {
32+
float scale;
33+
int zero_point;
34+
int quant_min;
35+
int quant_max;
36+
};
37+
$else:
38+
${layout_declare_tensor(B, "r", "t_scale", "float", STORAGE)}
39+
${layout_declare_tensor(B, "r", "t_zero_point", "int", STORAGE)}
40+
41+
layout(push_constant) uniform restrict Block {
42+
int num_tokens;
43+
int quant_min;
44+
int quant_max;
45+
};
46+
47+
$if STORAGE == "buffer":
48+
${layout_declare_ubo(B, "ivec4", "t_in_sizes")}
49+
${layout_declare_ubo(B, "ivec4", "t_in_strides")}
50+
${layout_declare_ubo(B, "ivec4", "t_out_sizes")}
51+
${layout_declare_ubo(B, "ivec4", "t_out_strides")}
52+
$else:
53+
${layout_declare_ubo(B, "ivec3", "t_in_limits")}
54+
${layout_declare_ubo(B, "ivec3", "t_out_limits")}
55+
56+
#include "indexing_utils.h"
57+
58+
layout(local_size_x_id = 0, local_size_y_id = 1, local_size_z_id = 2) in;
59+
60+
OUT_T dequantize_val(IN_T qvalue, float scale_val, int zero_point_val) {
61+
return OUT_T(float(int(qvalue) - zero_point_val) * scale_val);
62+
}
63+
64+
#ifdef USING_BUFFER
65+
66+
void main() {
67+
$if MODE == "per_tensor":
68+
const ivec4 pos = ivec4(
69+
gl_GlobalInvocationID.x,
70+
gl_GlobalInvocationID.y,
71+
gl_GlobalInvocationID.z,
72+
0);
73+
74+
const int t_in_idx = tidx_to_bufi(pos, t_in_strides);
75+
const int t_out_idx = tidx_to_bufi(pos, t_out_strides);
76+
77+
IN_T qvalue = t_in[t_in_idx];
78+
OUT_T value;
79+
80+
value = dequantize_val(qvalue, scale, zero_point);
81+
82+
t_out[t_out_idx] = value;
83+
84+
$if MODE == "per_token":
85+
const ivec4 pos = ivec4(
86+
gl_GlobalInvocationID.x,
87+
gl_GlobalInvocationID.y,
88+
gl_GlobalInvocationID.z,
89+
0);
90+
91+
const int t_in_idx = tidx_to_bufi(pos, t_in_strides);
92+
const int t_out_idx = tidx_to_bufi(pos, t_out_strides);
93+
94+
// Skip if out of bounds
95+
if (t_in_idx >= t_in_sizes.x * t_in_sizes.y * t_in_sizes.z * t_in_sizes.w) {
96+
return;
97+
}
98+
99+
IN_T qvalue = t_in[t_in_idx];
100+
OUT_T value;
101+
102+
// Calculate logical position from linear index and strides
103+
ivec4 logical_pos;
104+
int remaining = t_in_idx;
105+
106+
logical_pos.x = remaining % t_in_sizes.x;
107+
remaining /= t_in_sizes.x;
108+
109+
logical_pos.y = remaining % t_in_sizes.y;
110+
remaining /= t_in_sizes.y;
111+
112+
logical_pos.z = remaining % t_in_sizes.z;
113+
remaining /= t_in_sizes.z;
114+
115+
logical_pos.w = remaining;
116+
117+
// Calculate token index based on logical position
118+
int token_idx = 0;
119+
120+
// Check dimensions to determine how to calculate token_idx
121+
if (t_in_sizes.w > 1) {
122+
// 4D tensor
123+
token_idx = logical_pos.w * (t_in_sizes.z * t_in_sizes.y) + logical_pos.z * t_in_sizes.y + logical_pos.y;
124+
} else if (t_in_sizes.z > 1) {
125+
// 3D tensor
126+
token_idx = logical_pos.z * t_in_sizes.y + logical_pos.y;
127+
} else if (t_in_sizes.y > 1) {
128+
// 2D tensor
129+
token_idx = logical_pos.y;
130+
}
131+
// For 1D tensor, token_idx remains 0
132+
133+
// Make sure token_idx is within bounds
134+
token_idx = min(token_idx, num_tokens - 1);
135+
136+
value = dequantize_val(qvalue, t_scale[token_idx], t_zero_point[token_idx]);
137+
138+
t_out[t_out_idx] = value;
139+
}
140+
141+
#else
142+
143+
void main() {
144+
$if MODE == "per_tensor":
145+
const ivec3 pos = ivec3(gl_GlobalInvocationID);
146+
147+
// Skip if out of bounds
148+
if (any(greaterThanEqual(pos, t_in_limits))) {
149+
return;
150+
}
151+
152+
IVEC4_T intex = load_texel(t_in, pos);
153+
FVEC4_T outtex;
154+
155+
[[unroll]] for (int i = 0; i < 4; ++i) {
156+
IN_T qvalue = IN_T(intex[i]);
157+
OUT_T value = dequantize_val(qvalue, scale, zero_point);
158+
outtex[i] = value;
159+
}
160+
write_texel(t_out, pos, outtex);
161+
162+
$if MODE == "per_token":
163+
const ivec3 pos = ivec3(gl_GlobalInvocationID);
164+
165+
// Skip if out of bounds
166+
if (any(greaterThanEqual(pos, t_in_limits))) {
167+
return;
168+
}
169+
170+
IVEC4_T intex = load_texel(t_in, pos);
171+
172+
int token_idx = 0;
173+
ivec3 dims = t_in_limits;
174+
175+
if (dims.z > 1) {
176+
// 3D tensor
177+
token_idx = pos.z * dims.y + pos.y;
178+
} else if (dims.y > 1) {
179+
// 2D tensor
180+
token_idx = pos.y;
181+
}
182+
// For 1D tensor, token_idx remains 0
183+
184+
// Make sure token_idx is within bounds
185+
token_idx = min(token_idx, num_tokens - 1);
186+
187+
// For texture storage, we need to calculate the texel position and component index
188+
int texel_idx = token_idx / 4;
189+
int comp_idx = token_idx % 4;
190+
191+
vec4 scale_vals = load_texel(t_scale, ivec3(texel_idx, 0, 0));
192+
ivec4 zp_vals = load_texel(t_zero_point, ivec3(texel_idx, 0, 0));
193+
194+
float scale_val = scale_vals[comp_idx];
195+
int zero_point_val = zp_vals[comp_idx];
196+
197+
FVEC4_T outtex;
198+
[[unroll]] for (int i = 0; i < 4; ++i) {
199+
IN_T qvalue = IN_T(intex[i]);
200+
OUT_T value = dequantize_val(qvalue, scale_val, zero_point_val);
201+
outtex[i] = value;
202+
}
203+
204+
write_texel(t_out, pos, outtex);
205+
206+
}
207+
208+
#endif
Lines changed: 22 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,22 @@
1+
dequantize:
2+
parameter_names_with_default_values:
3+
IN_DTYPE: int
4+
OUT_DTYPE: float
5+
STORAGE: buffer
6+
MODE: per_tensor
7+
generate_variant_forall:
8+
STORAGE:
9+
- VALUE: buffer
10+
- VALUE: texture3d
11+
IN_DTYPE:
12+
- VALUE: uint8
13+
- VALUE: int8
14+
- VALUE: int32
15+
OUT_DTYPE:
16+
- VALUE: half
17+
- VALUE: float
18+
shader_variants:
19+
- NAME: dequantize_per_tensor
20+
MODE: per_tensor
21+
- NAME: dequantize_per_token
22+
MODE: per_token

0 commit comments

Comments
 (0)