Skip to content

Commit 8f6c16e

Browse files
Kush Rastogifacebook-github-bot
authored andcommitted
Removing q_linear.h and adding tiled q_linear implementation (#5492)
Summary: Pull Request resolved: #5492 Removes q_linear.h and moves implementation directly to q_8w_linear.glsl Reviewed By: nathanaelsee, jorgep31415 Differential Revision: D61309097 fbshipit-source-id: a4776b489b393fc6bafe93e174278e88177d0307
1 parent 8d6d18a commit 8f6c16e

File tree

6 files changed

+394
-84
lines changed

6 files changed

+394
-84
lines changed

backends/vulkan/partitioner/supported_ops.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -84,6 +84,7 @@ def __contains__(self, op):
8484
exir_ops.edge.aten.addmm.default,
8585
exir_ops.edge.aten.linear.default,
8686
exir_ops.edge.et_vk.linear_weight_int4.default,
87+
exir_ops.edge.aten._weight_int8pack_mm.default,
8788
# Reduction
8889
exir_ops.edge.aten._log_softmax.default,
8990
exir_ops.edge.aten._softmax.default,

backends/vulkan/runtime/graph/ops/glsl/q_8w_linear.glsl

Lines changed: 59 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -44,10 +44,38 @@ layout(local_size_x_id = 0, local_size_y_id = 1, local_size_z_id = 2) in;
4444
// This header file must be defined after the layout descriptors have been
4545
// declared because the functions in the header assume some variables have been
4646
// declared as layout descriptors.
47-
#include "q_linear.h"
4847

4948
#ifdef USING_BUFFER
5049

50+
#ifndef FLOAT_T
51+
#define FLOAT_T float
52+
#endif
53+
54+
FLOAT_T q_8w_linear(const ivec4 out_idx, const int K) {
55+
const FLOAT_T scale = t_scales[out_idx.x];
56+
57+
FLOAT_T outval = FLOAT_T(0.0);
58+
59+
// Initial mat1 tensor idx will be (0, out_idx.y, out_idx.z, 0)
60+
int mat1_offset = out_idx.y * mat1_strides.y + out_idx.z * qmat2_strides.z;
61+
// Initial qmat2 tensor idx wil be (0, out_idx.x, 0, 0); note that the qmat2
62+
// tensor is transposed
63+
int qmat2_offset = out_idx.x * qmat2_strides.y;
64+
65+
// TODO(ssjia): optimize memory access pattern by traversing K in inner loop
66+
for (int i = 0; i < K; i++) {
67+
const FLOAT_T mat1_val = t_mat1[mat1_offset];
68+
const FLOAT_T mat2_val = t_qmat2[qmat2_offset] * scale;
69+
70+
outval += mat1_val * mat2_val;
71+
72+
mat1_offset++;
73+
qmat2_offset++;
74+
}
75+
76+
return outval;
77+
}
78+
5179
void main() {
5280
const int out_bufi = int(gl_GlobalInvocationID.x);
5381
if (out_bufi >= out_numel) {
@@ -61,6 +89,36 @@ void main() {
6189

6290
#else // USING_TEXTURE
6391

92+
VEC4_T q_8w_linear(const ivec3 out_pos, const int K) {
93+
ivec3 mat1_pos = ivec3(0, out_pos.yz);
94+
ivec3 qmat2_pos = ivec3(0, out_pos.x * 4, 0);
95+
96+
VEC4_T outtex = VEC4_T(0);
97+
98+
const ivec3 scales_pos = ivec3(out_pos.x, 0, 0);
99+
const VEC4_T scales = load_texel(t_scales, scales_pos);
100+
101+
for (int i = 0; i < K; i += 4) {
102+
const VEC4_T mat1_tex = load_texel(t_mat1, mat1_pos);
103+
104+
const VEC4_T sums = VEC4_T(
105+
dot(mat1_tex, load_texel(t_qmat2, qmat2_pos) * scales.x),
106+
dot(mat1_tex,
107+
load_texel(t_qmat2, qmat2_pos + ivec3(0, 1, 0)) * scales.y),
108+
dot(mat1_tex,
109+
load_texel(t_qmat2, qmat2_pos + ivec3(0, 2, 0)) * scales.z),
110+
dot(mat1_tex,
111+
load_texel(t_qmat2, qmat2_pos + ivec3(0, 3, 0)) * scales.w));
112+
113+
outtex += sums;
114+
115+
mat1_pos.x++;
116+
qmat2_pos.x++;
117+
}
118+
119+
return outtex;
120+
}
121+
64122
void main() {
65123
const ivec3 out_pos = ivec3(gl_GlobalInvocationID);
66124
if (any(greaterThanEqual(out_pos, out_limits))) {
Lines changed: 211 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,211 @@
1+
/*
2+
* Copyright (c) Meta Platforms, Inc. and affiliates.
3+
* All rights reserved.
4+
*
5+
* This source code is licensed under the BSD-style license found in the
6+
* LICENSE file in the root directory of this source tree.
7+
*/
8+
9+
#version 450 core
10+
11+
#define PRECISION ${PRECISION}
12+
13+
#define VEC4_T ${texel_load_type(DTYPE, STORAGE)}
14+
#define FLOAT_T ${buffer_scalar_type(DTYPE)}
15+
16+
${define_active_storage_type(STORAGE)}
17+
18+
${define_required_extensions(DTYPE)}
19+
${define_required_extensions("int8")}
20+
21+
22+
$if BATCH_MODE:
23+
#define BATCH_MODE
24+
25+
#define TILE_ROWS ${TILE_ROWS}
26+
#define FOUR 4
27+
28+
// we avoid mat4 and vec4 usage here as they compile to much less efficient
29+
// SPIR-V
30+
struct FloatMatrix_2d {
31+
float data[TILE_ROWS][FOUR];
32+
};
33+
34+
struct FloatMatrix_3d {
35+
float data[TILE_ROWS][FOUR][FOUR];
36+
};
37+
38+
#ifdef BATCH_MODE
39+
#define FloatMatrix FloatMatrix_3d
40+
#else
41+
#define FloatMatrix FloatMatrix_2d
42+
#endif
43+
44+
#include "indexing_utils.h"
45+
46+
layout(std430) buffer;
47+
48+
${layout_declare_tensor(0, "w", "t_out", DTYPE, STORAGE)}
49+
${layout_declare_tensor(1, "r", "t_mat1", DTYPE, STORAGE)}
50+
${layout_declare_tensor(2, "r", "t_qmat2", "int8", STORAGE)}
51+
${layout_declare_tensor(3, "r", "t_scales", DTYPE, STORAGE)}
52+
53+
$if STORAGE == "buffer":
54+
${layout_declare_ubo(4, "ivec4", "out_sizes")}
55+
${layout_declare_ubo(5, "ivec4", "out_strides")}
56+
${layout_declare_ubo(6, "int", "out_numel")}
57+
${layout_declare_ubo(7, "ivec4", "mat1_sizes")}
58+
${layout_declare_ubo(8, "ivec4", "mat1_strides")}
59+
${layout_declare_ubo(9, "ivec4", "qmat2_strides")}
60+
${layout_declare_ubo(10, "ivec4", "scales_strides")}
61+
$else:
62+
${layout_declare_ubo(4, "ivec3", "out_limits")}
63+
${layout_declare_ubo(5, "ivec4", "mat1_sizes")}
64+
65+
layout(local_size_x_id = 0, local_size_y_id = 1, local_size_z_id = 2) in;
66+
67+
// This header file must be defined after the layout descriptors have been
68+
// declared because the functions in the header assume some variables have been
69+
// declared as layout descriptors.
70+
71+
#ifdef USING_BUFFER
72+
73+
#ifndef FLOAT_T
74+
#define FLOAT_T float
75+
#endif
76+
77+
FLOAT_T q_8w_linear(const ivec4 out_idx, const int K) {
78+
const FLOAT_T scale = t_scales[out_idx.x];
79+
80+
FLOAT_T outval = FLOAT_T(0.0);
81+
82+
// Initial mat1 tensor idx will be (0, out_idx.y, out_idx.z, 0)
83+
int mat1_offset = out_idx.y * mat1_strides.y + out_idx.z * qmat2_strides.z;
84+
// Initial qmat2 tensor idx wil be (0, out_idx.x, 0, 0); note that the qmat2
85+
// tensor is transposed
86+
int qmat2_offset = out_idx.x * qmat2_strides.y;
87+
88+
// TODO(ssjia): optimize memory access pattern by traversing K in inner loop
89+
for (int i = 0; i < K; i++) {
90+
const FLOAT_T mat1_val = t_mat1[mat1_offset];
91+
const FLOAT_T mat2_val = t_qmat2[qmat2_offset] * scale;
92+
93+
outval += mat1_val * mat2_val;
94+
95+
mat1_offset++;
96+
qmat2_offset++;
97+
}
98+
99+
return outval;
100+
}
101+
102+
void main() {
103+
const int out_bufi = int(gl_GlobalInvocationID.x);
104+
if (out_bufi >= out_numel) {
105+
return;
106+
}
107+
108+
const ivec4 out_tidx = bufi_to_tidx(out_bufi, out_strides, 0);
109+
110+
t_out[out_bufi] = q_8w_linear(out_tidx, mat1_sizes.x);
111+
}
112+
113+
#else // USING_TEXTURE
114+
FloatMatrix q_8w_linear_optimized(const ivec3 out_idx_tl) {
115+
FloatMatrix results;
116+
for (int i = 0; i < TILE_ROWS; i++) {
117+
for (int j = 0; j < FOUR; j++) {
118+
#ifdef BATCH_MODE
119+
for (int k = 0; k < FOUR; k++) {
120+
results.data[i][j][k] = 0.0f;
121+
}
122+
#else
123+
results.data[i][j] = 0.0f;
124+
#endif // BATCH_MODE
125+
}
126+
}
127+
128+
VEC4_T im_mat1_partial_load[TILE_ROWS];
129+
VEC4_T im_mat2_partial_load[FOUR];
130+
131+
#ifdef BATCH_MODE
132+
for (int batch_idx = 0; batch_idx < FOUR; batch_idx++) {
133+
if (out_idx_tl.z + batch_idx >= out_limits.z) {
134+
break;
135+
}
136+
#endif
137+
for (int k = 0; k < mat1_sizes.x; k++) {
138+
for (int r = 0; r < TILE_ROWS; r++) {
139+
ivec3 mat1_pos = ivec3(k, out_idx_tl.y * TILE_ROWS + r, 0);
140+
#ifdef BATCH_MODE
141+
mat1_pos[2] = out_idx_tl.z + batch_idx;
142+
#endif
143+
144+
im_mat1_partial_load[r] = texelFetch(t_mat1, mat1_pos, 0);
145+
}
146+
147+
for (int r = 0; r < FOUR; ++r) {
148+
ivec3 qmat2_pos = ivec3(k, FOUR * out_idx_tl.x + r, 0);
149+
150+
im_mat2_partial_load[r] = texelFetch(t_qmat2, qmat2_pos, 0);
151+
}
152+
153+
vec4 scales = texelFetch(t_scales, ivec3(out_idx_tl.x, 0, 0), 0);
154+
155+
// perform partial dot products and add partial result to results
156+
for (int out_row = 0; out_row < TILE_ROWS; out_row++) {
157+
for (int out_col = 0; out_col < FOUR; out_col++) {
158+
#ifdef BATCH_MODE
159+
results.data[out_row][out_col][batch_idx] +=
160+
#else
161+
results.data[out_row][out_col] +=
162+
#endif
163+
dot(im_mat1_partial_load[out_row],
164+
im_mat2_partial_load[out_col] * scales[out_col]);
165+
}
166+
}
167+
}
168+
#ifdef BATCH_MODE
169+
}
170+
#endif
171+
return results;
172+
}
173+
174+
void main() {
175+
const ivec3 out_idx = ivec3(gl_GlobalInvocationID);
176+
if (any(greaterThanEqual(out_idx, out_limits))) {
177+
return;
178+
}
179+
180+
FloatMatrix results = q_8w_linear_optimized(out_idx);
181+
182+
ivec3 out_pos = ivec3(
183+
out_idx.x,
184+
out_idx.y * TILE_ROWS,
185+
#ifdef BATCH_MODE
186+
out_idx.z * 4
187+
#else
188+
out_idx.z
189+
#endif
190+
);
191+
192+
for (int idx_c = 0; idx_c < TILE_ROWS; idx_c++, out_pos[1]++) {
193+
out_pos.x = out_idx.x;
194+
$if BATCH_MODE:
195+
for (int idx_r = 0; idx_r < FOUR; idx_r++, out_pos[0]++) {
196+
write_texel(t_out, out_pos, VEC4_T(
197+
results.data[idx_c][idx_r][0],
198+
results.data[idx_c][idx_r][1],
199+
results.data[idx_c][idx_r][2],
200+
results.data[idx_c][idx_r][3]));
201+
}
202+
$else:
203+
write_texel(t_out, out_pos, VEC4_T(
204+
results.data[idx_c][0],
205+
results.data[idx_c][1],
206+
results.data[idx_c][2],
207+
results.data[idx_c][3]));
208+
}
209+
}
210+
211+
#endif
Lines changed: 35 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,35 @@
1+
# Copyright (c) Meta Platforms, Inc. and affiliates.
2+
# All rights reserved.
3+
#
4+
# This source code is licensed under the BSD-style license found in the
5+
# LICENSE file in the root directory of this source tree.
6+
7+
q_8w_linear_optimized:
8+
parameter_names_with_default_values:
9+
DTYPE: float
10+
STORAGE: texture3d
11+
MAT1_PACKING: W_packed
12+
MAT2_PACKING: W_packed
13+
BATCH_MODE: false
14+
TILE_ROWS: 4
15+
generate_variant_forall:
16+
TILE_ROWS:
17+
- VALUE: 4
18+
SUFFIX: tile_row_4
19+
- VALUE: 2
20+
SUFFIX: tile_row_2
21+
DTYPE:
22+
- VALUE: float
23+
- VALUE: half
24+
STORAGE:
25+
- VALUE: texture3d
26+
- VALUE: buffer
27+
shader_variants:
28+
- NAME: q_8w_linear_optimized_W_packed_W_packed
29+
- NAME: q_8w_linear_optimized_W_packed_H_packed
30+
MAT2_PACKING: H_packed
31+
- NAME: batch_q_8w_linear_optimized_W_packed_W_packed
32+
BATCH_MODE: true
33+
- NAME: batch_q_8w_linear_optimized_W_packed_H_packed
34+
MAT2_PACKING: H_packed
35+
BATCH_MODE: true

0 commit comments

Comments
 (0)