12
12
13
13
#define PRECISION ${PRECISION}
14
14
15
+ #define FOUR 4
16
+
17
+ #define VEC4_T ${texel_load_type(DTYPE, STORAGE)}
15
18
#define FLOAT_T ${buffer_scalar_type(DTYPE)}
16
19
17
20
${define_active_storage_type(STORAGE)}
@@ -26,12 +29,17 @@ ${layout_declare_tensor(1, "r", "t_mat1", DTYPE, STORAGE)}
26
29
${layout_declare_tensor(2 , "r", "t_mat2", "int8", STORAGE)}
27
30
${layout_declare_tensor(3 , "r", "t_scales_and_zeros", DTYPE, STORAGE)}
28
31
29
- ${layout_declare_ubo(4 , "ivec4 ", "out_sizes")}
30
- ${layout_declare_ubo(5 , "ivec4 ", "out_strides")}
31
- ${layout_declare_ubo(6 , "ivec4 ", "mat1_strides")}
32
- ${layout_declare_ubo(7 , "ivec4 ", "mat2_sizes")}
33
- ${layout_declare_ubo(8 , "ivec4 ", "mat2_strides")}
34
- ${layout_declare_ubo(9 , "ivec4 ", "scales_strides")}
32
+ $if STORAGE == "texture3d":
33
+ ${layout_declare_ubo(4 , "ivec4 ", "out_sizes")}
34
+ ${layout_declare_ubo(5 , "ivec4 ", "mat1_sizes")}
35
+ ${layout_declare_ubo(6 , "ivec4 ", "scales_strides")}
36
+ $else :
37
+ ${layout_declare_ubo(4 , "ivec4 ", "out_sizes")}
38
+ ${layout_declare_ubo(5 , "ivec4 ", "out_strides")}
39
+ ${layout_declare_ubo(6 , "ivec4 ", "mat1_sizes")}
40
+ ${layout_declare_ubo(7 , "ivec4 ", "mat1_strides")}
41
+ ${layout_declare_ubo(8 , "ivec4 ", "mat2_strides")}
42
+ ${layout_declare_ubo(9 , "ivec4 ", "scales_strides")}
35
43
36
44
layout (local_size_x_id = 0 , local_size_y_id = 1 , local_size_z_id = 2 ) in ;
37
45
@@ -49,45 +57,90 @@ void main() {
49
57
return ;
50
58
}
51
59
52
- const uint K = mat2_sizes.x * 2 ;
53
- const uint N = mat2_sizes.y;
60
+ const uint K = mat1_sizes.x;
54
61
const uint n = out_pos.x;
55
62
const uint m = out_pos.y;
56
- const uint k_block = (K + group_size - 1 ) / group_size;
57
63
const uint mask = uint (0x0f);
58
- ivec4 mat1_pos = ivec4 (0 , m, out_pos.z, out_pos.w);
59
- ivec4 mat2_pos = ivec4 (0 , n, out_pos.z, out_pos.w);
60
- ivec4 scale_pos = ivec4 (0 , n, 0 , out_pos.w);
61
- ivec4 zero_pos = ivec4 (0 , n, 1 , out_pos.w);
62
64
63
65
float rc = 0.0 ;
64
66
int k = 0 ;
65
67
66
- for (int kb = 0 ; kb < k_block; kb++ ) {
67
- scale_pos.x = kb;
68
- const int scale_id = to_buffer_id(scale_pos, scales_strides);
69
- const float scale = float (t_scales_and_zeros[scale_id]);
70
-
71
- zero_pos.x = kb;
72
- const int zero_id = to_buffer_id(zero_pos, scales_strides);
73
- const float zero = float (t_scales_and_zeros[zero_id]) - scale * 8.0 ;
74
-
75
- for (uint idx = 0 ; idx < group_size && k < K; idx++ , k++ ) {
76
- mat1_pos.x = k;
77
- const int mat1_id = to_buffer_id(mat1_pos, mat1_strides);
78
- const float mat1_val = float (t_mat1[mat1_id]);
79
-
80
- mat2_pos.x = k / 2 ;
81
- const int mat2_id = to_buffer_id(mat2_pos, mat2_strides);
82
- // Bitwise op treats sign bit from int8 as a value bit instead,
83
- // since there is no uint8_t datatype
84
- uint mat2_val = (t_mat2[mat2_id] & 0xFF);
85
- mat2_val = (k & 1 ) == 0 ? mat2_val & mask : (mat2_val >> 4 );
68
+ #ifdef USING_BUFFER
69
+ const uint k_block = (K + group_size - 1 ) / group_size;
70
+ ivec4 mat1_pos = ivec4 (0 , m, out_pos.z, out_pos.w);
71
+ ivec4 mat2_pos = ivec4 (0 , n, out_pos.z, out_pos.w);
72
+ ivec4 scale_pos = ivec4 (0 , n, 0 , out_pos.w);
73
+ ivec4 zero_pos = ivec4 (0 , n, 1 , out_pos.w);
74
+
75
+ for (int kb = 0 ; kb < k_block; kb++ ) {
76
+ scale_pos.x = kb;
77
+ const int scale_id = to_buffer_id(scale_pos, scales_strides);
78
+ const float scale = float (t_scales_and_zeros[scale_id]);
79
+
80
+ zero_pos.x = kb;
81
+ const int zero_id = to_buffer_id(zero_pos, scales_strides);
82
+ const float zero = float (t_scales_and_zeros[zero_id]) - scale * 8.0 ;
83
+
84
+ for (uint idx = 0 ; idx < group_size && k < K; idx++ , k++ ) {
85
+ mat1_pos.x = k;
86
+ const int mat1_id = to_buffer_id(mat1_pos, mat1_strides);
87
+ const float mat1_val = float (t_mat1[mat1_id]);
88
+
89
+ mat2_pos.x = k / 2 ;
90
+ const int mat2_id = to_buffer_id(mat2_pos, mat2_strides);
91
+ // Bitwise op treats sign bit from int8 as a value bit instead,
92
+ // since there is no uint8_t datatype
93
+ uint mat2_val = (t_mat2[mat2_id] & 0xFF);
94
+ mat2_val = (k & 1 ) == 0 ? mat2_val & mask : (mat2_val >> 4 );
95
+
96
+ rc += mat1_val * (scale * float (mat2_val) + zero);
97
+ }
98
+ }
86
99
87
- rc += mat1_val * (scale * float (mat2_val) + zero);
100
+ const int out_id = to_buffer_id(out_pos, out_strides);
101
+ t_out[out_id] = FLOAT_T(rc);
102
+
103
+ #else // Using texture
104
+ const uint texel_group_size = group_size / FOUR;
105
+ const uint k_block = (K + texel_group_size - 1 ) / texel_group_size;
106
+ ivec3 mat1_pos = ivec3 (0 , m, out_pos.z);
107
+ ivec3 mat2_pos = ivec3 (0 , n, out_pos.z);
108
+ ivec3 scale_pos = ivec3 (0 , n, 0 );
109
+ ivec3 zero_pos = ivec3 (0 , n, 1 );
110
+
111
+ for (int kb = 0 ; kb < k_block; kb++ ) {
112
+ const int texel_kb = kb / FOUR;
113
+ const int kb_offset = kb % FOUR;
114
+
115
+ scale_pos.x = texel_kb;
116
+ const VEC4_T scale_texel = load_texel(t_scales_and_zeros, scale_pos);
117
+ const float scale = float (scale_texel[kb_offset]);
118
+
119
+ zero_pos.x = texel_kb;
120
+ const VEC4_T zero_texel = load_texel(t_scales_and_zeros, zero_pos);
121
+ const float zero = float (zero_texel[kb_offset]) - scale * 8.0 ;
122
+
123
+ for (uint idx = 0 ; idx < texel_group_size && k < K; idx++ , k++ ) {
124
+ mat1_pos.x = k;
125
+ const VEC4_T mat1_tex = load_texel(t_mat1, mat1_pos);
126
+
127
+ mat2_pos.x = k / 2 ;
128
+ const i8vec4 mat2_tex = i8vec4(load_texel(t_mat2, mat2_pos));
129
+
130
+ // Every two texels of mat1 correspond to one texel of mat2
131
+ // Even mat1 indeces correspond to first half of mat2 texel and
132
+ // odd indeces correspond to second half
133
+ const int mat2_offset = (k & 1 ) == 0 ? 0 : 2 ;
134
+ for (int texel_idx = 0 ; texel_idx < FOUR; texel_idx++ ){
135
+ // Bitwise op treats sign bit from int8 as a value bit instead,
136
+ // since there is no uint8_t datatype
137
+ uint mat2_val = (mat2_tex[mat2_offset + texel_idx / 2 ] & 0xFF);
138
+ mat2_val = (texel_idx & 1 ) == 0 ? mat2_val & mask : (mat2_val >> 4 );
139
+ rc += mat1_tex[texel_idx] * (scale * float (mat2_val) + zero);
140
+ }
141
+ }
88
142
}
89
- }
143
+ write_texel(t_out, out_pos.xyz, vec4 (rc, 0 , 0 , 0 ));
90
144
91
- const int out_id = to_buffer_id(out_pos, out_strides);
92
- t_out[out_id] = FLOAT_T(rc);
145
+ #endif
93
146
}
0 commit comments