Skip to content

Commit 140ea1b

Browse files
author
morelos
committed
[ET-VK][Ops] dequantization op shaders and impl
Creating the dequantize_per_tensor and dequantize_per_token logic shaders and impl which are linked with the testing framework. Differential Revision: [D76267107](https://our.internmc.facebook.com/intern/diff/D76267107/) ghstack-source-id: 289116371 Pull Request resolved: #11483
1 parent ec886b8 commit 140ea1b

File tree

4 files changed

+644
-41
lines changed

4 files changed

+644
-41
lines changed
Lines changed: 222 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,222 @@
1+
/*
2+
* Copyright (c) Meta Platforms, Inc. and affiliates.
3+
* All rights reserved.
4+
*
5+
* This source code is licensed under the BSD-style license found in the
6+
* LICENSE file in the root directory of this source tree.
7+
*/
8+
9+
#version 450 core
10+
11+
#define PRECISION ${PRECISION}
12+
13+
#define IN_T ${buffer_scalar_type(IN_DTYPE)}
14+
#define IVEC4_T ${texel_load_type(IN_DTYPE, STORAGE)}
15+
16+
#define OUT_T ${buffer_scalar_type(OUT_DTYPE)}
17+
#define FVEC4_T ${texel_load_type(OUT_DTYPE, STORAGE)}
18+
19+
${define_active_storage_type(STORAGE)}
20+
${define_required_extensions(IN_DTYPE)}
21+
${define_required_extensions(OUT_DTYPE)}
22+
23+
// Need this in order to properly handle overflow for dequantize_val
24+
// since there is an inconsistency between the cpu logic
25+
#extension GL_EXT_shader_explicit_arithmetic_types_int64 : require
26+
27+
#extension GL_EXT_control_flow_attributes : require
28+
29+
layout(std430) buffer;
30+
31+
${layout_declare_tensor(B, "r", "t_in", IN_DTYPE, STORAGE)}
32+
${layout_declare_tensor(B, "w", "t_out", OUT_DTYPE, STORAGE)}
33+
34+
$if MODE == "per_tensor":
35+
layout(push_constant) uniform restrict Block {
36+
float scale;
37+
int zero_point;
38+
int quant_min;
39+
int quant_max;
40+
};
41+
$else:
42+
${layout_declare_tensor(B, "r", "t_scale", "float", STORAGE)}
43+
${layout_declare_tensor(B, "r", "t_zero_point", "int", STORAGE)}
44+
45+
layout(push_constant) uniform restrict Block {
46+
int num_tokens;
47+
int quant_min;
48+
int quant_max;
49+
};
50+
51+
$if STORAGE == "buffer":
52+
${layout_declare_ubo(B, "ivec4", "t_in_sizes")}
53+
${layout_declare_ubo(B, "ivec4", "t_in_strides")}
54+
${layout_declare_ubo(B, "ivec4", "t_out_sizes")}
55+
${layout_declare_ubo(B, "ivec4", "t_out_strides")}
56+
$else:
57+
${layout_declare_ubo(B, "ivec3", "t_in_limits")}
58+
${layout_declare_ubo(B, "ivec3", "t_out_limits")}
59+
60+
#include "indexing_utils.h"
61+
62+
layout(local_size_x_id = 0, local_size_y_id = 1, local_size_z_id = 2) in;
63+
64+
OUT_T dequantize_val(IN_T qvalue, float scale_val, int zero_point_val) {
65+
$if MODE == "per_tensor":
66+
// out_data_ptr[i] = static_cast<OUT_CTYPE>((input_data_ptr[i] - static_cast<int32_t>(zero_point)) * scale);
67+
// -2147483648 - 100 = 2147483548 * 0.0001 > cast to float32
68+
OUT_T value = OUT_T(float(int(qvalue) - zero_point_val) * scale_val);
69+
70+
$if MODE == "per_token":
71+
// out_data_ptr[i] = static_cast<OUT_CTYPE>((input_data_ptr[i] - zero_point) * scale);
72+
// -2147483648 - 100 = -2147483748 * 0.0001 > cast to float32
73+
OUT_T value = OUT_T(float(int(qvalue) - int64_t(zero_point_val)) * scale_val);
74+
75+
return value;
76+
}
77+
78+
#ifdef USING_BUFFER
79+
80+
void main() {
81+
$if MODE == "per_tensor":
82+
const ivec4 pos = ivec4(
83+
gl_GlobalInvocationID.x,
84+
gl_GlobalInvocationID.y,
85+
gl_GlobalInvocationID.z,
86+
0);
87+
88+
const int t_in_idx = tidx_to_bufi(pos, t_in_strides);
89+
const int t_out_idx = tidx_to_bufi(pos, t_out_strides);
90+
91+
IN_T qvalue = t_in[t_in_idx];
92+
OUT_T value;
93+
94+
value = dequantize_val(qvalue, scale, zero_point);
95+
96+
t_out[t_out_idx] = value;
97+
98+
$if MODE == "per_token":
99+
const ivec4 pos = ivec4(
100+
gl_GlobalInvocationID.x,
101+
gl_GlobalInvocationID.y,
102+
gl_GlobalInvocationID.z,
103+
0);
104+
105+
const int t_in_idx = tidx_to_bufi(pos, t_in_strides);
106+
const int t_out_idx = tidx_to_bufi(pos, t_out_strides);
107+
108+
// Skip if out of bounds
109+
if (t_in_idx >= t_in_sizes.x * t_in_sizes.y * t_in_sizes.z * t_in_sizes.w) {
110+
return;
111+
}
112+
113+
IN_T qvalue = t_in[t_in_idx];
114+
OUT_T value;
115+
116+
// Calculate logical position from linear index and strides
117+
ivec4 logical_pos;
118+
int remaining = t_in_idx;
119+
120+
logical_pos.x = remaining % t_in_sizes.x;
121+
remaining /= t_in_sizes.x;
122+
123+
logical_pos.y = remaining % t_in_sizes.y;
124+
remaining /= t_in_sizes.y;
125+
126+
logical_pos.z = remaining % t_in_sizes.z;
127+
remaining /= t_in_sizes.z;
128+
129+
logical_pos.w = remaining;
130+
131+
// Calculate token index based on logical position
132+
int token_idx = 0;
133+
134+
// Check dimensions to determine how to calculate token_idx
135+
if (t_in_sizes.w > 1) {
136+
// 4D tensor
137+
token_idx = logical_pos.w * (t_in_sizes.z * t_in_sizes.y) + logical_pos.z * t_in_sizes.y + logical_pos.y;
138+
} else if (t_in_sizes.z > 1) {
139+
// 3D tensor
140+
token_idx = logical_pos.z * t_in_sizes.y + logical_pos.y;
141+
} else if (t_in_sizes.y > 1) {
142+
// 2D tensor
143+
token_idx = logical_pos.y;
144+
}
145+
// For 1D tensor, token_idx remains 0
146+
147+
// Make sure token_idx is within bounds
148+
token_idx = min(token_idx, num_tokens - 1);
149+
150+
value = dequantize_val(qvalue, t_scale[token_idx], t_zero_point[token_idx]);
151+
152+
t_out[t_out_idx] = value;
153+
}
154+
155+
#else
156+
157+
void main() {
158+
$if MODE == "per_tensor":
159+
const ivec3 pos = ivec3(gl_GlobalInvocationID);
160+
161+
// Skip if out of bounds
162+
if (any(greaterThanEqual(pos, t_in_limits))) {
163+
return;
164+
}
165+
166+
IVEC4_T intex = load_texel(t_in, pos);
167+
FVEC4_T outtex;
168+
169+
[[unroll]] for (int i = 0; i < 4; ++i) {
170+
IN_T qvalue = IN_T(intex[i]);
171+
OUT_T value = dequantize_val(qvalue, scale, zero_point);
172+
outtex[i] = value;
173+
}
174+
write_texel(t_out, pos, outtex);
175+
176+
$if MODE == "per_token":
177+
const ivec3 pos = ivec3(gl_GlobalInvocationID);
178+
179+
// Skip if out of bounds
180+
if (any(greaterThanEqual(pos, t_in_limits))) {
181+
return;
182+
}
183+
184+
IVEC4_T intex = load_texel(t_in, pos);
185+
186+
int token_idx = 0;
187+
ivec3 dims = t_in_limits;
188+
189+
if (dims.z > 1) {
190+
// 3D tensor
191+
token_idx = pos.z * dims.y + pos.y;
192+
} else if (dims.y > 1) {
193+
// 2D tensor
194+
token_idx = pos.y;
195+
}
196+
// For 1D tensor, token_idx remains 0
197+
198+
// Make sure token_idx is within bounds
199+
token_idx = min(token_idx, num_tokens - 1);
200+
201+
// For texture storage, we need to calculate the texel position and component index
202+
int texel_idx = token_idx / 4;
203+
int comp_idx = token_idx % 4;
204+
205+
vec4 scale_vals = load_texel(t_scale, ivec3(texel_idx, 0, 0));
206+
ivec4 zp_vals = load_texel(t_zero_point, ivec3(texel_idx, 0, 0));
207+
208+
float scale_val = scale_vals[comp_idx];
209+
int zero_point_val = zp_vals[comp_idx];
210+
211+
FVEC4_T outtex;
212+
[[unroll]] for (int i = 0; i < 4; ++i) {
213+
IN_T qvalue = IN_T(intex[i]);
214+
OUT_T value = dequantize_val(qvalue, scale_val, zero_point_val);
215+
outtex[i] = value;
216+
}
217+
218+
write_texel(t_out, pos, outtex);
219+
220+
}
221+
222+
#endif
Lines changed: 22 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,22 @@
1+
dequantize:
2+
parameter_names_with_default_values:
3+
IN_DTYPE: int
4+
OUT_DTYPE: float
5+
STORAGE: buffer
6+
MODE: per_tensor
7+
generate_variant_forall:
8+
STORAGE:
9+
- VALUE: buffer
10+
- VALUE: texture3d
11+
IN_DTYPE:
12+
- VALUE: uint8
13+
- VALUE: int8
14+
- VALUE: int32
15+
OUT_DTYPE:
16+
- VALUE: half
17+
- VALUE: float
18+
shader_variants:
19+
- NAME: dequantize_per_tensor
20+
MODE: per_tensor
21+
- NAME: dequantize_per_token
22+
MODE: per_token

0 commit comments

Comments
 (0)