@@ -297,6 +297,90 @@ static const uint64_t table_b2b_0[1 << 8] = { B8(00, 10) }; // ( b) << 4
297
297
static const uint64_t table_b2b_1 [1 << 8 ] = { B8 (10 , 00 ) }; // (!b) << 4
298
298
#endif
299
299
300
+ #if defined(__loongarch_sx )
301
+
302
+ static __m128i lsx_packs_w (__m128i a , __m128i b ) {
303
+ __m128i tmp , tmp1 ;
304
+ tmp = __lsx_vsat_w (a , 15 );
305
+ tmp1 = __lsx_vsat_w (b , 15 );
306
+ return __lsx_vpickev_h (tmp1 , tmp );
307
+ }
308
+
309
+ static __m128i lsx_packs_h (__m128i a , __m128i b ) {
310
+ __m128i tmp , tmp1 ;
311
+ tmp = __lsx_vsat_h (a , 7 );
312
+ tmp1 = __lsx_vsat_h (b , 7 );
313
+ return __lsx_vpickev_b (tmp1 , tmp );
314
+ }
315
+
316
+ static __m128i lsx_packus_h (__m128i a , __m128i b ) {
317
+ __m128i tmp , tmp1 ;
318
+ tmp = __lsx_vsat_hu (a , 7 );
319
+ tmp1 = __lsx_vsat_hu (b , 7 );
320
+ return __lsx_vpickev_b (tmp1 , tmp );
321
+ }
322
+
323
+ static __m128i lsx_maddubs_h (__m128i a , __m128i b ) {
324
+ __m128i tmp1 , tmp2 ;
325
+ tmp1 = __lsx_vmulwev_h_b (a , b );
326
+ tmp2 = __lsx_vmulwod_h_b (a , b );
327
+ return __lsx_vsadd_h (tmp1 , tmp2 );
328
+ }
329
+
330
+ static __m128i lsx_madd_h (__m128i a , __m128i b ) {
331
+ __m128i tmp1 , tmp2 ;
332
+ tmp1 = __lsx_vmulwev_w_h (a , b );
333
+ tmp2 = __lsx_vmulwod_w_h (a , b );
334
+ return __lsx_vadd_w (tmp1 , tmp2 );
335
+ }
336
+
337
+ static __m128i lsx_set_w (int32_t a , int32_t b , int32_t c , int32_t d ) {
338
+ v4i32 __ret = {d , c , b , a };
339
+ return (__m128i )__ret ;
340
+ }
341
+
342
+ static __m128i lsx_shuffle_b (__m128i a , __m128i b ) {
343
+ __m128i mask_f , zero , tmp0 , tmp2 , mask ;
344
+ int f = 0x8f ;
345
+ mask_f = __lsx_vreplgr2vr_b (f );
346
+ zero = __lsx_vldi (0 );
347
+ tmp0 = __lsx_vand_v (b , mask_f ); // get mask with low 4 bit and sign bits
348
+ tmp0 = __lsx_vori_b (tmp0 , 0x10 ); // make each mask or with 0x10 prepare for positive
349
+ mask = __lsx_vsle_b (zero , tmp0 ); // if mask >= 0, set mask
350
+ tmp2 = __lsx_vand_v (tmp0 , mask ); // maskout the in2 < ones
351
+ return __lsx_vshuf_b (a , zero , tmp2 );
352
+ }
353
+
354
+ static __m128i lsx_hadd_h (__m128i a , __m128i b ) {
355
+ __m128i tmp1 = __lsx_vpickev_h (b , a );
356
+ __m128i tmp2 = __lsx_vpickod_h (b , a );
357
+ return __lsx_vadd_h (tmp1 , tmp2 );
358
+ }
359
+
360
+ static __m128i lsx_hadd_w (__m128i a , __m128i b ) {
361
+ __m128i tmp1 = __lsx_vpickev_w (b , a );
362
+ __m128i tmp2 = __lsx_vpickod_w (b , a );
363
+ return __lsx_vadd_w (tmp1 , tmp2 );
364
+ }
365
+
366
+ static __m128 lsx_hadd_s (__m128 a , __m128 b ) {
367
+ __m128 tmp1 = (__m128 )__lsx_vpickev_w ((__m128i )b , (__m128i )a );
368
+ __m128 tmp2 = (__m128 )__lsx_vpickod_w ((__m128i )b , (__m128i )a );
369
+
370
+ return __lsx_vfadd_s (tmp1 , tmp2 );
371
+ }
372
+
373
+ static inline float hsum_float_4x4 (const __m128 a , const __m128 b , const __m128 c , const __m128 d ) {
374
+ __m128 res_0 = lsx_hadd_s (a , b );
375
+ __m128 res_1 = lsx_hadd_s (c , d );
376
+ __m128 res = lsx_hadd_s (res_0 , res_1 );
377
+ res = lsx_hadd_s (res , res );
378
+ res = lsx_hadd_s (res , res );
379
+
380
+ return ((v4f32 )res )[0 ];
381
+ }
382
+ #endif
383
+
300
384
#if defined(__loongarch_asx )
301
385
302
386
#ifdef __clang__
@@ -395,11 +479,6 @@ static __m256i lasx_set_w(int e7, int e6, int e5, int e4, int e3, int e2, int e1
395
479
return (__m256i )__ret ;
396
480
}
397
481
398
- static __m128i lsx_set_w (int32_t a , int32_t b , int32_t c , int32_t d ) {
399
- v4i32 __ret = {d , c , b , a };
400
- return (__m128i )__ret ;
401
- }
402
-
403
482
static __m256i lasx_set_d (int64_t a , int64_t b , int64_t c , int64_t d ) {
404
483
v4i64 __ret = {d , c , b , a };
405
484
return (__m256i )__ret ;
@@ -409,18 +488,6 @@ static __m256i lasx_insertf128( __m128i x, __m128i y) {
409
488
return lasx_set_q (x , y );
410
489
}
411
490
412
- static __m128i lsx_shuffle_b (__m128i a , __m128i b ) {
413
- __m128i mask_f , zero , tmp0 , tmp2 , mask ;
414
- int f = 0x8f ;
415
- mask_f = __lsx_vreplgr2vr_b (f );
416
- zero = __lsx_vldi (0 );
417
- tmp0 = __lsx_vand_v (b , mask_f ); // get mask with low 4 bit and sign bits
418
- tmp0 = __lsx_vori_b (tmp0 , 0x10 ); // make each mask or with 0x10 prepare for positive
419
- mask = __lsx_vsle_b (zero , tmp0 ); // if mask >= 0, set mask
420
- tmp2 = __lsx_vand_v (tmp0 , mask ); // maskout the in2 < ones
421
- return __lsx_vshuf_b (a , zero , tmp2 );
422
- }
423
-
424
491
static __m256i lasx_shuffle_b (__m256i a , __m256i b ) {
425
492
__m256i mask_f , zero , tmp0 , tmp2 , mask ;
426
493
int f = 0x8f ;
@@ -482,25 +549,6 @@ static __m128 lasx_extractf128( __m256 a, int pos) {
482
549
return ret ;
483
550
}
484
551
485
- static __m128i lsx_hadd_h (__m128i a , __m128i b ) {
486
- __m128i tmp1 = __lsx_vpickev_h (b , a );
487
- __m128i tmp2 = __lsx_vpickod_h (b , a );
488
- return __lsx_vadd_h (tmp1 , tmp2 );
489
- }
490
-
491
- static __m128i lsx_hadd_w (__m128i a , __m128i b ) {
492
- __m128i tmp1 = __lsx_vpickev_w (b , a );
493
- __m128i tmp2 = __lsx_vpickod_w (b , a );
494
- return __lsx_vadd_w (tmp1 , tmp2 );
495
- }
496
-
497
- static __m128 lsx_hadd_s (__m128 a , __m128 b ) {
498
- __m128 tmp1 = (__m128 )__lsx_vpickev_w ((__m128i )b , (__m128i )a );
499
- __m128 tmp2 = (__m128 )__lsx_vpickod_w ((__m128i )b , (__m128i )a );
500
-
501
- return __lsx_vfadd_s (tmp1 , tmp2 );
502
- }
503
-
504
552
static __m256i lasx_maddubs_h (__m256i a , __m256i b ) {
505
553
__m256i tmp1 , tmp2 ;
506
554
tmp1 = __lasx_xvmulwev_h_b (a , b );
@@ -529,42 +577,6 @@ static __m256i lasx_packs_h(__m256i a, __m256i b) {
529
577
return __lasx_xvpickev_b (tmp1 , tmp );
530
578
}
531
579
532
- static __m128i lsx_packs_w (__m128i a , __m128i b ) {
533
- __m128i tmp , tmp1 ;
534
- tmp = __lsx_vsat_w (a , 15 );
535
- tmp1 = __lsx_vsat_w (b , 15 );
536
- return __lsx_vpickev_h (tmp1 , tmp );
537
- }
538
-
539
- static __m128i lsx_packs_h (__m128i a , __m128i b ) {
540
- __m128i tmp , tmp1 ;
541
- tmp = __lsx_vsat_h (a , 7 );
542
- tmp1 = __lsx_vsat_h (b , 7 );
543
- return __lsx_vpickev_b (tmp1 , tmp );
544
- }
545
-
546
- static __m128i lsx_packus_h (__m128i a , __m128i b ) {
547
- __m128i tmp , tmp1 ;
548
- tmp = __lsx_vsat_hu (a , 7 );
549
- tmp1 = __lsx_vsat_hu (b , 7 );
550
- return __lsx_vpickev_b (tmp1 , tmp );
551
- }
552
-
553
-
554
- static __m128i lsx_maddubs_h (__m128i a , __m128i b ) {
555
- __m128i tmp1 , tmp2 ;
556
- tmp1 = __lsx_vmulwev_h_b (a , b );
557
- tmp2 = __lsx_vmulwod_h_b (a , b );
558
- return __lsx_vsadd_h (tmp1 , tmp2 );
559
- }
560
-
561
- static __m128i lsx_madd_h (__m128i a , __m128i b ) {
562
- __m128i tmp1 , tmp2 ;
563
- tmp1 = __lsx_vmulwev_w_h (a , b );
564
- tmp2 = __lsx_vmulwod_w_h (a , b );
565
- return __lsx_vadd_w (tmp1 , tmp2 );
566
- }
567
-
568
580
// multiply int8_t, add results pairwise twice
569
581
static inline __m128i mul_sum_i8_pairs (const __m128i x , const __m128i y ) {
570
582
// Get absolute values of x vectors
@@ -2232,21 +2244,22 @@ void ggml_vec_dot_q4_0_q8_0(int n, float * restrict s, size_t bs, const void * r
2232
2244
}
2233
2245
2234
2246
sumf = hsum_float_8 (acc );
2247
+
2235
2248
#elif defined(__loongarch_sx )
2236
2249
// set constants
2237
2250
const __m128i low_mask = __lsx_vreplgr2vr_b (0xF );
2238
2251
const __m128i off = __lsx_vreplgr2vr_b (8 );
2239
2252
2240
2253
// Initialize accumulator with zeros
2241
- __m128 acc_0 = __lsx_vldi (0 );
2242
- __m128 acc_1 = __lsx_vldi (0 );
2243
- __m128 acc_2 = __lsx_vldi (0 );
2244
- __m128 acc_3 = __lsx_vldi (0 );
2254
+ __m128 acc_0 = ( __m128 ) __lsx_vldi (0 );
2255
+ __m128 acc_1 = ( __m128 ) __lsx_vldi (0 );
2256
+ __m128 acc_2 = ( __m128 ) __lsx_vldi (0 );
2257
+ __m128 acc_3 = ( __m128 ) __lsx_vldi (0 );
2245
2258
2246
2259
for (; ib + 1 < nb ; ib += 2 ) {
2247
2260
2248
2261
// Compute combined scale for the block 0 and 1
2249
- const __m128 d_0_1 = __lsx_vreplgr2vr_w ( GGML_FP16_TO_FP32 (x [ib ].d ) * GGML_FP16_TO_FP32 (y [ib ].d ) );
2262
+ const __m128 d_0_1 = ( __m128 ) __lsx_vreplgr2vr_w ( GGML_FP16_TO_FP32 (x [ib ].d ) * GGML_FP16_TO_FP32 (y [ib ].d ) );
2250
2263
2251
2264
const __m128i tmp_0_1 = __lsx_vld ((const __m128i * )x [ib ].qs , 0 );
2252
2265
@@ -2264,7 +2277,7 @@ void ggml_vec_dot_q4_0_q8_0(int n, float * restrict s, size_t bs, const void * r
2264
2277
//_mm_prefetch(&y[ib] + 2 * sizeof(block_q8_0), _MM_HINT_T0);
2265
2278
2266
2279
// Compute combined scale for the block 2 and 3
2267
- const __m128 d_2_3 = __lsx_vreplgr2vr_w ( GGML_FP16_TO_FP32 (x [ib + 1 ].d ) * GGML_FP16_TO_FP32 (y [ib + 1 ].d ) );
2280
+ const __m128 d_2_3 = ( __m128 ) __lsx_vreplgr2vr_w ( GGML_FP16_TO_FP32 (x [ib + 1 ].d ) * GGML_FP16_TO_FP32 (y [ib + 1 ].d ) );
2268
2281
2269
2282
const __m128i tmp_2_3 = __lsx_vld ((const __m128i * )x [ib + 1 ].qs , 0 );
2270
2283
0 commit comments