Skip to content

Commit e73dce2

Browse files
authored
[INT4-MM| Add Texture3D storage type
Differential Revision: D62148863 Pull Request resolved: #5044
1 parent 79b97e4 commit e73dce2

File tree

4 files changed

+154
-64
lines changed

4 files changed

+154
-64
lines changed

backends/vulkan/runtime/graph/ops/glsl/q_4w_linear.glsl

Lines changed: 90 additions & 37 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,9 @@
1212

1313
#define PRECISION ${PRECISION}
1414

15+
#define FOUR 4
16+
17+
#define VEC4_T ${texel_load_type(DTYPE, STORAGE)}
1518
#define FLOAT_T ${buffer_scalar_type(DTYPE)}
1619

1720
${define_active_storage_type(STORAGE)}
@@ -26,12 +29,17 @@ ${layout_declare_tensor(1, "r", "t_mat1", DTYPE, STORAGE)}
2629
${layout_declare_tensor(2, "r", "t_mat2", "int8", STORAGE)}
2730
${layout_declare_tensor(3, "r", "t_scales_and_zeros", DTYPE, STORAGE)}
2831

29-
${layout_declare_ubo(4, "ivec4", "out_sizes")}
30-
${layout_declare_ubo(5, "ivec4", "out_strides")}
31-
${layout_declare_ubo(6, "ivec4", "mat1_strides")}
32-
${layout_declare_ubo(7, "ivec4", "mat2_sizes")}
33-
${layout_declare_ubo(8, "ivec4", "mat2_strides")}
34-
${layout_declare_ubo(9, "ivec4", "scales_strides")}
32+
$if STORAGE == "texture3d":
33+
${layout_declare_ubo(4, "ivec4", "out_sizes")}
34+
${layout_declare_ubo(5, "ivec4", "mat1_sizes")}
35+
${layout_declare_ubo(6, "ivec4", "scales_strides")}
36+
$else:
37+
${layout_declare_ubo(4, "ivec4", "out_sizes")}
38+
${layout_declare_ubo(5, "ivec4", "out_strides")}
39+
${layout_declare_ubo(6, "ivec4", "mat1_sizes")}
40+
${layout_declare_ubo(7, "ivec4", "mat1_strides")}
41+
${layout_declare_ubo(8, "ivec4", "mat2_strides")}
42+
${layout_declare_ubo(9, "ivec4", "scales_strides")}
3543

3644
layout(local_size_x_id = 0, local_size_y_id = 1, local_size_z_id = 2) in;
3745

@@ -49,45 +57,90 @@ void main() {
4957
return;
5058
}
5159

52-
const uint K = mat2_sizes.x * 2;
53-
const uint N = mat2_sizes.y;
60+
const uint K = mat1_sizes.x;
5461
const uint n = out_pos.x;
5562
const uint m = out_pos.y;
56-
const uint k_block = (K + group_size - 1) / group_size;
5763
const uint mask = uint(0x0f);
58-
ivec4 mat1_pos = ivec4(0, m, out_pos.z, out_pos.w);
59-
ivec4 mat2_pos = ivec4(0, n, out_pos.z, out_pos.w);
60-
ivec4 scale_pos = ivec4(0, n, 0, out_pos.w);
61-
ivec4 zero_pos = ivec4(0, n, 1, out_pos.w);
6264

6365
float rc = 0.0;
6466
int k = 0;
6567

66-
for (int kb = 0; kb < k_block; kb++) {
67-
scale_pos.x = kb;
68-
const int scale_id = to_buffer_id(scale_pos, scales_strides);
69-
const float scale = float(t_scales_and_zeros[scale_id]);
70-
71-
zero_pos.x = kb;
72-
const int zero_id = to_buffer_id(zero_pos, scales_strides);
73-
const float zero = float(t_scales_and_zeros[zero_id]) - scale * 8.0;
74-
75-
for(uint idx = 0; idx < group_size && k < K; idx++, k++) {
76-
mat1_pos.x = k;
77-
const int mat1_id = to_buffer_id(mat1_pos, mat1_strides);
78-
const float mat1_val = float(t_mat1[mat1_id]);
79-
80-
mat2_pos.x = k / 2;
81-
const int mat2_id = to_buffer_id(mat2_pos, mat2_strides);
82-
// Bitwise op treats sign bit from int8 as a value bit instead,
83-
// since there is no uint8_t datatype
84-
uint mat2_val = (t_mat2[mat2_id] & 0xFF);
85-
mat2_val = (k & 1) == 0 ? mat2_val & mask : (mat2_val >> 4);
68+
#ifdef USING_BUFFER
69+
const uint k_block = (K + group_size - 1) / group_size;
70+
ivec4 mat1_pos = ivec4(0, m, out_pos.z, out_pos.w);
71+
ivec4 mat2_pos = ivec4(0, n, out_pos.z, out_pos.w);
72+
ivec4 scale_pos = ivec4(0, n, 0, out_pos.w);
73+
ivec4 zero_pos = ivec4(0, n, 1, out_pos.w);
74+
75+
for (int kb = 0; kb < k_block; kb++) {
76+
scale_pos.x = kb;
77+
const int scale_id = to_buffer_id(scale_pos, scales_strides);
78+
const float scale = float(t_scales_and_zeros[scale_id]);
79+
80+
zero_pos.x = kb;
81+
const int zero_id = to_buffer_id(zero_pos, scales_strides);
82+
const float zero = float(t_scales_and_zeros[zero_id]) - scale * 8.0;
83+
84+
for(uint idx = 0; idx < group_size && k < K; idx++, k++) {
85+
mat1_pos.x = k;
86+
const int mat1_id = to_buffer_id(mat1_pos, mat1_strides);
87+
const float mat1_val = float(t_mat1[mat1_id]);
88+
89+
mat2_pos.x = k / 2;
90+
const int mat2_id = to_buffer_id(mat2_pos, mat2_strides);
91+
// Bitwise op treats sign bit from int8 as a value bit instead,
92+
// since there is no uint8_t datatype
93+
uint mat2_val = (t_mat2[mat2_id] & 0xFF);
94+
mat2_val = (k & 1) == 0 ? mat2_val & mask : (mat2_val >> 4);
95+
96+
rc += mat1_val * (scale * float(mat2_val) + zero);
97+
}
98+
}
8699

87-
rc += mat1_val * (scale * float(mat2_val) + zero);
100+
const int out_id = to_buffer_id(out_pos, out_strides);
101+
t_out[out_id] = FLOAT_T(rc);
102+
103+
#else // Using texture
104+
const uint texel_group_size = group_size / FOUR;
105+
const uint k_block = (K + texel_group_size - 1) / texel_group_size;
106+
ivec3 mat1_pos = ivec3(0, m, out_pos.z);
107+
ivec3 mat2_pos = ivec3(0, n, out_pos.z);
108+
ivec3 scale_pos = ivec3(0, n, 0);
109+
ivec3 zero_pos = ivec3(0, n, 1);
110+
111+
for (int kb = 0; kb < k_block; kb++) {
112+
const int texel_kb = kb / FOUR;
113+
const int kb_offset = kb % FOUR;
114+
115+
scale_pos.x = texel_kb;
116+
const VEC4_T scale_texel = load_texel(t_scales_and_zeros, scale_pos);
117+
const float scale = float(scale_texel[kb_offset]);
118+
119+
zero_pos.x = texel_kb;
120+
const VEC4_T zero_texel = load_texel(t_scales_and_zeros, zero_pos);
121+
const float zero = float(zero_texel[kb_offset]) - scale * 8.0;
122+
123+
for(uint idx = 0; idx < texel_group_size && k < K; idx++, k++) {
124+
mat1_pos.x = k;
125+
const VEC4_T mat1_tex = load_texel(t_mat1, mat1_pos);
126+
127+
mat2_pos.x = k / 2;
128+
const i8vec4 mat2_tex = i8vec4(load_texel(t_mat2, mat2_pos));
129+
130+
// Every two texels of mat1 correspond to one texel of mat2
131+
// Even mat1 indeces correspond to first half of mat2 texel and
132+
// odd indeces correspond to second half
133+
const int mat2_offset = (k & 1) == 0 ? 0 : 2;
134+
for (int texel_idx = 0; texel_idx < FOUR; texel_idx++){
135+
// Bitwise op treats sign bit from int8 as a value bit instead,
136+
// since there is no uint8_t datatype
137+
uint mat2_val = (mat2_tex[mat2_offset + texel_idx / 2] & 0xFF);
138+
mat2_val = (texel_idx & 1) == 0 ? mat2_val & mask : (mat2_val >> 4);
139+
rc += mat1_tex[texel_idx] * (scale * float(mat2_val) + zero);
140+
}
141+
}
88142
}
89-
}
143+
write_texel(t_out, out_pos.xyz, vec4(rc, 0, 0, 0));
90144

91-
const int out_id = to_buffer_id(out_pos, out_strides);
92-
t_out[out_id] = FLOAT_T(rc);
145+
#endif
93146
}

backends/vulkan/runtime/graph/ops/glsl/q_4w_linear.yaml

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -12,5 +12,8 @@ q_4w_linear:
1212
DTYPE:
1313
- VALUE: float
1414
- VALUE: half
15+
STORAGE:
16+
- VALUE: buffer
17+
- VALUE: texture3d
1518
shader_variants:
1619
- NAME: q_4w_linear

backends/vulkan/runtime/graph/ops/impl/QuantizedMatMul.cpp

Lines changed: 33 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -30,7 +30,16 @@ void check_q_matmul_args(
3030
VK_CHECK_COND(mat1_sizes.size() == 2);
3131
VK_CHECK_COND(mat1_sizes.size() == mat2_sizes.size());
3232

33-
VK_CHECK_COND(graph.memory_layout_of(mat1) == graph.memory_layout_of(out));
33+
VK_CHECK_COND(graph.memory_layout_of(mat1) == utils::kWidthPacked);
34+
VK_CHECK_COND(graph.memory_layout_of(mat2_data) == utils::kWidthPacked);
35+
VK_CHECK_COND(
36+
graph.memory_layout_of(scales_and_zeros) == utils::kWidthPacked);
37+
38+
if (graph.storage_type_of(out) == utils::kBuffer) {
39+
VK_CHECK_COND(graph.memory_layout_of(out) == utils::kWidthPacked);
40+
} else {
41+
VK_CHECK_COND(graph.memory_layout_of(out) == utils::kChannelsPacked);
42+
}
3443

3544
const int mat1_K = utils::val_at(-1, mat1_sizes);
3645
const int mat2_K = utils::val_at(-1, mat2_sizes) * 2;
@@ -95,24 +104,39 @@ void add_q_matmul_node(
95104
const ValueRef group_size,
96105
const ValueRef scales_and_zeros_data,
97106
const ValueRef out) {
98-
ValueRef mat2 =
99-
prepack_buffer_if_tensor_ref(graph, mat2_data, utils::kWidthPacked);
107+
auto storage_type = graph.storage_type_of(out);
108+
109+
ValueRef mat2;
110+
111+
if (storage_type == utils::kBuffer) {
112+
mat2 = prepack_buffer_if_tensor_ref(graph, mat2_data, utils::kWidthPacked);
113+
} else {
114+
mat2 = prepack_if_tensor_ref(graph, mat2_data, utils::kWidthPacked);
115+
}
116+
100117
ValueRef scales_and_zeros =
101118
prepack_if_tensor_ref(graph, scales_and_zeros_data, utils::kWidthPacked);
102119

103120
std::string kernel_name = "q_4w_linear";
104121

105122
add_dtype_suffix(kernel_name, graph.dtype_of(out));
123+
add_storage_type_suffix(kernel_name, storage_type);
106124

107125
const uint32_t group_size_val = graph.extract_scalar<uint32_t>(group_size);
108126

109127
vkapi::ParamsBindList ubos({});
110-
ubos.append(graph.sizes_ubo(out));
111-
ubos.append(graph.strides_ubo(out));
112-
ubos.append(graph.strides_ubo(mat1));
113-
ubos.append(graph.sizes_ubo(mat2));
114-
ubos.append(graph.strides_ubo(mat2));
115-
ubos.append(graph.strides_ubo(scales_and_zeros));
128+
if (storage_type == utils::kBuffer) {
129+
ubos.append(graph.sizes_ubo(out));
130+
ubos.append(graph.strides_ubo(out));
131+
ubos.append(graph.sizes_ubo(mat1));
132+
ubos.append(graph.strides_ubo(mat1));
133+
ubos.append(graph.strides_ubo(mat2));
134+
ubos.append(graph.strides_ubo(scales_and_zeros));
135+
} else {
136+
ubos.append(graph.sizes_ubo(out));
137+
ubos.append(graph.sizes_ubo(mat1));
138+
ubos.append(graph.strides_ubo(scales_and_zeros));
139+
}
116140

117141
auto out_sizes = graph.sizes_of(out);
118142
uint32_t N = utils::val_at(-1, out_sizes);

backends/vulkan/test/vulkan_compute_api_test.cpp

Lines changed: 28 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -2738,7 +2738,10 @@ TEST(VulkanComputeGraphOpsTest, grid_priors_test) {
27382738
/*data_out_expected = */ {4, 4, 12, 4, 20, 4, 4, 12, 12, 12, 20, 12});
27392739
}
27402740

2741-
void test_int4pack_mm(std::vector<uint32_t> MKN, uint32_t group_size) {
2741+
void test_int4pack_mm(
2742+
std::vector<uint32_t> MKN,
2743+
uint32_t group_size,
2744+
utils::StorageType storage_type) {
27422745
GraphConfig config;
27432746
ComputeGraph graph(config);
27442747

@@ -2752,8 +2755,7 @@ void test_int4pack_mm(std::vector<uint32_t> MKN, uint32_t group_size) {
27522755
const std::vector<int64_t> out_size = {M, N};
27532756

27542757
std::vector<float> A_data = create_random_float_buffer(M * K);
2755-
IOValueRef A =
2756-
graph.add_input_tensor(mat1_size, vkapi::kFloat, utils::kBuffer);
2758+
IOValueRef A = graph.add_input_tensor(mat1_size, vkapi::kFloat, storage_type);
27572759
graph.copy_into_staging(A.staging, A_data.data(), A_data.size());
27582760

27592761
// Quantized but un-packed weights
@@ -2764,7 +2766,7 @@ void test_int4pack_mm(std::vector<uint32_t> MKN, uint32_t group_size) {
27642766
int4mm_pack_weights(mat2_size, B_quant_data.data());
27652767

27662768
IOValueRef B_int4 =
2767-
graph.add_input_tensor(mat2_q_size, vkapi::kQInt8, utils::kBuffer);
2769+
graph.add_input_tensor(mat2_q_size, vkapi::kQInt8, storage_type);
27682770
graph.copy_into_staging(
27692771
B_int4.staging, B_int4_data.data(), B_int4_data.size());
27702772

@@ -2773,7 +2775,7 @@ void test_int4pack_mm(std::vector<uint32_t> MKN, uint32_t group_size) {
27732775
// Random scales and zeroes. Keep scales small to avoid overflow and zeroes in
27742776
// int4 range
27752777
IOValueRef scales_and_zeros =
2776-
graph.add_input_tensor({2, N, k_groups}, vkapi::kFloat, utils::kBuffer);
2778+
graph.add_input_tensor({2, N, k_groups}, vkapi::kFloat, storage_type);
27772779
std::vector<float> s_data(graph.numel_of(scales_and_zeros.value));
27782780
const int zeros_stride = s_data.size() / 2;
27792781
for (size_t i = 0; i < zeros_stride; i++) {
@@ -2785,7 +2787,13 @@ void test_int4pack_mm(std::vector<uint32_t> MKN, uint32_t group_size) {
27852787
scales_and_zeros.staging, s_data.data(), s_data.size());
27862788

27872789
IOValueRef out_int4;
2788-
out_int4.value = graph.add_tensor(out_size, vkapi::kFloat, utils::kBuffer);
2790+
2791+
if (storage_type == utils::kBuffer) {
2792+
out_int4.value = graph.add_tensor(out_size, vkapi::kFloat, utils::kBuffer);
2793+
} else {
2794+
out_int4.value =
2795+
graph.add_tensor(out_size, vkapi::kFloat, utils::kChannelsPacked);
2796+
}
27892797

27902798
VK_GET_OP_FN("aten._weight_int4pack_mm.default")
27912799
(graph,
@@ -2799,13 +2807,13 @@ void test_int4pack_mm(std::vector<uint32_t> MKN, uint32_t group_size) {
27992807

28002808
// Dequantized matmul for comparison
28012809
IOValueRef B_deq =
2802-
graph.add_input_tensor(mat2_size, vkapi::kFloat, utils::kBuffer);
2810+
graph.add_input_tensor(mat2_size, vkapi::kFloat, storage_type);
28032811
std::vector<float> B_deq_data = int4mm_dequantize_weights(
28042812
mat2_size, B_quant_data.data(), group_size, s_data.data());
28052813
graph.copy_into_staging(B_deq.staging, B_deq_data.data(), B_deq_data.size());
28062814

28072815
IOValueRef out_deq;
2808-
out_deq.value = graph.add_tensor(out_size, vkapi::kFloat, utils::kBuffer);
2816+
out_deq.value = graph.add_tensor(out_size, vkapi::kFloat, storage_type);
28092817

28102818
VK_GET_OP_FN("aten.mm.default")
28112819
(graph, {A.value, B_deq.value, out_deq.value});
@@ -2838,18 +2846,20 @@ TEST(VulkanComputeGraphOpsTest, int4pack_mm_test) {
28382846
GTEST_SKIP();
28392847
}
28402848

2841-
// Vector multiplication, single group per row
2842-
test_int4pack_mm({1, 32, 1}, 32);
2849+
for (auto storage_type : {utils::kBuffer, utils::kTexture3D}) {
2850+
// Vector multiplication, single group per row
2851+
test_int4pack_mm({1, 32, 1}, 32, storage_type);
28432852

2844-
// Vector multiplication, multiple groups per row
2845-
test_int4pack_mm({1, 256, 1}, 64);
2853+
// Vector multiplication, multiple groups per row
2854+
test_int4pack_mm({1, 256, 1}, 64, storage_type);
28462855

2847-
// Square matrices, single group per row
2848-
test_int4pack_mm({32, 32, 32}, 32);
2856+
// Square matrices, single group per row
2857+
test_int4pack_mm({32, 32, 32}, 32, storage_type);
28492858

2850-
// Irregular matrices, single group per row
2851-
test_int4pack_mm({37, 32, 19}, 32);
2859+
// Irregular matrices, single group per row
2860+
test_int4pack_mm({37, 32, 19}, 32, storage_type);
28522861

2853-
// Irregular matrices, multiple groups per row
2854-
test_int4pack_mm({37, 256, 19}, 64);
2862+
// Irregular matrices, multiple groups per row
2863+
test_int4pack_mm({37, 256, 19}, 64, storage_type);
2864+
}
28552865
}

0 commit comments

Comments
 (0)