5
5
6
6
layout(local_size_x_id = 0, local_size_y = 1, local_size_z = 1) in;
7
7
8
+ shared FLOAT_TYPE sccache[BLOCK_SIZE/16][2][16];
9
+
8
10
void compute_outputs(const uint32_t first_row, const uint32_t num_rows) {
9
11
uint a_offset, b_offset, d_offset;
10
12
get_offsets(a_offset, b_offset, d_offset);
@@ -16,6 +18,7 @@ void compute_outputs(const uint32_t first_row, const uint32_t num_rows) {
16
18
const uint tid = gl_LocalInvocationID.x;
17
19
const uint itid = tid%16; // 0...15
18
20
const uint ix = tid/16;
21
+ const uint itid8 = itid%8;
19
22
20
23
const uint v_im = itid/8; // 0 or 1. 0 computes 0..., 1 computes 128...
21
24
const uint v_in = itid - 8*v_im; // 0...15 or 0...7
@@ -42,18 +45,9 @@ void compute_outputs(const uint32_t first_row, const uint32_t num_rows) {
42
45
const FLOAT_TYPE dall = d.x;
43
46
const FLOAT_TYPE dmin = d.y;
44
47
45
- uint32_t s0_u32 = data_a_packed32[ib0 + i].scales[s_offset / 4];
46
- uint32_t s4_u32 = data_a_packed32[ib0 + i].scales[s_offset / 4 + 1];
47
-
48
- uint32_t s0_lo4_u32 = s0_u32 & 0x0F0F0F0F;
49
- uint32_t s0_hi4_u32 = (s0_u32 >> 4) & 0x0F0F0F0F;
50
- uint32_t s4_lo4_u32 = s4_u32 & 0x0F0F0F0F;
51
- uint32_t s4_hi4_u32 = (s4_u32 >> 4) & 0x0F0F0F0F;
52
-
53
- uvec4 s0_lo4 = uvec4(unpack8(s0_lo4_u32));
54
- uvec4 s4_lo4 = uvec4(unpack8(s4_lo4_u32));
55
- uvec4 s0_hi4 = uvec4(unpack8(s0_hi4_u32));
56
- uvec4 s4_hi4 = uvec4(unpack8(s4_hi4_u32));
48
+ sccache[ix][0][itid] = FLOAT_TYPE(bitfieldExtract(uint(data_a[ib0 + i].scales[itid8]), int(v_im*4), 4)); // lower 8 bytes
49
+ sccache[ix][1][itid] = FLOAT_TYPE(bitfieldExtract(uint(data_a[ib0 + i].scales[itid8+8]), int(v_im*4), 4)); // upper 8 bytes
50
+ barrier();
57
51
58
52
uint16_t qs0_u16 = data_a_packed16[ib0 + i].qs[q_offset / 2];
59
53
uint16_t qs16_u16 = data_a_packed16[ib0 + i].qs[q_offset / 2 + 8];
@@ -73,22 +67,22 @@ void compute_outputs(const uint32_t first_row, const uint32_t num_rows) {
73
67
FLOAT_TYPE sum1 = FLOAT_TYPE(0.0);
74
68
FLOAT_TYPE sum2 = FLOAT_TYPE(0.0);
75
69
[[unroll]] for (int l = 0; l < 2; ++l) {
76
- sum1 = fma(FLOAT_TYPE(b0[l]), FLOAT_TYPE(s0_lo4[0]) * FLOAT_TYPE((qs0[l] >> 0 ) & 3),
77
- fma(FLOAT_TYPE(b16[l]), FLOAT_TYPE(s0_lo4[1]) * FLOAT_TYPE((qs16[l] >> 0 ) & 3),
78
- fma(FLOAT_TYPE(b32[l]), FLOAT_TYPE(s0_lo4[2]) * FLOAT_TYPE((qs0[l] >> 2) & 3),
79
- fma(FLOAT_TYPE(b48[l]), FLOAT_TYPE(s0_lo4[3]) * FLOAT_TYPE((qs16[l] >> 2) & 3),
80
- fma(FLOAT_TYPE(b64[l]), FLOAT_TYPE(s4_lo4[0]) * FLOAT_TYPE((qs0[l] >> 4) & 3),
81
- fma(FLOAT_TYPE(b80[l]), FLOAT_TYPE(s4_lo4[1]) * FLOAT_TYPE((qs16[l] >> 4) & 3),
82
- fma(FLOAT_TYPE(b96[l]), FLOAT_TYPE(s4_lo4[2]) * FLOAT_TYPE((qs0[l] >> 6) & 3),
83
- fma(FLOAT_TYPE(b112[l]), FLOAT_TYPE(s4_lo4[3]) * FLOAT_TYPE((qs16[l] >> 6) & 3), sum1))))))));
84
- sum2 = fma(FLOAT_TYPE(b0[l]), FLOAT_TYPE(s0_hi4[0]) ,
85
- fma(FLOAT_TYPE(b16[l]), FLOAT_TYPE(s0_hi4[1]) ,
86
- fma(FLOAT_TYPE(b32[l]), FLOAT_TYPE(s0_hi4[2]) ,
87
- fma(FLOAT_TYPE(b48[l]), FLOAT_TYPE(s0_hi4[3]) ,
88
- fma(FLOAT_TYPE(b64[l]), FLOAT_TYPE(s4_hi4[0]) ,
89
- fma(FLOAT_TYPE(b80[l]), FLOAT_TYPE(s4_hi4[1]) ,
90
- fma(FLOAT_TYPE(b96[l]), FLOAT_TYPE(s4_hi4[2]) ,
91
- fma(FLOAT_TYPE(b112[l]), FLOAT_TYPE(s4_hi4[3]) , sum2))))))));
70
+ sum1 = fma(FLOAT_TYPE(b0[l]), sccache[ix][v_im][0] * FLOAT_TYPE((qs0[l] ) & 3),
71
+ fma(FLOAT_TYPE(b16[l]), sccache[ix][v_im][1] * FLOAT_TYPE((qs16[l] ) & 3),
72
+ fma(FLOAT_TYPE(b32[l]), sccache[ix][v_im][2] * FLOAT_TYPE((qs0[l] >> 2) & 3),
73
+ fma(FLOAT_TYPE(b48[l]), sccache[ix][v_im][3] * FLOAT_TYPE((qs16[l] >> 2) & 3),
74
+ fma(FLOAT_TYPE(b64[l]), sccache[ix][v_im][4] * FLOAT_TYPE((qs0[l] >> 4) & 3),
75
+ fma(FLOAT_TYPE(b80[l]), sccache[ix][v_im][5] * FLOAT_TYPE((qs16[l] >> 4) & 3),
76
+ fma(FLOAT_TYPE(b96[l]), sccache[ix][v_im][6] * FLOAT_TYPE((qs0[l] >> 6) & 3),
77
+ fma(FLOAT_TYPE(b112[l]), sccache[ix][v_im][7] * FLOAT_TYPE((qs16[l] >> 6) & 3), sum1))))))));
78
+ sum2 = fma(FLOAT_TYPE(b0[l]), sccache[ix][v_im][ 8] ,
79
+ fma(FLOAT_TYPE(b16[l]), sccache[ix][v_im][ 9] ,
80
+ fma(FLOAT_TYPE(b32[l]), sccache[ix][v_im][10] ,
81
+ fma(FLOAT_TYPE(b48[l]), sccache[ix][v_im][11] ,
82
+ fma(FLOAT_TYPE(b64[l]), sccache[ix][v_im][12] ,
83
+ fma(FLOAT_TYPE(b80[l]), sccache[ix][v_im][13] ,
84
+ fma(FLOAT_TYPE(b96[l]), sccache[ix][v_im][14] ,
85
+ fma(FLOAT_TYPE(b112[l]), sccache[ix][v_im][15] , sum2))))))));
92
86
}
93
87
temp[j][n] = fma(dall, sum1, fma(-dmin, sum2, temp[j][n]));
94
88
}
0 commit comments