Skip to content

Commit 2eae7a9

Browse files
SS-JIAfacebook-github-bot
authored andcommitted
Move QMat2 to buffer storage and scales_and_zeros to Channels Packed (#5515)
Summary: Pull Request resolved: #5515 Storing QMat2 in a texture gives way to two main problems: - Indexing is a mess and additional computation is required to take into account the fact that we are reading ivec4's and only using half of the values - There is no texel fetching in int8. The texel is read in int32 and needs to be casted Keeping QMat2 in a buffer performs better because, although reading from buffers is slower, removing the extra computation compensates for this. {F1863459327} This diff also moves the scales_and_zeros tensor to Channels Packed in texture implementations because it just makes more sense, I had done some terrible indexing shennanigans before. ghstack-source-id: 244258611 exported-using-ghexport Reviewed By: yipjustin Differential Revision: D62504978 fbshipit-source-id: df2fdf87f75140be0a316576c8ffad67feefd6d7
1 parent 8be3ce5 commit 2eae7a9

File tree

3 files changed

+42
-41
lines changed

3 files changed

+42
-41
lines changed

backends/vulkan/runtime/graph/ops/glsl/q_4w_linear.glsl

Lines changed: 18 additions & 29 deletions
Original file line numberDiff line numberDiff line change
@@ -26,13 +26,14 @@ layout(std430) buffer;
2626

2727
${layout_declare_tensor(0, "w", "t_out", DTYPE, STORAGE)}
2828
${layout_declare_tensor(1, "r", "t_mat1", DTYPE, STORAGE)}
29-
${layout_declare_tensor(2, "r", "t_mat2", "int8", STORAGE)}
29+
${layout_declare_tensor(2, "r", "t_mat2", "int8", "buffer")}
3030
${layout_declare_tensor(3, "r", "t_scales_and_zeros", DTYPE, STORAGE)}
3131

3232
$if STORAGE == "texture3d":
3333
${layout_declare_ubo(4, "ivec4", "out_sizes")}
3434
${layout_declare_ubo(5, "ivec4", "mat1_sizes")}
35-
${layout_declare_ubo(6, "ivec4", "scales_strides")}
35+
${layout_declare_ubo(6, "ivec4", "mat2_strides")}
36+
${layout_declare_ubo(7, "ivec4", "scales_strides")}
3637
$else:
3738
${layout_declare_ubo(4, "ivec4", "out_sizes")}
3839
${layout_declare_ubo(5, "ivec4", "out_strides")}
@@ -64,9 +65,9 @@ void main() {
6465

6566
float rc = 0.0;
6667
int k = 0;
68+
const uint k_block = (K + group_size - 1) / group_size;
6769

6870
#ifdef USING_BUFFER
69-
const uint k_block = (K + group_size - 1) / group_size;
7071
ivec4 mat1_pos = ivec4(0, m, out_pos.z, out_pos.w);
7172
ivec4 mat2_pos = ivec4(0, n, out_pos.z, out_pos.w);
7273
ivec4 scale_pos = ivec4(0, n, 0, out_pos.w);
@@ -101,42 +102,30 @@ void main() {
101102
t_out[out_bufi] = FLOAT_T(rc);
102103

103104
#else // Using texture
104-
const uint texel_group_size = group_size / FOUR;
105-
const uint k_block = (K + texel_group_size - 1) / texel_group_size;
106105
ivec3 mat1_pos = ivec3(0, m, out_pos.z);
107-
ivec3 mat2_pos = ivec3(0, n, out_pos.z);
108-
ivec3 scale_pos = ivec3(0, n, 0);
109-
ivec3 zero_pos = ivec3(0, n, 1);
106+
ivec4 mat2_pos = ivec4(0, n, out_pos.z, out_pos.w);
107+
ivec3 scale_zero_pos = ivec3(0, n, 0);
108+
uint K_texel = K / FOUR;
110109

111110
for (int kb = 0; kb < k_block; kb++) {
112-
const int texel_kb = kb / FOUR;
113-
const int kb_offset = kb % FOUR;
114-
115-
scale_pos.x = texel_kb;
116-
const VEC4_T scale_texel = load_texel(t_scales_and_zeros, scale_pos);
117-
const float scale = float(scale_texel[kb_offset]);
111+
scale_zero_pos.x = kb;
112+
const vec4 scale_zero = load_texel(t_scales_and_zeros, scale_zero_pos);
113+
const float scale = scale_zero.x;
114+
const float zero = scale_zero.y - scale * 8.0;
118115

119-
zero_pos.x = texel_kb;
120-
const VEC4_T zero_texel = load_texel(t_scales_and_zeros, zero_pos);
121-
const float zero = float(zero_texel[kb_offset]) - scale * 8.0;
122-
123-
for(uint idx = 0; idx < texel_group_size && k < K; idx++, k++) {
116+
for(uint idx = 0; idx < group_size && k < K_texel; idx += FOUR, k++) {
124117
mat1_pos.x = k;
125118
const VEC4_T mat1_tex = load_texel(t_mat1, mat1_pos);
126119

127-
mat2_pos.x = k / 2;
128-
const i8vec4 mat2_tex = i8vec4(load_texel(t_mat2, mat2_pos));
120+
mat2_pos.x = k * 2; // k * FOUR / 2
121+
const int mat2_id = tidx_to_bufi(mat2_pos, mat2_strides);
129122

130-
// Every two texels of mat1 correspond to one texel of mat2
131-
// Even mat1 indeces correspond to first half of mat2 texel and
132-
// odd indeces correspond to second half
133-
const int mat2_offset = (k & 1) == 0 ? 0 : 2;
134-
for (int texel_idx = 0; texel_idx < FOUR; texel_idx++){
123+
for (int texel_pos = 0; texel_pos < FOUR; texel_pos++) {
135124
// Bitwise op treats sign bit from int8 as a value bit instead,
136125
// since there is no uint8_t datatype
137-
uint mat2_val = (mat2_tex[mat2_offset + texel_idx / 2] & 0xFF);
138-
mat2_val = (texel_idx & 1) == 0 ? mat2_val & mask : (mat2_val >> 4);
139-
rc += mat1_tex[texel_idx] * (scale * float(mat2_val) + zero);
126+
uint mat2_val = (t_mat2[mat2_id + texel_pos / 2] & 0xFF);
127+
mat2_val = (texel_pos & 1) == 0 ? mat2_val & mask : (mat2_val >> 4);
128+
rc += mat1_tex[texel_pos] * (scale * float(mat2_val) + zero);
140129
}
141130
}
142131
}

backends/vulkan/runtime/graph/ops/impl/QuantizedMatMul.cpp

Lines changed: 10 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -33,7 +33,13 @@ void check_q_matmul_args(
3333
using namespace WHCN;
3434
VK_CHECK_COND(graph.packed_dim_of(mat1) == kWidthDim);
3535
VK_CHECK_COND(graph.packed_dim_of(mat2_data) == kWidthDim);
36-
VK_CHECK_COND(graph.packed_dim_of(scales_and_zeros) == kWidthDim);
36+
// VK_CHECK_COND(graph.packed_dim_of(scales_and_zeros) == kWidthDim);
37+
38+
if (graph.storage_type_of(scales_and_zeros) == utils::kBuffer) {
39+
VK_CHECK_COND(graph.packed_dim_of(scales_and_zeros) == kWidthDim);
40+
} else {
41+
VK_CHECK_COND(graph.packed_dim_of(scales_and_zeros) == kChannelsDim);
42+
}
3743

3844
if (graph.storage_type_of(out) == utils::kBuffer) {
3945
VK_CHECK_COND(graph.packed_dim_of(out) == kWidthDim);
@@ -106,13 +112,8 @@ void add_q_matmul_node(
106112
const ValueRef out) {
107113
auto storage_type = graph.storage_type_of(out);
108114

109-
ValueRef mat2;
110-
111-
if (storage_type == utils::kBuffer) {
112-
mat2 = prepack_buffer_if_tensor_ref(graph, mat2_data, utils::kWidthPacked);
113-
} else {
114-
mat2 = prepack_if_tensor_ref(graph, mat2_data, utils::kWidthPacked);
115-
}
115+
ValueRef mat2 =
116+
prepack_buffer_if_tensor_ref(graph, mat2_data, utils::kWidthPacked);
116117

117118
ValueRef scales_and_zeros =
118119
prepack_if_tensor_ref(graph, scales_and_zeros_data, utils::kWidthPacked);
@@ -135,6 +136,7 @@ void add_q_matmul_node(
135136
} else {
136137
ubos.append(graph.sizes_ubo(out));
137138
ubos.append(graph.sizes_ubo(mat1));
139+
ubos.append(graph.strides_ubo(mat2));
138140
ubos.append(graph.strides_ubo(scales_and_zeros));
139141
}
140142

backends/vulkan/test/vulkan_compute_api_test.cpp

Lines changed: 14 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -2932,16 +2932,26 @@ void test_int4pack_mm(
29322932
int4mm_pack_weights(mat2_size, B_quant_data.data());
29332933

29342934
IOValueRef B_int4 =
2935-
graph.add_input_tensor(mat2_q_size, vkapi::kQInt8, storage_type);
2935+
graph.add_input_tensor(mat2_q_size, vkapi::kQInt8, utils::kBuffer);
29362936
graph.copy_into_staging(
29372937
B_int4.staging, B_int4_data.data(), B_int4_data.size());
29382938

29392939
const int k_groups = K / group_size;
29402940

29412941
// Random scales and zeroes. Keep scales small to avoid overflow and zeroes in
29422942
// int4 range
2943-
IOValueRef scales_and_zeros =
2944-
graph.add_input_tensor({2, N, k_groups}, vkapi::kFloat, storage_type);
2943+
IOValueRef scales_and_zeros;
2944+
2945+
if (storage_type == utils::kBuffer) {
2946+
scales_and_zeros.value = graph.add_tensor(
2947+
{2, N, k_groups}, vkapi::kFloat, storage_type, utils::kWidthPacked);
2948+
} else {
2949+
scales_and_zeros.value = graph.add_tensor(
2950+
{2, N, k_groups}, vkapi::kFloat, storage_type, utils::kChannelsPacked);
2951+
}
2952+
2953+
scales_and_zeros.staging = graph.set_input_tensor(scales_and_zeros.value);
2954+
29452955
std::vector<float> s_data(graph.numel_of(scales_and_zeros.value));
29462956
const int zeros_stride = s_data.size() / 2;
29472957
for (size_t i = 0; i < zeros_stride; i++) {
@@ -3003,7 +3013,7 @@ void test_int4pack_mm(
30033013
out_deq.staging, out_deq_data.data(), out_deq_data.size());
30043014

30053015
for (int i = 0; i < out_int4_data.size(); i++) {
3006-
CHECK_VALUE(out_int4_data, i, out_deq_data[i]);
3016+
EXPECT_TRUE(check_close(out_int4_data[i], out_deq_data[i]));
30073017
}
30083018
}
30093019

0 commit comments

Comments
 (0)