@@ -187,6 +187,8 @@ static inline __m256i mul_sum_i8_pairs_int32x8(const __m256i x, const __m256i y)
187
187
}
188
188
#endif
189
189
190
+ static const int8_t kvalues_iq4nl [16 ] = {-127 , -104 , -83 , -65 , -49 , -35 , -22 , -10 , 1 , 13 , 25 , 38 , 53 , 69 , 89 , 113 };
191
+
190
192
static void quantize_q8_0_4x4 (const float * restrict x , void * restrict vy , int64_t k ) {
191
193
assert (QK8_0 == 32 );
192
194
assert (k % QK8_0 == 0 );
@@ -996,6 +998,102 @@ void ggml_gemv_q4_0_8x8_q8_0(int n, float * restrict s, size_t bs, const void *
996
998
}
997
999
}
998
1000
1001
+ void ggml_gemv_iq4_nl_4x4_q8_0 (int n , float * restrict s , size_t bs , const void * restrict vx , const void * restrict vy , int nr , int nc ) {
1002
+ const int qk = QK8_0 ;
1003
+ const int nb = n / qk ;
1004
+ const int ncols_interleaved = 4 ;
1005
+ const int blocklen = 4 ;
1006
+
1007
+ assert (n % qk == 0 );
1008
+ assert (nc % ncols_interleaved == 0 );
1009
+
1010
+ UNUSED (s );
1011
+ UNUSED (bs );
1012
+ UNUSED (vx );
1013
+ UNUSED (vy );
1014
+ UNUSED (nr );
1015
+ UNUSED (nc );
1016
+ UNUSED (nb );
1017
+ UNUSED (ncols_interleaved );
1018
+ UNUSED (blocklen );
1019
+
1020
+ #if ! ((defined(_MSC_VER )) && ! defined(__clang__ )) && defined(__aarch64__ ) && defined(__ARM_NEON )
1021
+ if (ggml_cpu_has_neon ()) {
1022
+ const int8x16_t kvalues = vld1q_s8 (kvalues_iq4nl );
1023
+ const block_q8_0 * a_ptr = (const block_q8_0 * ) vy ;
1024
+ float * res_ptr = s ;
1025
+
1026
+ for (int x = 0 ; x < nc / ncols_interleaved ; x ++ ) {
1027
+ const block_q4_0x4 * b_ptr = (const block_q4_0x4 * ) vx + (x * nb );
1028
+
1029
+ float32x4_t sumf = vdupq_n_f32 (0 );
1030
+ for (int l = 0 ; l < nb ; l ++ ) {
1031
+ uint8x16_t b_0 = vld1q_u8 (b_ptr [l ].qs + 0 );
1032
+ uint8x16_t b_1 = vld1q_u8 (b_ptr [l ].qs + 16 );
1033
+ uint8x16_t b_2 = vld1q_u8 (b_ptr [l ].qs + 32 );
1034
+ uint8x16_t b_3 = vld1q_u8 (b_ptr [l ].qs + 48 );
1035
+
1036
+ int8x16_t b_0_hi = vqtbl1q_s8 (kvalues , b_0 >> 4 );
1037
+ int8x16_t b_0_lo = vqtbl1q_s8 (kvalues , b_0 & 0x0F );
1038
+ int8x16_t b_1_hi = vqtbl1q_s8 (kvalues , b_1 >> 4 );
1039
+ int8x16_t b_1_lo = vqtbl1q_s8 (kvalues , b_1 & 0x0F );
1040
+ int8x16_t b_2_hi = vqtbl1q_s8 (kvalues , b_2 >> 4 );
1041
+ int8x16_t b_2_lo = vqtbl1q_s8 (kvalues , b_2 & 0x0F );
1042
+ int8x16_t b_3_hi = vqtbl1q_s8 (kvalues , b_3 >> 4 );
1043
+ int8x16_t b_3_lo = vqtbl1q_s8 (kvalues , b_3 & 0x0F );
1044
+
1045
+ int8x16_t a_0 = vld1q_s8 (a_ptr [l ].qs + 0 );
1046
+ int8x16_t a_1 = vld1q_s8 (a_ptr [l ].qs + 16 );
1047
+
1048
+ int32x4_t sumi = vdupq_n_s32 (0 );
1049
+ sumi = vdotq_laneq_s32 (sumi , b_0_lo , a_0 , 0 );
1050
+ sumi = vdotq_laneq_s32 (sumi , b_0_hi , a_1 , 0 );
1051
+ sumi = vdotq_laneq_s32 (sumi , b_1_lo , a_0 , 1 );
1052
+ sumi = vdotq_laneq_s32 (sumi , b_1_hi , a_1 , 1 );
1053
+ sumi = vdotq_laneq_s32 (sumi , b_2_lo , a_0 , 2 );
1054
+ sumi = vdotq_laneq_s32 (sumi , b_2_hi , a_1 , 2 );
1055
+ sumi = vdotq_laneq_s32 (sumi , b_3_lo , a_0 , 3 );
1056
+ sumi = vdotq_laneq_s32 (sumi , b_3_hi , a_1 , 3 );
1057
+
1058
+ float32x4_t a_d = vcvt_f32_f16 (vld1_dup_f16 ((const float16_t * )& a_ptr [l ].d ));
1059
+ float32x4_t b_d = vcvt_f32_f16 (vld1_f16 ((const float16_t * )b_ptr [l ].d ));
1060
+ float32x4_t d = a_d * b_d ;
1061
+
1062
+ sumf = vmlaq_f32 (sumf , d , vcvtq_f32_s32 (sumi ));
1063
+ }
1064
+
1065
+ vst1q_f32 (res_ptr + x * 4 , sumf );
1066
+ }
1067
+ return ;
1068
+ }
1069
+ #endif // #if ! ((defined(_MSC_VER)) && ! defined(__clang__)) && defined(__aarch64__) && defined(__ARM_NEON)
1070
+ {
1071
+ float sumf [4 ];
1072
+ int sumi ;
1073
+
1074
+ const block_q8_0 * a_ptr = (const block_q8_0 * ) vy ;
1075
+ for (int x = 0 ; x < nc / ncols_interleaved ; x ++ ) {
1076
+ const block_iq4_nlx4 * b_ptr = (const block_iq4_nlx4 * ) vx + (x * nb );
1077
+
1078
+ for (int j = 0 ; j < ncols_interleaved ; j ++ ) sumf [j ] = 0.0 ;
1079
+ for (int l = 0 ; l < nb ; l ++ ) {
1080
+ for (int k = 0 ; k < (qk / (2 * blocklen )); k ++ ) {
1081
+ for (int j = 0 ; j < ncols_interleaved ; j ++ ) {
1082
+ sumi = 0 ;
1083
+ for (int i = 0 ; i < blocklen ; ++ i ) {
1084
+ const int v0 = kvalues_iq4nl [b_ptr [l ].qs [k * ncols_interleaved * blocklen + j * blocklen + i ] & 0x0F ];
1085
+ const int v1 = kvalues_iq4nl [b_ptr [l ].qs [k * ncols_interleaved * blocklen + j * blocklen + i ] >> 4 ];
1086
+ sumi += ((v0 * a_ptr [l ].qs [k * blocklen + i ]) + (v1 * a_ptr [l ].qs [k * blocklen + i + qk / 2 ]));
1087
+ }
1088
+ sumf [j ] += sumi * GGML_FP16_TO_FP32 (b_ptr [l ].d [j ]) * GGML_FP16_TO_FP32 (a_ptr [l ].d );
1089
+ }
1090
+ }
1091
+ }
1092
+ for (int j = 0 ; j < ncols_interleaved ; j ++ ) s [x * ncols_interleaved + j ] = sumf [j ];
1093
+ }
1094
+ }
1095
+ }
1096
+
999
1097
void ggml_gemm_q4_0_4x4_q8_0 (int n , float * restrict s , size_t bs , const void * restrict vx , const void * restrict vy , int nr , int nc ) {
1000
1098
const int qk = QK8_0 ;
1001
1099
const int nb = n / qk ;
@@ -3386,6 +3484,117 @@ void ggml_gemm_q4_0_8x8_q8_0(int n, float * restrict s, size_t bs, const void *
3386
3484
}
3387
3485
}
3388
3486
3487
+ void ggml_gemm_iq4_nl_4x4_q8_0 (int n , float * restrict s , size_t bs , const void * restrict vx , const void * restrict vy , int nr , int nc ) {
3488
+ const int qk = QK8_0 ;
3489
+ const int nb = n / qk ;
3490
+ const int ncols_interleaved = 4 ;
3491
+ const int blocklen = 4 ;
3492
+
3493
+ assert (n % qk == 0 );
3494
+ assert (nr % 4 == 0 );
3495
+ assert (nc % ncols_interleaved == 0 );
3496
+
3497
+ UNUSED (s );
3498
+ UNUSED (bs );
3499
+ UNUSED (vx );
3500
+ UNUSED (vy );
3501
+ UNUSED (nr );
3502
+ UNUSED (nc );
3503
+ UNUSED (nb );
3504
+ UNUSED (ncols_interleaved );
3505
+ UNUSED (blocklen );
3506
+
3507
+ #if ! ((defined(_MSC_VER )) && ! defined(__clang__ )) && defined(__aarch64__ ) && defined(__ARM_NEON )
3508
+ if (ggml_cpu_has_neon ()) {
3509
+ const int8x16_t kvalues = vld1q_s8 (kvalues_iq4nl );
3510
+
3511
+ for (int y = 0 ; y < nr / 4 ; y ++ ) {
3512
+ const block_q8_0x4 * a_ptr = (const block_q8_0x4 * ) vy + (y * nb );
3513
+ for (int x = 0 ; x < nc / ncols_interleaved ; x ++ ) {
3514
+ const block_q4_0x4 * b_ptr = (const block_q4_0x4 * ) vx + (x * nb );
3515
+
3516
+ float32x4_t sumf [4 ];
3517
+ for (int m = 0 ; m < 4 ; m ++ ) {
3518
+ sumf [m ] = vdupq_n_f32 (0 );
3519
+ }
3520
+
3521
+ for (int l = 0 ; l < nb ; l ++ ) {
3522
+ float32x4_t a_d = vcvt_f32_f16 (vld1_f16 ((const float16_t * )a_ptr [l ].d ));
3523
+ float32x4_t b_d = vcvt_f32_f16 (vld1_f16 ((const float16_t * )b_ptr [l ].d ));
3524
+
3525
+ int32x4_t sumi_0 = vdupq_n_s32 (0 );
3526
+ int32x4_t sumi_1 = vdupq_n_s32 (0 );
3527
+ int32x4_t sumi_2 = vdupq_n_s32 (0 );
3528
+ int32x4_t sumi_3 = vdupq_n_s32 (0 );
3529
+
3530
+ for (int k = 0 ; k < 4 ; k ++ ) {
3531
+ int8x16_t a_0 = vld1q_s8 (a_ptr [l ].qs + 16 * k + 0 );
3532
+ int8x16_t a_1 = vld1q_s8 (a_ptr [l ].qs + 16 * k + 64 );
3533
+
3534
+ uint8x16_t b = vld1q_u8 (b_ptr [l ].qs + 16 * k );
3535
+ int8x16_t b_hi = vqtbl1q_s8 (kvalues , b >> 4 );
3536
+ int8x16_t b_lo = vqtbl1q_s8 (kvalues , b & 0xF );
3537
+
3538
+ sumi_0 = vdotq_laneq_s32 (sumi_0 , b_lo , a_0 , 0 );
3539
+ sumi_1 = vdotq_laneq_s32 (sumi_1 , b_lo , a_0 , 1 );
3540
+ sumi_2 = vdotq_laneq_s32 (sumi_2 , b_lo , a_0 , 2 );
3541
+ sumi_3 = vdotq_laneq_s32 (sumi_3 , b_lo , a_0 , 3 );
3542
+ sumi_0 = vdotq_laneq_s32 (sumi_0 , b_hi , a_1 , 0 );
3543
+ sumi_1 = vdotq_laneq_s32 (sumi_1 , b_hi , a_1 , 1 );
3544
+ sumi_2 = vdotq_laneq_s32 (sumi_2 , b_hi , a_1 , 2 );
3545
+ sumi_3 = vdotq_laneq_s32 (sumi_3 , b_hi , a_1 , 3 );
3546
+ }
3547
+
3548
+ sumf [0 ] = vmlaq_f32 (sumf [0 ], vmulq_laneq_f32 (b_d , a_d , 0 ), vcvtq_f32_s32 (sumi_0 ));
3549
+ sumf [1 ] = vmlaq_f32 (sumf [1 ], vmulq_laneq_f32 (b_d , a_d , 1 ), vcvtq_f32_s32 (sumi_1 ));
3550
+ sumf [2 ] = vmlaq_f32 (sumf [2 ], vmulq_laneq_f32 (b_d , a_d , 2 ), vcvtq_f32_s32 (sumi_2 ));
3551
+ sumf [3 ] = vmlaq_f32 (sumf [3 ], vmulq_laneq_f32 (b_d , a_d , 3 ), vcvtq_f32_s32 (sumi_3 ));
3552
+ }
3553
+
3554
+ for (int m = 0 ; m < 4 ; m ++ ) {
3555
+ vst1q_f32 (s + (y * 4 + m ) * bs + x * 4 , sumf [m ]);
3556
+ }
3557
+ }
3558
+ }
3559
+ return ;
3560
+ }
3561
+ #endif // #if ! ((defined(_MSC_VER)) && ! defined(__clang__)) && defined(__aarch64__) && defined(__ARM_NEON)
3562
+ {
3563
+ float sumf [4 ][4 ];
3564
+ int sumi ;
3565
+
3566
+ for (int y = 0 ; y < nr / 4 ; y ++ ) {
3567
+ const block_q8_0x4 * a_ptr = (const block_q8_0x4 * ) vy + (y * nb );
3568
+ for (int x = 0 ; x < nc / ncols_interleaved ; x ++ ) {
3569
+ const block_iq4_nlx4 * b_ptr = (const block_iq4_nlx4 * ) vx + (x * nb );
3570
+ for (int m = 0 ; m < 4 ; m ++ ) {
3571
+ for (int j = 0 ; j < ncols_interleaved ; j ++ ) sumf [m ][j ] = 0.0 ;
3572
+ }
3573
+ for (int l = 0 ; l < nb ; l ++ ) {
3574
+ for (int k = 0 ; k < (qk / (2 * blocklen )); k ++ ) {
3575
+ for (int m = 0 ; m < 4 ; m ++ ) {
3576
+ for (int j = 0 ; j < ncols_interleaved ; j ++ ) {
3577
+ sumi = 0 ;
3578
+ for (int i = 0 ; i < blocklen ; ++ i ) {
3579
+ const int v0 = kvalues_iq4nl [b_ptr [l ].qs [k * ncols_interleaved * blocklen + j * blocklen + i ] & 0x0F ];
3580
+ const int v1 = kvalues_iq4nl [b_ptr [l ].qs [k * ncols_interleaved * blocklen + j * blocklen + i ] >> 4 ];
3581
+ sumi += ((v0 * a_ptr [l ].qs [k * 4 * blocklen + m * blocklen + i ]) +
3582
+ (v1 * a_ptr [l ].qs [k * 4 * blocklen + m * blocklen + i + qk / 2 * 4 ]));
3583
+ }
3584
+ sumf [m ][j ] += sumi * GGML_FP16_TO_FP32 (b_ptr [l ].d [j ]) * GGML_FP16_TO_FP32 (a_ptr [l ].d [m ]);
3585
+ }
3586
+ }
3587
+ }
3588
+ }
3589
+ for (int m = 0 ; m < 4 ; m ++ ) {
3590
+ for (int j = 0 ; j < ncols_interleaved ; j ++ )
3591
+ s [(y * 4 + m ) * bs + x * ncols_interleaved + j ] = sumf [m ][j ];
3592
+ }
3593
+ }
3594
+ }
3595
+ }
3596
+ }
3597
+
3389
3598
// FIXME: this code is duplicated from ggml-aarch64.c
3390
3599
static block_q4_0x4 make_block_q4_0x4 (block_q4_0 * in , unsigned int blck_size_interleave ) {
3391
3600
block_q4_0x4 out ;
@@ -3518,27 +3727,101 @@ static int repack_q4_0_to_q4_0_8_bl(struct ggml_tensor *t, int interleave_block,
3518
3727
GGML_UNUSED (data_size );
3519
3728
}
3520
3729
3730
+ static block_iq4_nlx4 make_block_iq4_nlx4 (block_iq4_nl * in , unsigned int blck_size_interleave ) {
3731
+ block_iq4_nlx4 out ;
3732
+
3733
+ for (int i = 0 ; i < 4 ; i ++ ) {
3734
+ out .d [i ] = in [i ].d ;
3735
+ }
3736
+
3737
+ const int end = QK4_NL * 2 / blck_size_interleave ;
3738
+
3739
+ if (blck_size_interleave == 8 ) {
3740
+ for (int i = 0 ; i < end ; ++ i ) {
3741
+ int src_id = i % 4 ;
3742
+ int src_offset = (i / 4 ) * blck_size_interleave ;
3743
+ int dst_offset = i * blck_size_interleave ;
3744
+
3745
+ // Using memcpy to avoid unaligned memory accesses
3746
+ memcpy (& out .qs [dst_offset ], & in [src_id ].qs [src_offset ], sizeof (uint64_t ));
3747
+ }
3748
+ } else if (blck_size_interleave == 4 ) {
3749
+ for (int i = 0 ; i < end ; ++ i ) {
3750
+ int src_id = i % 4 ;
3751
+ int src_offset = (i / 4 ) * blck_size_interleave ;
3752
+ int dst_offset = i * blck_size_interleave ;
3753
+
3754
+ memcpy (& out .qs [dst_offset ], & in [src_id ].qs [src_offset ], sizeof (uint32_t ));
3755
+ }
3756
+ } else {
3757
+ GGML_ASSERT (false);
3758
+ }
3759
+
3760
+ return out ;
3761
+ }
3762
+
3763
+ static int repack_iq4_nl_to_iq4_nl_4_bl (struct ggml_tensor * t , int interleave_block , const void * restrict data , size_t data_size ) {
3764
+ GGML_ASSERT (t -> type == GGML_TYPE_IQ4_NL );
3765
+ GGML_ASSERT (interleave_block == 4 || interleave_block == 8 );
3766
+
3767
+ block_iq4_nlx4 * dst = (block_iq4_nlx4 * )t -> data ;
3768
+ const block_iq4_nl * src = (const block_iq4_nl * )data ;
3769
+ block_iq4_nl dst_tmp [4 ];
3770
+ int nrow = t -> ne [1 ]; // Number of rows
3771
+ int nrows_interleaved = 4 ;
3772
+ int nblocks = t -> ne [0 ] / QK4_0 ;
3773
+
3774
+ GGML_ASSERT (data_size == nrow * nblocks * sizeof (block_iq4_nl ));
3775
+
3776
+ if (nrow % nrows_interleaved != 0 || t -> ne [0 ] % 8 != 0 ) {
3777
+ return -1 ;
3778
+ }
3779
+
3780
+ for (int b = 0 ; b < nrow ; b += nrows_interleaved ) {
3781
+ for (int64_t x = 0 ; x < nblocks ; x ++ ) {
3782
+ for (int i = 0 ; i < nrows_interleaved ; i ++ ) {
3783
+ dst_tmp [i ] = src [x + i * nblocks ];
3784
+ }
3785
+ * dst ++ = make_block_iq4_nlx4 (dst_tmp , interleave_block );
3786
+ }
3787
+ src += nrows_interleaved * nblocks ;
3788
+ }
3789
+ return 0 ;
3790
+
3791
+ GGML_UNUSED (data_size );
3792
+ }
3793
+
3521
3794
// Prepare for optimized kernels if applicable
3522
3795
void ggml_aarch64_repack_tensor (struct ggml_tensor * cur , enum ggml_type repack_type , const void * restrict data , size_t data_size ) {
3523
3796
if (cur -> type == repack_type ) {
3524
3797
memcpy (cur -> data , data , data_size );
3525
3798
return ;
3526
3799
}
3527
3800
3528
- GGML_ASSERT (cur -> type == GGML_TYPE_Q4_0 );
3529
-
3530
- switch (repack_type ) {
3531
- case GGML_TYPE_Q4_0_8_8 :
3532
- repack_q4_0_to_q4_0_8_bl (cur , 8 , data , data_size );
3533
- break ;
3534
- case GGML_TYPE_Q4_0_4_8 :
3535
- repack_q4_0_to_q4_0_4_bl (cur , 8 , data , data_size );
3536
- break ;
3537
- case GGML_TYPE_Q4_0_4_4 :
3538
- repack_q4_0_to_q4_0_4_bl (cur , 4 , data , data_size );
3539
- break ;
3540
- default :
3541
- GGML_ABORT ("Unsupported type" );
3801
+ if (cur -> type == GGML_TYPE_Q4_0 ) {
3802
+ switch (repack_type ) {
3803
+ case GGML_TYPE_Q4_0_8_8 :
3804
+ repack_q4_0_to_q4_0_8_bl (cur , 8 , data , data_size );
3805
+ break ;
3806
+ case GGML_TYPE_Q4_0_4_8 :
3807
+ repack_q4_0_to_q4_0_4_bl (cur , 8 , data , data_size );
3808
+ break ;
3809
+ case GGML_TYPE_Q4_0_4_4 :
3810
+ repack_q4_0_to_q4_0_4_bl (cur , 4 , data , data_size );
3811
+ break ;
3812
+ default :
3813
+ GGML_ABORT ("Unsupported type" );
3814
+ }
3815
+ } else if (cur -> type == GGML_TYPE_IQ4_NL ) {
3816
+ switch (repack_type ) {
3817
+ case GGML_TYPE_IQ4_NL_4_4 :
3818
+ repack_iq4_nl_to_iq4_nl_4_bl (cur , 4 , data , data_size );
3819
+ break ;
3820
+ default :
3821
+ GGML_ABORT ("Unsupported type" );
3822
+ }
3823
+ } else {
3824
+ GGML_ABORT ("Unsupported type" );
3542
3825
}
3543
3826
}
3544
3827
@@ -3554,6 +3837,10 @@ enum ggml_type ggml_aarch64_get_optimal_repack_type(const struct ggml_tensor * c
3554
3837
if (ggml_cpu_has_neon ()) {
3555
3838
return GGML_TYPE_Q4_0_4_4 ;
3556
3839
}
3840
+ } else if (cur -> type == GGML_TYPE_IQ4_NL ) {
3841
+ if (ggml_cpu_has_neon ()) {
3842
+ return GGML_TYPE_IQ4_NL_4_4 ;
3843
+ }
3557
3844
}
3558
3845
3559
3846
return cur -> type ;
0 commit comments