@@ -4341,6 +4341,49 @@ kernel void kernel_cpy_f32_iq4_nl(
4341
4341
}
4342
4342
}
4343
4343
4344
+ template <typename T4x4, typename block_q, short nl, void (*dequantize_func)(device const block_q *, short , thread T4x4 &)>
4345
+ kernel void kernel_cpy_q_f32 (
4346
+ constant ggml_metal_kargs_cpy & args,
4347
+ device const char * src0,
4348
+ device char * dst,
4349
+ uint3 tgpig[[threadgroup_position_in_grid]],
4350
+ ushort3 tpitg[[thread_position_in_threadgroup]],
4351
+ ushort3 ntg[[threads_per_threadgroup]]) {
4352
+ const int i03 = tgpig[2 ];
4353
+ const int i02 = tgpig[1 ];
4354
+ const int i01 = tgpig[0 ];
4355
+
4356
+ const int64_t n = i03*args.ne02 *args.ne01 *args.ne00 + i02*args.ne01 *args.ne00 + i01*args.ne00 ;
4357
+
4358
+ const int64_t i3 = n/(args.ne2 *args.ne1 *args.ne0 );
4359
+ const int64_t i2 = (n - i3*args.ne2 *args.ne1 *args.ne0 )/(args.ne1 *args.ne0 );
4360
+ const int64_t i1 = (n - i3*args.ne2 *args.ne1 *args.ne0 - i2*args.ne1 *args.ne0 )/args.ne0 ;
4361
+ const int64_t i0 = (n - i3*args.ne2 *args.ne1 *args.ne0 - i2*args.ne1 *args.ne0 - i1*args.ne0 );
4362
+
4363
+ device const block_q * src_data = (device const block_q *)(src0 + i03*args.nb03 + i02*args.nb02 + i01*args.nb01 );
4364
+ device T4x4 * dst_data = (device T4x4 *)(dst + i3*args.nb3 + i2*args.nb2 + i1*args.nb1 + i0*args.nb0 );
4365
+
4366
+ for (int64_t i00 = tpitg.x ; i00 < args.ne00 /16 ; i00 += ntg.x ) {
4367
+ T4x4 temp;
4368
+ dequantize_func (src_data + i00/nl, i00%nl, temp);
4369
+ dst_data[i00] = temp;
4370
+ }
4371
+ }
4372
+
4373
+ typedef decltype (kernel_cpy_q_f32<float4x4, block_q4_0, 2 , dequantize_q4_0>) cpy_q_f_t;
4374
+
4375
+ template [[host_name(" kernel_cpy_q4_0_f32" )]] kernel cpy_q_f_t kernel_cpy_q_f32<float4x4, block_q4_0, 2 , dequantize_q4_0>;
4376
+ template [[host_name(" kernel_cpy_q4_1_f32" )]] kernel cpy_q_f_t kernel_cpy_q_f32<float4x4, block_q4_1, 2 , dequantize_q4_1>;
4377
+ template [[host_name(" kernel_cpy_q5_0_f32" )]] kernel cpy_q_f_t kernel_cpy_q_f32<float4x4, block_q5_0, 2 , dequantize_q5_0>;
4378
+ template [[host_name(" kernel_cpy_q5_1_f32" )]] kernel cpy_q_f_t kernel_cpy_q_f32<float4x4, block_q5_1, 2 , dequantize_q5_1>;
4379
+ template [[host_name(" kernel_cpy_q8_0_f32" )]] kernel cpy_q_f_t kernel_cpy_q_f32<float4x4, block_q8_0, 2 , dequantize_q8_0>;
4380
+
4381
+ template [[host_name(" kernel_cpy_q4_0_f16" )]] kernel cpy_q_f_t kernel_cpy_q_f32<half4x4, block_q4_0, 2 , dequantize_q4_0>;
4382
+ template [[host_name(" kernel_cpy_q4_1_f16" )]] kernel cpy_q_f_t kernel_cpy_q_f32<half4x4, block_q4_1, 2 , dequantize_q4_1>;
4383
+ template [[host_name(" kernel_cpy_q5_0_f16" )]] kernel cpy_q_f_t kernel_cpy_q_f32<half4x4, block_q5_0, 2 , dequantize_q5_0>;
4384
+ template [[host_name(" kernel_cpy_q5_1_f16" )]] kernel cpy_q_f_t kernel_cpy_q_f32<half4x4, block_q5_1, 2 , dequantize_q5_1>;
4385
+ template [[host_name(" kernel_cpy_q8_0_f16" )]] kernel cpy_q_f_t kernel_cpy_q_f32<half4x4, block_q8_0, 2 , dequantize_q8_0>;
4386
+
4344
4387
kernel void kernel_concat (
4345
4388
constant ggml_metal_kargs_concat & args,
4346
4389
device const char * src0,
@@ -4372,150 +4415,6 @@ kernel void kernel_concat(
4372
4415
}
4373
4416
}
4374
4417
4375
- template <typename block_q, short qqk, void (*dequantize_func)(device const block_q *, device float *)>
4376
- kernel void kernel_cpy_q_f32 (
4377
- constant ggml_metal_kargs_cpy & args,
4378
- device const char * cx [[ buffer(1 ) ]],
4379
- device char * cdst [[ buffer(2 ) ]],
4380
- uint tid [[ thread_position_in_grid ]]
4381
- )
4382
- {
4383
- // Compute the global index multiplied by QK, matching:
4384
- // i = (blockDim.x*blockIdx.x + threadIdx.x)*qk
4385
- const int i = int (tid) * qqk;
4386
-
4387
- // Bounds check
4388
- if (i >= args.ne ) {
4389
- return ;
4390
- }
4391
-
4392
- const int i03 = i/(args.ne00 * args.ne01 * args.ne02 );
4393
- const int i02 = (i - i03*args.ne00 *args.ne01 *args.ne02 )/ (args.ne00 *args.ne01 );
4394
- const int i01 = (i - i03*args.ne00 *args.ne01 *args.ne02 - i02*args.ne01 *args.ne00 ) / args.ne00 ;
4395
- const int i00 = i - i03*args.ne00 *args.ne01 *args.ne02 - i02*args.ne01 *args.ne00 - i01*args.ne00 ;
4396
- const int x_offset = (i00/qqk)*args.nb00 + i01*args.nb01 + i02*args.nb02 + i03 * args.nb03 ;
4397
-
4398
- const int i13 = i/(args.ne0 * args.ne1 * args.ne2 );
4399
- const int i12 = (i - i13*args.ne0 *args.ne1 *args.ne2 ) / (args.ne0 *args.ne1 );
4400
- const int i11 = (i - i13*args.ne0 *args.ne1 *args.ne2 - i12*args.ne0 *args.ne1 ) / args.ne0 ;
4401
- const int i10 = i - i13*args.ne0 *args.ne1 *args.ne2 - i12*args.ne0 *args.ne1 - i11*args.ne0 ;
4402
- const int dst_offset = i10*args.nb0 + i11*args.nb1 + i12*args.nb2 + i13*args.nb3 ;
4403
-
4404
- device const block_q * src_block = (device const block_q *)(cx + x_offset);
4405
- device float * dst = (device float *)(cdst + dst_offset);
4406
-
4407
- dequantize_func (src_block, dst);
4408
- }
4409
-
4410
- void dequant_q4_0_f (device const block_q4_0 * src_block, device float * dst) {
4411
- float d = float (src_block->d );
4412
- const float shift = 8 .0f ;
4413
-
4414
- // Unpack 2 x 4-bit values per byte.
4415
- #pragma unroll(16)
4416
- for (int j = 0 ; j < QK4_0/2 ; j++) {
4417
- uint8_t q = src_block->qs [j];
4418
- uint8_t q0 = q & 0x0F ;
4419
- uint8_t q1 = (q >> 4 ) & 0x0F ;
4420
- dst[j] = (float (q0) - shift) * d;
4421
- dst[j + QK4_0/2 ] = (float (q1) - shift) * d;
4422
- }
4423
- }
4424
-
4425
- void dequant_q4_1_f (device const block_q4_1 * src_block, device float * dst) {
4426
- float d = float (src_block->d );
4427
- float vmin = float (src_block->m );
4428
-
4429
- #pragma unroll(16)
4430
- for (int j = 0 ; j < QK4_1/2 ; j++) {
4431
- uint8_t q = src_block->qs [j];
4432
- uint8_t q0 = q & 0x0F ;
4433
- uint8_t q1 = (q >> 4 ) & 0x0F ;
4434
- dst[j] = vmin + d * float (q0);
4435
- dst[j + QK4_1/2 ] = vmin + d * float (q1);
4436
- }
4437
- }
4438
-
4439
- void dequant_q5_0_f (device const block_q5_0 * src_block, device float * dst) {
4440
- float d = float (src_block->d );
4441
- const float shift = 16 .f ;
4442
-
4443
- // Combine the four qh bytes into a 32-bit value.
4444
- uint32_t qhVal = 0
4445
- | ((uint32_t ) src_block->qh [0 ] << 0 )
4446
- | ((uint32_t ) src_block->qh [1 ] << 8 )
4447
- | ((uint32_t ) src_block->qh [2 ] << 16 )
4448
- | ((uint32_t ) src_block->qh [3 ] << 24 );
4449
-
4450
- // First half
4451
- #pragma unroll(16)
4452
- for (int j = 0 ; j < QK5_0/2 ; j++) {
4453
- uint8_t q = src_block->qs [j];
4454
- uint8_t lowNib = q & 0x0F ;
4455
- uint8_t highBit = (qhVal >> j) & 0x1 ;
4456
- uint8_t qVal = (highBit << 4 ) | lowNib;
4457
- dst[j] = (float (qVal) - shift) * d;
4458
- }
4459
- // Second half
4460
- #pragma unroll(16)
4461
- for (int j = QK5_0/2 ; j < QK5_0; j++) {
4462
- int k = j - QK5_0/2 ;
4463
- uint8_t q = src_block->qs [k];
4464
- uint8_t hiNib = (q >> 4 ) & 0x0F ;
4465
- uint8_t highBit = (qhVal >> j) & 0x1 ;
4466
- uint8_t qVal = (highBit << 4 ) | hiNib;
4467
- dst[j] = (float (qVal) - shift) * d;
4468
- }
4469
- }
4470
-
4471
- void dequant_q5_1_f (device const block_q5_1 * src_block, device float * dst) {
4472
- float d = float (src_block->d );
4473
- float vmin = float (src_block->m );
4474
-
4475
- uint32_t qhVal = 0
4476
- | ((uint32_t ) src_block->qh [0 ] << 0 )
4477
- | ((uint32_t ) src_block->qh [1 ] << 8 )
4478
- | ((uint32_t ) src_block->qh [2 ] << 16 )
4479
- | ((uint32_t ) src_block->qh [3 ] << 24 );
4480
-
4481
- // First half
4482
- #pragma unroll(16)
4483
- for (int j = 0 ; j < QK5_1/2 ; j++) {
4484
- uint8_t q = src_block->qs [j];
4485
- uint8_t lowNib = q & 0x0F ;
4486
- uint8_t highBit = (qhVal >> j) & 0x1 ;
4487
- uint8_t qVal = (highBit << 4 ) | lowNib;
4488
- dst[j] = vmin + d * float (qVal);
4489
- }
4490
- // Second half
4491
- #pragma unroll(16)
4492
- for (int j = QK5_1/2 ; j < QK5_1; j++) {
4493
- int k = j - QK5_1/2 ;
4494
- uint8_t q = src_block->qs [k];
4495
- uint8_t hiNib = (q >> 4 ) & 0x0F ;
4496
- uint8_t highBit = (qhVal >> j) & 0x1 ;
4497
- uint8_t qVal = (highBit << 4 ) | hiNib;
4498
- dst[j] = vmin + d * float (qVal);
4499
- }
4500
- }
4501
-
4502
- void dequant_q8_0_f (device const block_q8_0 * src_block, device float * dst) {
4503
- const float d = (float )src_block->d ;
4504
-
4505
- #pragma unroll(32)
4506
- for (int j = 0 ; j < QK8_0; j++) {
4507
- dst[j] = src_block->qs [j] * d;
4508
- }
4509
- }
4510
-
4511
- typedef decltype (kernel_cpy_q_f32<block_q4_0, QK4_0, dequant_q4_0_f>) cpy_q_t;
4512
-
4513
- template [[host_name(" kernel_cpy_q4_0_f32" )]] kernel cpy_q_t kernel_cpy_q_f32<block_q4_0, QK4_0, dequant_q4_0_f>;
4514
- template [[host_name(" kernel_cpy_q4_1_f32" )]] kernel cpy_q_t kernel_cpy_q_f32<block_q4_1, QK4_1, dequant_q4_1_f>;
4515
- template [[host_name(" kernel_cpy_q5_0_f32" )]] kernel cpy_q_t kernel_cpy_q_f32<block_q5_0, QK5_0, dequant_q5_0_f>;
4516
- template [[host_name(" kernel_cpy_q5_1_f32" )]] kernel cpy_q_t kernel_cpy_q_f32<block_q5_1, QK5_1, dequant_q5_1_f>;
4517
- template [[host_name(" kernel_cpy_q8_0_f32" )]] kernel cpy_q_t kernel_cpy_q_f32<block_q8_0, QK8_0, dequant_q8_0_f>;
4518
-
4519
4418
template <typename args_t >
4520
4419
void kernel_mul_mv_q2_K_f32_impl (
4521
4420
args_t args,
0 commit comments