Skip to content

Commit f8c5a60

Browse files
Kush Rastogifacebook-github-bot
authored andcommitted
Adding Tiled 2D and 3D Quantizer Linear Base Implementation (#5492)
Summary: Pull Request resolved: #5492 Adding Tiled Implementation of Weight-Only Quantized Linear operator This diff adds Texture Implementation, will add Buffer impl next. # Diff Stack 1. Add Tiled Implementation of Weight-Only Quantized Linear 2. Add Optimized Quantized Linear Shader and code to invoke shader from Quantized Linear CPP operator 3. [Will Not Land] Use Optimized Quantized Linear implementation Differential Revision: D61309097
1 parent 47f4f07 commit f8c5a60

File tree

2 files changed

+115
-0
lines changed

2 files changed

+115
-0
lines changed

backends/vulkan/runtime/graph/ops/glsl/q_8w_linear.glsl

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,8 @@
88

99
#version 450 core
1010

11+
#extension GL_EXT_control_flow_attributes : require
12+
1113
#define PRECISION ${PRECISION}
1214

1315
#define VEC4_T ${texel_load_type(DTYPE, STORAGE)}

backends/vulkan/runtime/graph/ops/glsl/q_linear.h

Lines changed: 113 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,24 @@
1111

1212
#include "indexing_utils.h"
1313

14+
// To convince the SPIR-V compiler to unroll the loops optimally, need this
15+
// macro
16+
#define FOUR 4
17+
18+
#ifdef TILE_ROW_2
19+
#define TILE_ROWS 2
20+
#else
21+
#define TILE_ROWS 4
22+
#endif
23+
24+
struct FloatMatrix_2d {
25+
float data[TILE_ROWS][FOUR];
26+
};
27+
28+
struct FloatMatrix_3d {
29+
float data[TILE_ROWS][FOUR][FOUR];
30+
};
31+
1432
// The functions in this file assume that some variables have been defined as
1533
// descriptors, such as t_mat1, t_qmat2, t_scales, etc.
1634

@@ -77,6 +95,101 @@ VEC4_T q_8w_linear(const ivec3 out_pos, const int K) {
7795
return outtex;
7896
}
7997

98+
FloatMatrix_2d q_8w_linear_optimized_2d(const ivec3 out_pos, const int K) {
99+
FloatMatrix_2d results;
100+
101+
VEC4_T im_mat1_partial_load[TILE_ROWS];
102+
ivec4 im_mat2_partial_load[FOUR];
103+
104+
[[unroll]] for (int i = 0; i < TILE_ROWS; i++) {
105+
[[unroll]] for (int j = 0; j < FOUR; j++) { results.data[i][j] = 0.0f; }
106+
}
107+
108+
for (int mat1_x = 0; mat1_x < K; mat1_x++) {
109+
[[unroll]] for (int offset = 0; offset < TILE_ROWS; offset++) {
110+
const int mat1_y = out_pos.y * TILE_ROWS + offset;
111+
const ivec3 mat1_pos = ivec3(mat1_x, mat1_y, 0);
112+
im_mat1_partial_load[offset] = load_texel(t_mat1, mat1_pos);
113+
}
114+
[[unroll]] for (int offset = 0; offset < FOUR; offset++) {
115+
const int mat2_y = (FOUR * out_pos.x) + offset;
116+
const ivec3 mat2_pos = ivec3(mat1_x, mat2_y, 0);
117+
im_mat2_partial_load[offset] = load_texel(t_qmat2, mat2_pos);
118+
}
119+
120+
[[unroll]] for (int out_row = 0; out_row < TILE_ROWS; out_row++) {
121+
[[unroll]] for (int out_col = 0; out_col < FOUR; out_col++) {
122+
results.data[out_row][out_col] +=
123+
dot(im_mat1_partial_load[out_row], im_mat2_partial_load[out_col]);
124+
}
125+
}
126+
}
127+
128+
const VEC4_T scales = load_texel(t_scales, ivec3(out_pos.x, 0, 0));
129+
[[unroll]] for (int i = 0; i < TILE_ROWS; i++) {
130+
[[unroll]] for (int j = 0; j < FOUR; j++) {
131+
results.data[i][j] *= scales[j];
132+
}
133+
}
134+
return results;
135+
}
136+
137+
FloatMatrix_3d q_8w_linear_optimized_3d(
138+
const ivec3 out_pos,
139+
const int K,
140+
const int batch_size) {
141+
FloatMatrix_3d results;
142+
143+
[[unroll]] for (int i = 0; i < TILE_ROWS; i++) {
144+
[[unroll]] for (int j = 0; j < FOUR; j++) {
145+
[[unroll]] for (int k = 0; k < FOUR; k++) {
146+
results.data[i][j][k] = 0.0f;
147+
}
148+
}
149+
}
150+
151+
VEC4_T im_mat1_partial_load[TILE_ROWS];
152+
ivec4 im_mat2_partial_load[FOUR];
153+
154+
const VEC4_T scales = load_texel(t_scales, ivec3(out_pos.x, 0, 0));
155+
156+
for (int batch_idx = 0; batch_idx < batch_size; batch_idx++) {
157+
if (FOUR * out_pos.z + batch_idx >= batch_size) {
158+
break;
159+
}
160+
int mat_z = FOUR * out_pos.z + batch_idx;
161+
for (int mat1_x = 0; mat1_x < K; mat1_x++) {
162+
[[unroll]] for (int offset = 0; offset < TILE_ROWS; offset++) {
163+
// read and cache 2x4 (or 4x4) tile of im_mat1
164+
const int mat1_y = (TILE_ROWS * out_pos.y) + offset;
165+
const ivec3 mat1_pos = ivec3(mat1_x, mat1_y, mat_z);
166+
im_mat1_partial_load[offset] = load_texel(t_mat1, mat1_pos);
167+
}
168+
169+
[[unroll]] for (int offset = 0; offset < FOUR; offset++) {
170+
// read and cache 4x4 tile of im_mat2
171+
const int mat2_y = (FOUR * out_pos.x) + offset;
172+
const ivec3 mat2_pos = ivec3(mat1_x, mat2_y, 0);
173+
im_mat2_partial_load[offset] = load_texel(t_qmat2, mat2_pos);
174+
}
175+
176+
[[unroll]] for (int out_row = 0; out_row < TILE_ROWS; out_row++) {
177+
[[unroll]] for (int out_col = 0; out_col < FOUR; out_col++) {
178+
results.data[out_row][out_col][batch_idx] +=
179+
dot(im_mat1_partial_load[out_row], im_mat2_partial_load[out_col]);
180+
}
181+
}
182+
}
183+
184+
[[unroll]] for (int i = 0; i < TILE_ROWS; i++) {
185+
[[unroll]] for (int j = 0; j < FOUR; j++) {
186+
results.data[i][j][batch_idx] *= scales[j];
187+
}
188+
}
189+
}
190+
return results;
191+
}
192+
80193
#endif // USING_BUFFER
81194

82195
#endif // Q_LINEAR_H

0 commit comments

Comments
 (0)