@@ -7,8 +7,8 @@ using namespace metal;
7
7
#define QK4_0 32
8
8
#define QR4_0 2
9
9
typedef struct {
10
- half d; // delta
11
- uint8_t qs[QK4_0 / 2 ]; // nibbles / quants
10
+ half d; // delta
11
+ uint8_t qs[QK4_0 / 2 ]; // nibbles / quants
12
12
} block_q4_0;
13
13
14
14
static void dequantize_row_q4_0 (device const block_q4_0 * x, device float * y, int k) {
@@ -38,8 +38,8 @@ kernel void kernel_add(
38
38
device const float * src0,
39
39
device const float * src1,
40
40
device float * dst,
41
- uint gid [[thread_position_in_grid]]) {
42
- dst[gid ] = src0[gid ] + src1[gid ];
41
+ uint tpig [[thread_position_in_grid]]) {
42
+ dst[tpig ] = src0[tpig ] + src1[tpig ];
43
43
}
44
44
45
45
// assumption: src1 is a row
@@ -49,15 +49,15 @@ kernel void kernel_mul(
49
49
device const float * src1,
50
50
device float * dst,
51
51
constant int64_t & ne00,
52
- uint gid [[thread_position_in_grid]]) {
53
- dst[gid ] = src0[gid ] * src1[gid % ne00];
52
+ uint tpig [[thread_position_in_grid]]) {
53
+ dst[tpig ] = src0[tpig ] * src1[tpig % ne00];
54
54
}
55
55
56
56
kernel void kernel_relu (
57
57
device const float * src0,
58
58
device float * dst,
59
- uint gid [[thread_position_in_grid]]) {
60
- dst[gid ] = max (0 .0f , src0[gid ]);
59
+ uint tpig [[thread_position_in_grid]]) {
60
+ dst[tpig ] = max (0 .0f , src0[tpig ]);
61
61
}
62
62
63
63
// TODO: broken
@@ -85,8 +85,8 @@ kernel void kernel_get_rows_q4_0(
85
85
constant int64_t & ne00,
86
86
constant uint64_t & nb01,
87
87
constant uint64_t & nb1,
88
- uint gid [[thread_position_in_grid]]) {
89
- const int i = gid ;
88
+ uint tpig [[thread_position_in_grid]]) {
89
+ const int i = tpig ;
90
90
const int r = ((device int32_t *) src1)[i];
91
91
92
92
dequantize_row_q4_0 (
@@ -100,8 +100,8 @@ kernel void kernel_rms_norm(
100
100
constant int64_t & ne00,
101
101
constant uint64_t & nb01,
102
102
constant float & eps,
103
- uint gid [[thread_position_in_grid]]) {
104
- device const float * x = (device const float *) ((device const char *) src0 + gid *nb01);
103
+ uint tpig [[thread_position_in_grid]]) {
104
+ device const float * x = (device const float *) ((device const char *) src0 + tpig *nb01);
105
105
106
106
float sum = 0 .0f ;
107
107
for (int i00 = 0 ; i00 < ne00; i00++) {
@@ -111,8 +111,84 @@ kernel void kernel_rms_norm(
111
111
const float mean = sum/ne00;
112
112
const float scale = 1 .0f /sqrt (mean + eps);
113
113
114
- device float * y = dst + gid *ne00;
114
+ device float * y = dst + tpig *ne00;
115
115
for (int i00 = 0 ; i00 < ne00; i00++) {
116
116
y[i00] = x[i00] * scale;
117
117
}
118
118
}
119
+
120
+ kernel void kernel_mul_mat_q4_0 (
121
+ device const void * src0,
122
+ device const float * src1,
123
+ device float * dst,
124
+ constant int64_t & ne00,
125
+ constant int64_t & ne01,
126
+ constant int64_t & ne10,
127
+ constant int64_t & ne11,
128
+ constant int64_t & ne0,
129
+ constant int64_t & ne1,
130
+ uint2 tgpig[[threadgroup_position_in_grid]],
131
+ uint2 tpig[[thread_position_in_grid]],
132
+ uint2 tpitg[[thread_position_in_threadgroup]],
133
+ uint2 tptg[[threads_per_threadgroup]]) {
134
+ const int64_t r0 = tgpig.x ;
135
+ const int64_t r1 = tgpig.y ;
136
+
137
+ const int qk = QK4_0;
138
+ const int nb = ne00/qk;
139
+
140
+ device const block_q4_0 * x = (device const block_q4_0 *) (src0) + r0*nb;
141
+ device const float * y = (device const float *) (src1) + r1*ne10;
142
+
143
+ threadgroup float sum[32 ]; // TODO: should be equal to threadgroup size
144
+ sum[tpitg.x ] = 0 .0f ;
145
+
146
+ for (int i = 0 ; i < nb; i += tptg.x ) {
147
+ device const uint4 * x0p = (device const uint4 *) (x + i);
148
+ device const float4 * y0p = (device const float4 *) (y + i*qk);
149
+
150
+ const uint4 x0 = *x0p;
151
+
152
+ const uint4 x0l = x0 & uint4 (0x0F0F0F0F );
153
+ const uint4 x0h = x0 >> 4 ;
154
+
155
+ const int4 x0ls = as_type<int4>(x0l) - int4 (8 );
156
+ const int4 x0hs = as_type<int4>(x0h) - int4 (8 );
157
+
158
+ thread const uchar * x0lsb = (thread const uchar *) &x0ls;
159
+ thread const uchar * x0hsb = (thread const uchar *) &x0hs;
160
+
161
+ const float4 y00 = *(y0p + 0 );
162
+ const float4 y01 = *(y0p + 1 );
163
+ const float4 y02 = *(y0p + 2 );
164
+ const float4 y03 = *(y0p + 3 );
165
+ const float4 y04 = *(y0p + 4 );
166
+ const float4 y05 = *(y0p + 5 );
167
+ const float4 y06 = *(y0p + 6 );
168
+ const float4 y07 = *(y0p + 7 );
169
+
170
+ const float d = (x + i)->d ;
171
+
172
+ sum[tpitg.x ] += (
173
+ x0lsb[ 0 ]*y00[0 ] + x0lsb[ 1 ]*y00[1 ] + x0lsb[ 2 ]*y00[2 ] + x0lsb[ 3 ]*y00[3 ] +
174
+ x0lsb[ 4 ]*y01[0 ] + x0lsb[ 5 ]*y01[1 ] + x0lsb[ 6 ]*y01[2 ] + x0lsb[ 7 ]*y01[3 ] +
175
+ x0lsb[ 8 ]*y02[0 ] + x0lsb[ 9 ]*y02[1 ] + x0lsb[10 ]*y02[2 ] + x0lsb[11 ]*y02[3 ] +
176
+ x0lsb[12 ]*y03[0 ] + x0lsb[13 ]*y03[1 ] + x0lsb[14 ]*y03[2 ] + x0lsb[15 ]*y03[3 ] +
177
+ x0hsb[ 0 ]*y04[0 ] + x0hsb[ 1 ]*y04[1 ] + x0hsb[ 2 ]*y04[2 ] + x0hsb[ 3 ]*y04[3 ] +
178
+ x0hsb[ 4 ]*y05[0 ] + x0hsb[ 5 ]*y05[1 ] + x0hsb[ 6 ]*y05[2 ] + x0hsb[ 7 ]*y05[3 ] +
179
+ x0hsb[ 8 ]*y06[0 ] + x0hsb[ 9 ]*y06[1 ] + x0hsb[10 ]*y06[2 ] + x0hsb[11 ]*y06[3 ] +
180
+ x0hsb[12 ]*y07[0 ] + x0hsb[13 ]*y07[1 ] + x0hsb[14 ]*y07[2 ] + x0hsb[15 ]*y07[3 ]
181
+ ) * d;
182
+ }
183
+
184
+ // accumulate the sum from all threads in the threadgroup
185
+ threadgroup_barrier (mem_flags::mem_threadgroup);
186
+ for (uint i = tptg.x /2 ; i > 0 ; i /= 2 ) {
187
+ if (tpitg.x < i) {
188
+ sum[tpitg.x ] += sum[tpitg.x + i];
189
+ }
190
+ threadgroup_barrier (mem_flags::mem_threadgroup);
191
+ }
192
+
193
+ dst[r1*ne0 + r0] = sum[0 ];
194
+ }
0 commit comments