@@ -2459,6 +2459,12 @@ typedef struct {
2459
2459
} block_iq2_xs;
2460
2460
// 74 bytes / block for QK_K = 256, so 2.3125 bpw
2461
2461
2462
+ typedef struct {
2463
+ half d;
2464
+ uint8_t qs[3 *QK_K/8 ];
2465
+ } block_iq3_xxs;
2466
+ // 98 bytes / block for QK_K = 256, so 3.0625 bpw
2467
+
2462
2468
// ====================================== dot products =========================
2463
2469
2464
2470
void kernel_mul_mv_q2_K_f32_impl (
@@ -3681,6 +3687,42 @@ constexpr constant static uint64_t iq2xs_grid[512] = {
3681
3687
0x2b2b2b2b082b2b08 , 0x2b2b2b2b082b2b2b , 0x2b2b2b2b2b190819 , 0x2b2b2b2b2b2b2b2b ,
3682
3688
};
3683
3689
3690
+ constexpr constant static uint32_t iq3xxs_grid[256 ] = {
3691
+ 0x04040404 , 0x04040414 , 0x04040424 , 0x04040c0c , 0x04040c1c , 0x04040c3c , 0x04041404 , 0x04041414 ,
3692
+ 0x04041c0c , 0x04042414 , 0x04043c1c , 0x04043c2c , 0x040c040c , 0x040c041c , 0x040c0c04 , 0x040c0c14 ,
3693
+ 0x040c140c , 0x040c142c , 0x040c1c04 , 0x040c1c14 , 0x040c240c , 0x040c2c24 , 0x040c3c04 , 0x04140404 ,
3694
+ 0x04140414 , 0x04140424 , 0x04140c0c , 0x04141404 , 0x04141414 , 0x04141c0c , 0x04141c1c , 0x04141c3c ,
3695
+ 0x04142c0c , 0x04142c3c , 0x04143c2c , 0x041c040c , 0x041c043c , 0x041c0c04 , 0x041c0c14 , 0x041c142c ,
3696
+ 0x041c3c04 , 0x04240c1c , 0x04241c3c , 0x04242424 , 0x04242c3c , 0x04243c1c , 0x04243c2c , 0x042c040c ,
3697
+ 0x042c043c , 0x042c1c14 , 0x042c2c14 , 0x04341c2c , 0x04343424 , 0x043c0c04 , 0x043c0c24 , 0x043c0c34 ,
3698
+ 0x043c241c , 0x043c340c , 0x0c04040c , 0x0c04041c , 0x0c040c04 , 0x0c040c14 , 0x0c04140c , 0x0c04141c ,
3699
+ 0x0c041c04 , 0x0c041c14 , 0x0c041c24 , 0x0c04243c , 0x0c042c04 , 0x0c0c0404 , 0x0c0c0414 , 0x0c0c0c0c ,
3700
+ 0x0c0c1404 , 0x0c0c1414 , 0x0c14040c , 0x0c14041c , 0x0c140c04 , 0x0c140c14 , 0x0c14140c , 0x0c141c04 ,
3701
+ 0x0c143c14 , 0x0c1c0404 , 0x0c1c0414 , 0x0c1c1404 , 0x0c1c1c0c , 0x0c1c2434 , 0x0c1c3434 , 0x0c24040c ,
3702
+ 0x0c24042c , 0x0c242c04 , 0x0c2c1404 , 0x0c2c1424 , 0x0c2c2434 , 0x0c2c3c0c , 0x0c34042c , 0x0c3c1414 ,
3703
+ 0x0c3c2404 , 0x14040404 , 0x14040414 , 0x14040c0c , 0x14040c1c , 0x14041404 , 0x14041414 , 0x14041434 ,
3704
+ 0x14041c0c , 0x14042414 , 0x140c040c , 0x140c041c , 0x140c042c , 0x140c0c04 , 0x140c0c14 , 0x140c140c ,
3705
+ 0x140c1c04 , 0x140c341c , 0x140c343c , 0x140c3c04 , 0x14140404 , 0x14140414 , 0x14140c0c , 0x14140c3c ,
3706
+ 0x14141404 , 0x14141414 , 0x14141c3c , 0x14142404 , 0x14142c2c , 0x141c040c , 0x141c0c04 , 0x141c0c24 ,
3707
+ 0x141c3c04 , 0x141c3c24 , 0x14241c2c , 0x14242c1c , 0x142c041c , 0x142c143c , 0x142c240c , 0x142c3c24 ,
3708
+ 0x143c040c , 0x143c041c , 0x143c0c34 , 0x143c242c , 0x1c04040c , 0x1c040c04 , 0x1c040c14 , 0x1c04140c ,
3709
+ 0x1c04141c , 0x1c042c04 , 0x1c04342c , 0x1c043c14 , 0x1c0c0404 , 0x1c0c0414 , 0x1c0c1404 , 0x1c0c1c0c ,
3710
+ 0x1c0c2424 , 0x1c0c2434 , 0x1c14040c , 0x1c14041c , 0x1c140c04 , 0x1c14142c , 0x1c142c14 , 0x1c143c14 ,
3711
+ 0x1c1c0c0c , 0x1c1c1c1c , 0x1c241c04 , 0x1c24243c , 0x1c243c14 , 0x1c2c0404 , 0x1c2c0434 , 0x1c2c1414 ,
3712
+ 0x1c2c2c2c , 0x1c340c24 , 0x1c341c34 , 0x1c34341c , 0x1c3c1c1c , 0x1c3c3404 , 0x24040424 , 0x24040c3c ,
3713
+ 0x24041c2c , 0x24041c3c , 0x24042c1c , 0x24042c3c , 0x240c3c24 , 0x24141404 , 0x24141c3c , 0x24142404 ,
3714
+ 0x24143404 , 0x24143434 , 0x241c043c , 0x241c242c , 0x24240424 , 0x24242c0c , 0x24243424 , 0x242c142c ,
3715
+ 0x242c241c , 0x242c3c04 , 0x243c042c , 0x243c0c04 , 0x243c0c14 , 0x243c1c04 , 0x2c040c14 , 0x2c04240c ,
3716
+ 0x2c043c04 , 0x2c0c0404 , 0x2c0c0434 , 0x2c0c1434 , 0x2c0c2c2c , 0x2c140c24 , 0x2c141c14 , 0x2c143c14 ,
3717
+ 0x2c1c0414 , 0x2c1c2c1c , 0x2c240c04 , 0x2c24141c , 0x2c24143c , 0x2c243c14 , 0x2c2c0414 , 0x2c2c1c0c ,
3718
+ 0x2c342c04 , 0x2c3c1424 , 0x2c3c2414 , 0x34041424 , 0x34042424 , 0x34042434 , 0x34043424 , 0x340c140c ,
3719
+ 0x340c340c , 0x34140c3c , 0x34143424 , 0x341c1c04 , 0x341c1c34 , 0x34242424 , 0x342c042c , 0x342c2c14 ,
3720
+ 0x34341c1c , 0x343c041c , 0x343c140c , 0x3c04041c , 0x3c04042c , 0x3c04043c , 0x3c040c04 , 0x3c041c14 ,
3721
+ 0x3c042c14 , 0x3c0c1434 , 0x3c0c2404 , 0x3c140c14 , 0x3c14242c , 0x3c142c14 , 0x3c1c0404 , 0x3c1c0c2c ,
3722
+ 0x3c1c1c1c , 0x3c1c3404 , 0x3c24140c , 0x3c24240c , 0x3c2c0404 , 0x3c2c0414 , 0x3c2c1424 , 0x3c341c04 ,
3723
+ };
3724
+
3725
+
3684
3726
constexpr constant static uint8_t ksigns_iq2xs[128 ] = {
3685
3727
0 , 129 , 130 , 3 , 132 , 5 , 6 , 135 , 136 , 9 , 10 , 139 , 12 , 141 , 142 , 15 ,
3686
3728
144 , 17 , 18 , 147 , 20 , 149 , 150 , 23 , 24 , 153 , 154 , 27 , 156 , 29 , 30 , 159 ,
@@ -3970,6 +4012,143 @@ kernel void kernel_mul_mv_iq2_xs_f32(
3970
4012
kernel_mul_mv_iq2_xs_f32_impl (src0, src1, dst, ne00, ne01, ne02, ne10, ne12, ne0, ne1, r2, r3, shared_values, tgpig, tiisg, sgitg);
3971
4013
}
3972
4014
4015
+ void kernel_mul_mv_iq3_xxs_f32_impl (
4016
+ device const void * src0,
4017
+ device const float * src1,
4018
+ device float * dst,
4019
+ constant int64_t & ne00,
4020
+ constant int64_t & ne01,
4021
+ constant int64_t & ne02,
4022
+ constant int64_t & ne10,
4023
+ constant int64_t & ne12,
4024
+ constant int64_t & ne0,
4025
+ constant int64_t & ne1,
4026
+ constant uint & r2,
4027
+ constant uint & r3,
4028
+ threadgroup int8_t * shared_values [[threadgroup(0 )]],
4029
+ uint3 tgpig[[threadgroup_position_in_grid]],
4030
+ uint tiisg[[thread_index_in_simdgroup]],
4031
+ uint sgitg[[simdgroup_index_in_threadgroup]]) {
4032
+
4033
+ const int nb = ne00/QK_K;
4034
+ const int r0 = tgpig.x ;
4035
+ const int r1 = tgpig.y ;
4036
+ const int im = tgpig.z ;
4037
+
4038
+ const int first_row = (r0 * N_SIMDGROUP + sgitg) * N_DST;
4039
+ const int ib_row = first_row * nb;
4040
+
4041
+ const uint i12 = im%ne12;
4042
+ const uint i13 = im/ne12;
4043
+
4044
+ const uint offset0 = (i12/r2)*(nb*ne01) + (i13/r3)*(nb*ne01*ne02);
4045
+
4046
+ device const block_iq3_xxs * x = (device const block_iq3_xxs *) src0 + ib_row + offset0;
4047
+ device const float * y = (device const float *) src1 + r1*ne10 + im*ne00*ne1;
4048
+
4049
+ float yl[32 ];
4050
+ float sumf[N_DST]={0 .f }, all_sum;
4051
+
4052
+ const int nb32 = nb * (QK_K / 32 );
4053
+
4054
+ threadgroup uint32_t * values = (threadgroup uint32_t *)shared_values;
4055
+ threadgroup uint8_t * shared_signs = (threadgroup uint8_t *)(values + 256 );
4056
+ {
4057
+ int nval = 4 ;
4058
+ int pos = (32 *sgitg + tiisg)*nval;
4059
+ for (int i = 0 ; i < nval; ++i) values[pos + i] = iq3xxs_grid[pos + i];
4060
+ nval = 2 ;
4061
+ pos = (32 *sgitg + tiisg)*nval;
4062
+ for (int i = 0 ; i < nval; ++i) shared_signs[pos+i] = ksigns_iq2xs[pos+i];
4063
+ threadgroup_barrier (mem_flags::mem_threadgroup);
4064
+ }
4065
+
4066
+ #if QK_K == 256
4067
+ const int ix = tiisg;
4068
+
4069
+ device const float * y4 = y + 32 * ix;
4070
+
4071
+ for (int ib32 = ix; ib32 < nb32; ib32 += 32 ) {
4072
+
4073
+ for (int i = 0 ; i < 32 ; ++i) {
4074
+ yl[i] = y4[i];
4075
+ }
4076
+
4077
+ const int ibl = ib32 / (QK_K / 32 );
4078
+ const int ib = ib32 % (QK_K / 32 );
4079
+
4080
+ device const block_iq3_xxs * xr = x + ibl;
4081
+ device const uint8_t * q3 = xr->qs + 8 * ib;
4082
+ device const uint16_t * gas = (device const uint16_t *)(xr->qs + QK_K/4 ) + 2 * ib;
4083
+ device const half * dh = &xr->d ;
4084
+
4085
+ for (int row = 0 ; row < N_DST; row++) {
4086
+
4087
+ const float db = dh[0 ];
4088
+ const uint32_t aux32 = gas[0 ] | (gas[1 ] << 16 );
4089
+ const float d = db * (0 .5f + (aux32 >> 28 ));
4090
+
4091
+ float2 sum = {0 };
4092
+ for (int l = 0 ; l < 4 ; ++l) {
4093
+ const threadgroup uint8_t * grid1 = (const threadgroup uint8_t *)(values + q3[2 *l+0 ]);
4094
+ const threadgroup uint8_t * grid2 = (const threadgroup uint8_t *)(values + q3[2 *l+1 ]);
4095
+ const uint8_t signs = shared_signs[(aux32 >> 7 *l) & 127 ];
4096
+ for (int j = 0 ; j < 4 ; ++j) {
4097
+ sum[0 ] += yl[8 *l + j + 0 ] * grid1[j] * (signs & kmask_iq2xs[j+0 ] ? -1 .f : 1 .f );
4098
+ sum[1 ] += yl[8 *l + j + 4 ] * grid2[j] * (signs & kmask_iq2xs[j+4 ] ? -1 .f : 1 .f );
4099
+ }
4100
+ }
4101
+ sumf[row] += d * (sum[0 ] + sum[1 ]);
4102
+
4103
+ dh += nb*sizeof (block_iq3_xxs)/2 ;
4104
+ q3 += nb*sizeof (block_iq3_xxs);
4105
+ gas += nb*sizeof (block_iq3_xxs)/2 ;
4106
+ }
4107
+
4108
+ y4 += 32 * 32 ;
4109
+ }
4110
+ #else
4111
+ // TODO
4112
+ #endif
4113
+
4114
+ for (int row = 0 ; row < N_DST; ++row) {
4115
+ all_sum = simd_sum (sumf[row]);
4116
+ if (tiisg == 0 ) {
4117
+ dst[r1*ne0 + im*ne0*ne1 + first_row + row] = all_sum * 0 .5f ;
4118
+ }
4119
+ }
4120
+ }
4121
+
4122
+ [[host_name(" kernel_mul_mv_iq3_xxs_f32" )]]
4123
+ kernel void kernel_mul_mv_iq3_xxs_f32 (
4124
+ device const void * src0,
4125
+ device const float * src1,
4126
+ device float * dst,
4127
+ constant int64_t & ne00,
4128
+ constant int64_t & ne01,
4129
+ constant int64_t & ne02,
4130
+ constant uint64_t & nb00,
4131
+ constant uint64_t & nb01,
4132
+ constant uint64_t & nb02,
4133
+ constant int64_t & ne10,
4134
+ constant int64_t & ne11,
4135
+ constant int64_t & ne12,
4136
+ constant uint64_t & nb10,
4137
+ constant uint64_t & nb11,
4138
+ constant uint64_t & nb12,
4139
+ constant int64_t & ne0,
4140
+ constant int64_t & ne1,
4141
+ constant uint & r2,
4142
+ constant uint & r3,
4143
+ threadgroup int8_t * shared_values [[threadgroup(0 )]],
4144
+ uint3 tgpig[[threadgroup_position_in_grid]],
4145
+ uint tiisg[[thread_index_in_simdgroup]],
4146
+ uint sgitg[[simdgroup_index_in_threadgroup]]) {
4147
+
4148
+ kernel_mul_mv_iq3_xxs_f32_impl (src0, src1, dst, ne00, ne01, ne02, ne10, ne12, ne0, ne1, r2, r3, shared_values, tgpig, tiisg, sgitg);
4149
+ }
4150
+
4151
+
3973
4152
// ============================= templates and their specializations =============================
3974
4153
3975
4154
// NOTE: this is not dequantizing - we are simply fitting the template
@@ -4287,6 +4466,33 @@ void dequantize_iq2_xs(device const block_iq2_xs * xb, short il, thread type4x4
4287
4466
}
4288
4467
}
4289
4468
4469
+ template <typename type4x4>
4470
+ void dequantize_iq3_xxs (device const block_iq3_xxs * xb, short il, thread type4x4 & reg) {
4471
+ // il is 0...15 for QK_K = 256 => index of block of 32 is il/2
4472
+ const float d = xb->d ;
4473
+ const int ib32 = il/2 ;
4474
+ il = il%2 ;
4475
+ // il = 0 or 1. il = 0 processes the first 16 quants in a block of 32, il = 1 the second 16
4476
+ device const uint8_t * q3 = xb->qs + 8 *ib32;
4477
+ device const uint16_t * gas = (device const uint16_t *)(xb->qs + QK_K/4 ) + 2 *ib32;
4478
+ const uint32_t aux32 = gas[0 ] | (gas[1 ] << 16 );
4479
+ const float dl = d * (0 .5f + (aux32 >> 28 )) * 0 .5f ;
4480
+ constant uint8_t * grid1 = (constant uint8_t *)(iq3xxs_grid + q3[4 *il+0 ]);
4481
+ constant uint8_t * grid2 = (constant uint8_t *)(iq3xxs_grid + q3[4 *il+1 ]);
4482
+ uint8_t signs = ksigns_iq2xs[(aux32 >> 14 *il) & 127 ];
4483
+ for (int i = 0 ; i < 4 ; ++i) {
4484
+ reg[0 ][i] = dl * grid1[i] * (signs & kmask_iq2xs[i+0 ] ? -1 .f : 1 .f );
4485
+ reg[1 ][i] = dl * grid2[i] * (signs & kmask_iq2xs[i+4 ] ? -1 .f : 1 .f );
4486
+ }
4487
+ grid1 = (constant uint8_t *)(iq3xxs_grid + q3[4 *il+2 ]);
4488
+ grid2 = (constant uint8_t *)(iq3xxs_grid + q3[4 *il+3 ]);
4489
+ signs = ksigns_iq2xs[(aux32 >> (14 *il+7 )) & 127 ];
4490
+ for (int i = 0 ; i < 4 ; ++i) {
4491
+ reg[2 ][i] = dl * grid1[i] * (signs & kmask_iq2xs[i+0 ] ? -1 .f : 1 .f );
4492
+ reg[3 ][i] = dl * grid2[i] * (signs & kmask_iq2xs[i+4 ] ? -1 .f : 1 .f );
4493
+ }
4494
+ }
4495
+
4290
4496
template <typename block_q, short nl, void (*dequantize_func)(device const block_q *, short , thread float4x4 &)>
4291
4497
kernel void kernel_get_rows (
4292
4498
device const void * src0,
@@ -4828,6 +5034,7 @@ template [[host_name("kernel_get_rows_q5_K")]] kernel get_rows_t kernel_get_rows
4828
5034
template [[host_name(" kernel_get_rows_q6_K" )]] kernel get_rows_t kernel_get_rows<block_q6_K, QK_NL, dequantize_q6_K>;
4829
5035
template [[host_name(" kernel_get_rows_iq2_xxs" )]] kernel get_rows_t kernel_get_rows<block_iq2_xxs, QK_NL, dequantize_iq2_xxs>;
4830
5036
template [[host_name(" kernel_get_rows_iq2_xs" )]] kernel get_rows_t kernel_get_rows<block_iq2_xs, QK_NL, dequantize_iq2_xs>;
5037
+ template [[host_name(" kernel_get_rows_iq3_xxs" )]] kernel get_rows_t kernel_get_rows<block_iq3_xxs, QK_NL, dequantize_iq3_xxs>;
4831
5038
4832
5039
//
4833
5040
// matrix-matrix multiplication
@@ -4866,6 +5073,7 @@ template [[host_name("kernel_mul_mm_q5_K_f32")]] kernel mat_mm_t kernel_mul_mm<b
4866
5073
template [[host_name(" kernel_mul_mm_q6_K_f32" )]] kernel mat_mm_t kernel_mul_mm<block_q6_K, QK_NL, dequantize_q6_K>;
4867
5074
template [[host_name(" kernel_mul_mm_iq2_xxs_f32" )]] kernel mat_mm_t kernel_mul_mm<block_iq2_xxs, QK_NL, dequantize_iq2_xxs>;
4868
5075
template [[host_name(" kernel_mul_mm_iq2_xs_f32" )]] kernel mat_mm_t kernel_mul_mm<block_iq2_xs, QK_NL, dequantize_iq2_xs>;
5076
+ template [[host_name(" kernel_mul_mm_iq3_xxs_f32" )]] kernel mat_mm_t kernel_mul_mm<block_iq3_xxs, QK_NL, dequantize_iq3_xxs>;
4869
5077
4870
5078
//
4871
5079
// indirect matrix-matrix multiplication
@@ -4916,6 +5124,7 @@ template [[host_name("kernel_mul_mm_id_q5_K_f32")]] kernel mat_mm_id_t kernel_mu
4916
5124
template [[host_name(" kernel_mul_mm_id_q6_K_f32" )]] kernel mat_mm_id_t kernel_mul_mm_id<block_q6_K, QK_NL, dequantize_q6_K>;
4917
5125
template [[host_name(" kernel_mul_mm_id_iq2_xxs_f32" )]] kernel mat_mm_id_t kernel_mul_mm_id<block_iq2_xxs, QK_NL, dequantize_iq2_xxs>;
4918
5126
template [[host_name(" kernel_mul_mm_id_iq2_xs_f32" )]] kernel mat_mm_id_t kernel_mul_mm_id<block_iq2_xs, QK_NL, dequantize_iq2_xs>;
5127
+ template [[host_name(" kernel_mul_mm_id_iq3_xxs_f32" )]] kernel mat_mm_id_t kernel_mul_mm_id<block_iq3_xxs, QK_NL, dequantize_iq3_xxs>;
4919
5128
4920
5129
//
4921
5130
// matrix-vector multiplication
@@ -5818,3 +6027,68 @@ kernel void kernel_mul_mv_id_iq2_xs_f32(
5818
6027
tiisg,
5819
6028
sgitg);
5820
6029
}
6030
+
6031
+ [[host_name(" kernel_mul_mv_id_iq3_xxs_f32" )]]
6032
+ kernel void kernel_mul_mv_id_iq3_xxs_f32 (
6033
+ device const char * ids,
6034
+ device const char * src1,
6035
+ device float * dst,
6036
+ constant uint64_t & nbi1,
6037
+ constant int64_t & ne00,
6038
+ constant int64_t & ne01,
6039
+ constant int64_t & ne02,
6040
+ constant uint64_t & nb00,
6041
+ constant uint64_t & nb01,
6042
+ constant uint64_t & nb02,
6043
+ constant int64_t & ne10,
6044
+ constant int64_t & ne11,
6045
+ constant int64_t & ne12,
6046
+ constant int64_t & ne13,
6047
+ constant uint64_t & nb10,
6048
+ constant uint64_t & nb11,
6049
+ constant uint64_t & nb12,
6050
+ constant int64_t & ne0,
6051
+ constant int64_t & ne1,
6052
+ constant uint64_t & nb1,
6053
+ constant uint & r2,
6054
+ constant uint & r3,
6055
+ constant int & idx,
6056
+ device const char * src00,
6057
+ device const char * src01,
6058
+ device const char * src02,
6059
+ device const char * src03,
6060
+ device const char * src04,
6061
+ device const char * src05,
6062
+ device const char * src06,
6063
+ device const char * src07,
6064
+ threadgroup int8_t * shared_values [[threadgroup(0 )]],
6065
+ uint3 tgpig[[threadgroup_position_in_grid]],
6066
+ uint tiitg[[thread_index_in_threadgroup]],
6067
+ uint tiisg[[thread_index_in_simdgroup]],
6068
+ uint sgitg[[simdgroup_index_in_threadgroup]]) {
6069
+ device const char * src0[8 ] = {src00, src01, src02, src03, src04, src05, src06, src07};
6070
+
6071
+ const int64_t bid = tgpig.z /(ne12*ne13);
6072
+
6073
+ tgpig.z = tgpig.z %(ne12*ne13);
6074
+
6075
+ const int32_t id = ((device int32_t *) (ids + bid*nbi1))[idx];
6076
+
6077
+ kernel_mul_mv_iq3_xxs_f32_impl (
6078
+ src0[id],
6079
+ (device const float *) (src1 + bid*nb11),
6080
+ dst + bid*ne0,
6081
+ ne00,
6082
+ ne01,
6083
+ ne02,
6084
+ ne10,
6085
+ ne12,
6086
+ ne0,
6087
+ ne1,
6088
+ r2,
6089
+ r3,
6090
+ shared_values,
6091
+ tgpig,
6092
+ tiisg,
6093
+ sgitg);
6094
+ }
0 commit comments