@@ -4372,6 +4372,143 @@ kernel void kernel_concat(
4372
4372
}
4373
4373
}
4374
4374
4375
+ template <typename block_q, short qqk, void (*dequantize_func)(device const block_q *, device float *)>
4376
+ kernel void kernel_cpy_q_f32 (
4377
+ constant ggml_metal_kargs_cpy & args,
4378
+ device const char * cx [[ buffer(1 ) ]],
4379
+ device char * cdst [[ buffer(2 ) ]],
4380
+ uint tid [[ thread_position_in_grid ]]
4381
+ )
4382
+ {
4383
+ // Compute the global index multiplied by QK, matching:
4384
+ // i = (blockDim.x*blockIdx.x + threadIdx.x)*qk
4385
+ const int i = int (tid) * qqk;
4386
+
4387
+ // Bounds check
4388
+ if (i >= args.ne ) {
4389
+ return ;
4390
+ }
4391
+
4392
+ const int i03 = i/(args.ne00 * args.ne01 * args.ne02 );
4393
+ const int i02 = (i - i03*args.ne00 *args.ne01 *args.ne02 )/ (args.ne00 *args.ne01 );
4394
+ const int i01 = (i - i03*args.ne00 *args.ne01 *args.ne02 - i02*args.ne01 *args.ne00 ) / args.ne00 ;
4395
+ const int i00 = i - i03*args.ne00 *args.ne01 *args.ne02 - i02*args.ne01 *args.ne00 - i01*args.ne00 ;
4396
+ const int x_offset = (i00/qqk)*args.nb00 + i01*args.nb01 + i02*args.nb02 + i03 * args.nb03 ;
4397
+
4398
+ const int i13 = i/(args.ne0 * args.ne1 * args.ne2 );
4399
+ const int i12 = (i - i13*args.ne0 *args.ne1 *args.ne2 ) / (args.ne0 *args.ne1 );
4400
+ const int i11 = (i - i13*args.ne0 *args.ne1 *args.ne2 - i12*args.ne0 *args.ne1 ) / args.ne0 ;
4401
+ const int i10 = i - i13*args.ne0 *args.ne1 *args.ne2 - i12*args.ne0 *args.ne1 - i11*args.ne0 ;
4402
+ const int dst_offset = i10*args.nb0 + i11*args.nb1 + i12*args.nb2 + i13*args.nb3 ;
4403
+
4404
+ device const block_q * src_block = (device const block_q *)(cx + x_offset);
4405
+ device float * dst = (device float *)(cdst + dst_offset);
4406
+
4407
+ dequantize_func (src_block, dst);
4408
+ }
4409
+
4410
+ void dequant_q4_0_f (device const block_q4_0 * src_block, device float * dst) {
4411
+ float d = float (src_block->d );
4412
+ const float shift = 8 .0f ;
4413
+
4414
+ // Unpack 2 x 4-bit values per byte.
4415
+ for (int j = 0 ; j < QK4_0/2 ; j++) {
4416
+ uint8_t q = src_block->qs [j];
4417
+ uint8_t q0 = q & 0x0F ;
4418
+ uint8_t q1 = (q >> 4 ) & 0x0F ;
4419
+ dst[j] = (float (q0) - shift) * d;
4420
+ dst[j + QK4_0/2 ] = (float (q1) - shift) * d;
4421
+ }
4422
+ }
4423
+
4424
+ void dequant_q4_1_f (device const block_q4_1 * src_block, device float * dst) {
4425
+ float d = float (src_block->d );
4426
+ float vmin = float (src_block->m );
4427
+
4428
+ for (int j = 0 ; j < QK4_1/2 ; j++) {
4429
+ uint8_t q = src_block->qs [j];
4430
+ uint8_t q0 = q & 0x0F ;
4431
+ uint8_t q1 = (q >> 4 ) & 0x0F ;
4432
+ dst[j] = vmin + d * float (q0);
4433
+ dst[j + QK4_1/2 ] = vmin + d * float (q1);
4434
+ }
4435
+ }
4436
+
4437
+ void dequant_q5_0_f (device const block_q5_0 * src_block, device float * dst) {
4438
+ float d = float (src_block->d );
4439
+ const float shift = 16 .f ;
4440
+
4441
+ // Combine the four qh bytes into a 32-bit value.
4442
+ uint32_t qhVal = 0
4443
+ | ((uint32_t ) src_block->qh [0 ] << 0 )
4444
+ | ((uint32_t ) src_block->qh [1 ] << 8 )
4445
+ | ((uint32_t ) src_block->qh [2 ] << 16 )
4446
+ | ((uint32_t ) src_block->qh [3 ] << 24 );
4447
+
4448
+ // First half
4449
+ for (int j = 0 ; j < QK5_0/2 ; j++) {
4450
+ uint8_t q = src_block->qs [j];
4451
+ uint8_t lowNib = q & 0x0F ;
4452
+ uint8_t highBit = (qhVal >> j) & 0x1 ;
4453
+ uint8_t qVal = (highBit << 4 ) | lowNib;
4454
+ dst[j] = (float (qVal) - shift) * d;
4455
+ }
4456
+ // Second half
4457
+ for (int j = QK5_0/2 ; j < QK5_0; j++) {
4458
+ int k = j - QK5_0/2 ;
4459
+ uint8_t q = src_block->qs [k];
4460
+ uint8_t hiNib = (q >> 4 ) & 0x0F ;
4461
+ uint8_t highBit = (qhVal >> j) & 0x1 ;
4462
+ uint8_t qVal = (highBit << 4 ) | hiNib;
4463
+ dst[j] = (float (qVal) - shift) * d;
4464
+ }
4465
+ }
4466
+
4467
+ void dequant_q5_1_f (device const block_q5_1 * src_block, device float * dst) {
4468
+ float d = float (src_block->d );
4469
+ float vmin = float (src_block->m );
4470
+
4471
+ uint32_t qhVal = 0
4472
+ | ((uint32_t ) src_block->qh [0 ] << 0 )
4473
+ | ((uint32_t ) src_block->qh [1 ] << 8 )
4474
+ | ((uint32_t ) src_block->qh [2 ] << 16 )
4475
+ | ((uint32_t ) src_block->qh [3 ] << 24 );
4476
+
4477
+ // First half
4478
+ for (int j = 0 ; j < QK5_1/2 ; j++) {
4479
+ uint8_t q = src_block->qs [j];
4480
+ uint8_t lowNib = q & 0x0F ;
4481
+ uint8_t highBit = (qhVal >> j) & 0x1 ;
4482
+ uint8_t qVal = (highBit << 4 ) | lowNib;
4483
+ dst[j] = vmin + d * float (qVal);
4484
+ }
4485
+ // Second half
4486
+ for (int j = QK5_1/2 ; j < QK5_1; j++) {
4487
+ int k = j - QK5_1/2 ;
4488
+ uint8_t q = src_block->qs [k];
4489
+ uint8_t hiNib = (q >> 4 ) & 0x0F ;
4490
+ uint8_t highBit = (qhVal >> j) & 0x1 ;
4491
+ uint8_t qVal = (highBit << 4 ) | hiNib;
4492
+ dst[j] = vmin + d * float (qVal);
4493
+ }
4494
+ }
4495
+
4496
+ void dequant_q8_0_f (device const block_q8_0 * src_block, device float * dst) {
4497
+ const float d = (float )src_block->d ;
4498
+
4499
+ for (int j = 0 ; j < QK8_0; j++) {
4500
+ dst[j] = src_block->qs [j] * d;
4501
+ }
4502
+ }
4503
+
4504
+ typedef decltype (kernel_cpy_q_f32<block_q4_0, QK4_0, dequant_q4_0_f>) cpy_q_t;
4505
+
4506
+ template [[host_name(" kernel_cpy_q4_0_f32" )]] kernel cpy_q_t kernel_cpy_q_f32<block_q4_0, QK4_0, dequant_q4_0_f>;
4507
+ template [[host_name(" kernel_cpy_q4_1_f32" )]] kernel cpy_q_t kernel_cpy_q_f32<block_q4_1, QK4_1, dequant_q4_1_f>;
4508
+ template [[host_name(" kernel_cpy_q5_0_f32" )]] kernel cpy_q_t kernel_cpy_q_f32<block_q5_0, QK5_0, dequant_q5_0_f>;
4509
+ template [[host_name(" kernel_cpy_q5_1_f32" )]] kernel cpy_q_t kernel_cpy_q_f32<block_q5_1, QK5_1, dequant_q5_1_f>;
4510
+ template [[host_name(" kernel_cpy_q8_0_f32" )]] kernel cpy_q_t kernel_cpy_q_f32<block_q8_0, QK8_0, dequant_q8_0_f>;
4511
+
4375
4512
template <typename args_t >
4376
4513
void kernel_mul_mv_q2_K_f32_impl (
4377
4514
args_t args,
0 commit comments