Skip to content

Commit 36f2cb5

Browse files
author
morelos
committed
[ET-VK][Ops] quantization op shaders and impl
Creating the quantize_per_tensor and quantize_per_token logic shaders and impl which are linked with the testing framework. NOTE: Currently the only input types supported are **half** (fp16) and **float** (fp32). The only output types supported are **byte** (uint8), **char** (int8), **short** (int16), **int** (int32). Differential Revision: [D75959064](https://our.internmc.facebook.com/intern/diff/D75959064/) ghstack-source-id: 288187842 Pull Request resolved: #11369
1 parent 769b753 commit 36f2cb5

File tree

4 files changed

+607
-0
lines changed

4 files changed

+607
-0
lines changed
Lines changed: 236 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,236 @@
1+
/*
2+
* Copyright (c) Meta Platforms, Inc. and affiliates.
3+
* All rights reserved.
4+
*
5+
* This source code is licensed under the BSD-style license found in the
6+
* LICENSE file in the root directory of this source tree.
7+
*/
8+
9+
#version 450 core
10+
11+
#define PRECISION ${PRECISION}
12+
13+
#define IN_T ${buffer_scalar_type(IN_DTYPE)}
14+
#define FVEC4_T ${texel_load_type(IN_DTYPE, STORAGE)}
15+
16+
#define OUT_T ${buffer_scalar_type(OUT_DTYPE)}
17+
#define IVEC4_T ${texel_load_type(OUT_DTYPE, STORAGE)}
18+
19+
${define_active_storage_type(STORAGE)}
20+
${define_required_extensions(IN_DTYPE)}
21+
${define_required_extensions(OUT_DTYPE)}
22+
23+
#extension GL_EXT_control_flow_attributes : require
24+
25+
layout(std430) buffer;
26+
27+
${layout_declare_tensor(B, "r", "t_in", IN_DTYPE, STORAGE)}
28+
${layout_declare_tensor(B, "w", "t_out", OUT_DTYPE, STORAGE)}
29+
30+
$if MODE == "per_tensor":
31+
layout(push_constant) uniform restrict Block {
32+
float scale;
33+
int zero_point;
34+
int quant_min;
35+
int quant_max;
36+
};
37+
$else:
38+
${layout_declare_tensor(B, "r", "t_scale", "float", STORAGE)}
39+
${layout_declare_tensor(B, "r", "t_zero_point", "int", STORAGE)}
40+
41+
layout(push_constant) uniform restrict Block {
42+
int num_tokens;
43+
int quant_min;
44+
int quant_max;
45+
};
46+
47+
$if STORAGE == "buffer":
48+
${layout_declare_ubo(B, "ivec4", "t_in_sizes")}
49+
${layout_declare_ubo(B, "ivec4", "t_in_strides")}
50+
${layout_declare_ubo(B, "ivec4", "t_out_sizes")}
51+
${layout_declare_ubo(B, "ivec4", "t_out_strides")}
52+
$else:
53+
${layout_declare_ubo(B, "ivec3", "t_in_limits")}
54+
${layout_declare_ubo(B, "ivec3", "t_out_limits")}
55+
56+
#include "indexing_utils.h"
57+
58+
layout(local_size_x_id = 0, local_size_y_id = 1, local_size_z_id = 2) in;
59+
60+
OUT_T quantize_val(IN_T value, float scale_val, int zero_point_val) {
61+
// Use int for all intermediate calculations to match CPU implementation
62+
// which uses int64_t/int32_t for all calculations before final casting
63+
int qvalue;
64+
65+
if (scale_val == 0.0) {
66+
// When scale is 0, CPU implementation would produce a very large value
67+
// that gets clamped to quant_min or quant_max
68+
if (value < 0.0) {
69+
qvalue = quant_min;
70+
} else if (value > 0.0) {
71+
qvalue = quant_max;
72+
} else {
73+
qvalue = zero_point_val; // value is exactly 0
74+
}
75+
} else {
76+
float inv_scale = 1.0 / scale_val;
77+
78+
float rounded_float = round(inv_scale * float(value));
79+
80+
// Convert to int and add zero point (all in signed integer space)
81+
qvalue = zero_point_val + int(rounded_float);
82+
}
83+
84+
// Apply clamping in int space before final cast to output type
85+
qvalue = max(qvalue, quant_min);
86+
qvalue = min(qvalue, quant_max);
87+
88+
// Only cast to output type at the very end
89+
return OUT_T(qvalue);
90+
}
91+
92+
#ifdef USING_BUFFER
93+
94+
void main() {
95+
$if MODE == "per_tensor":
96+
const ivec4 pos = ivec4(
97+
gl_GlobalInvocationID.x,
98+
gl_GlobalInvocationID.y,
99+
gl_GlobalInvocationID.z,
100+
0);
101+
102+
const int t_in_idx = tidx_to_bufi(pos, t_in_strides);
103+
const int t_out_idx = tidx_to_bufi(pos, t_out_strides);
104+
105+
IN_T value = t_in[t_in_idx];
106+
OUT_T qvalue;
107+
108+
qvalue = quantize_val(value, scale, zero_point);
109+
110+
t_out[t_out_idx] = qvalue;
111+
112+
$if MODE == "per_token":
113+
const ivec4 pos = ivec4(
114+
gl_GlobalInvocationID.x,
115+
gl_GlobalInvocationID.y,
116+
gl_GlobalInvocationID.z,
117+
0);
118+
119+
const int t_in_idx = tidx_to_bufi(pos, t_in_strides);
120+
const int t_out_idx = tidx_to_bufi(pos, t_out_strides);
121+
122+
// Skip if out of bounds
123+
if (t_in_idx >= t_in_sizes.x * t_in_sizes.y * t_in_sizes.z * t_in_sizes.w) {
124+
return;
125+
}
126+
127+
IN_T value = t_in[t_in_idx];
128+
OUT_T qvalue;
129+
130+
// Calculate logical position from linear index and strides
131+
ivec4 logical_pos;
132+
int remaining = t_in_idx;
133+
134+
logical_pos.x = remaining % t_in_sizes.x;
135+
remaining /= t_in_sizes.x;
136+
137+
logical_pos.y = remaining % t_in_sizes.y;
138+
remaining /= t_in_sizes.y;
139+
140+
logical_pos.z = remaining % t_in_sizes.z;
141+
remaining /= t_in_sizes.z;
142+
143+
logical_pos.w = remaining;
144+
145+
// Calculate token index based on logical position
146+
int token_idx = 0;
147+
148+
// Check dimensions to determine how to calculate token_idx
149+
if (t_in_sizes.w > 1) {
150+
// 4D tensor
151+
token_idx = logical_pos.w * (t_in_sizes.z * t_in_sizes.y) + logical_pos.z * t_in_sizes.y + logical_pos.y;
152+
} else if (t_in_sizes.z > 1) {
153+
// 3D tensor
154+
token_idx = logical_pos.z * t_in_sizes.y + logical_pos.y;
155+
} else if (t_in_sizes.y > 1) {
156+
// 2D tensor
157+
token_idx = logical_pos.y;
158+
}
159+
// For 1D tensor, token_idx remains 0
160+
161+
// Make sure token_idx is within bounds
162+
token_idx = min(token_idx, num_tokens - 1);
163+
164+
qvalue = quantize_val(value, t_scale[token_idx], t_zero_point[token_idx]);
165+
166+
t_out[t_out_idx] = qvalue;
167+
}
168+
169+
#else
170+
171+
void main() {
172+
$if MODE == "per_tensor":
173+
const ivec3 pos = ivec3(gl_GlobalInvocationID);
174+
175+
// Skip if out of bounds
176+
if (any(greaterThanEqual(pos, t_in_limits))) {
177+
return;
178+
}
179+
180+
FVEC4_T intex = load_texel(t_in, pos);
181+
IVEC4_T outtex;
182+
183+
[[unroll]] for (int i = 0; i < 4; ++i) {
184+
IN_T value = IN_T(intex[i]);
185+
OUT_T qvalue = quantize_val(value, scale, zero_point);
186+
outtex[i] = qvalue;
187+
}
188+
write_texel(t_out, pos, outtex);
189+
190+
$if MODE == "per_token":
191+
const ivec3 pos = ivec3(gl_GlobalInvocationID);
192+
193+
// Skip if out of bounds
194+
if (any(greaterThanEqual(pos, t_in_limits))) {
195+
return;
196+
}
197+
198+
FVEC4_T intex = load_texel(t_in, pos);
199+
200+
int token_idx = 0;
201+
ivec3 dims = t_in_limits;
202+
203+
if (dims.z > 1) {
204+
// 3D tensor
205+
token_idx = pos.z * dims.y + pos.y;
206+
} else if (dims.y > 1) {
207+
// 2D tensor
208+
token_idx = pos.y;
209+
}
210+
// For 1D tensor, token_idx remains 0
211+
212+
// Make sure token_idx is within bounds
213+
token_idx = min(token_idx, num_tokens - 1);
214+
215+
// For texture storage, we need to calculate the texel position and component index
216+
int texel_idx = token_idx / 4;
217+
int comp_idx = token_idx % 4;
218+
219+
vec4 scale_vals = load_texel(t_scale, ivec3(texel_idx, 0, 0));
220+
ivec4 zp_vals = load_texel(t_zero_point, ivec3(texel_idx, 0, 0));
221+
222+
float scale_val = scale_vals[comp_idx];
223+
int zero_point_val = zp_vals[comp_idx];
224+
225+
IVEC4_T outtex;
226+
[[unroll]] for (int i = 0; i < 4; ++i) {
227+
IN_T value = IN_T(intex[i]);
228+
OUT_T qvalue = quantize_val(value, scale_val, zero_point_val);
229+
outtex[i] = qvalue;
230+
}
231+
232+
write_texel(t_out, pos, outtex);
233+
234+
}
235+
236+
#endif
Lines changed: 24 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,24 @@
1+
quantize:
2+
parameter_names_with_default_values:
3+
IN_DTYPE: float
4+
OUT_DTYPE: int
5+
STORAGE: buffer
6+
MODE: per_tensor
7+
generate_variant_forall:
8+
STORAGE:
9+
- VALUE: buffer
10+
- VALUE: texture3d
11+
IN_DTYPE:
12+
- VALUE: half
13+
- VALUE: float
14+
- VALUE: double
15+
OUT_DTYPE:
16+
- VALUE: uint8
17+
- VALUE: int8
18+
- VALUE: short
19+
- VALUE: int
20+
shader_variants:
21+
- NAME: quantize_per_tensor
22+
MODE: per_tensor
23+
- NAME: quantize_per_token
24+
MODE: per_token

0 commit comments

Comments
 (0)