@@ -339,8 +339,9 @@ static float table_f32_f16[1 << 16];
339
339
#define B7 (c ,s ,n ) B6(c,s,n ## c), B6(c,s,n ## s)
340
340
#define B8 (c ,s ) B7(c,s, c), B7(c,s, s)
341
341
342
- // precomputed tables for expanding 8bits to 8 bytes (shl 4)
343
- static const uint64_t table_b2b_u [1 << 8 ] = { B8 (00 , 10 ) };
342
+ // precomputed tables for expanding 8bits to 8 bytes:
343
+ static const uint64_t table_b2b_0 [1 << 8 ] = { B8 (00 , 10 ) }; // ( b) << 4
344
+ static const uint64_t table_b2b_1 [1 << 8 ] = { B8 (10 , 00 ) }; // (!b) << 4
344
345
#endif
345
346
346
347
// On ARM NEON, it's quicker to directly convert x -> x instead of calling into ggml_lookup_fp16_to_fp32,
@@ -2307,68 +2308,102 @@ static void ggml_vec_dot_q5_0_q8_0(const int n, float * restrict s, const void *
2307
2308
const block_q8_0 * restrict y = vy ;
2308
2309
2309
2310
#if defined(__ARM_NEON )
2310
- float32x4_t sumv = vdupq_n_f32 (0.0f );
2311
+ float32x4_t sumv0 = vdupq_n_f32 (0.0f );
2312
+ float32x4_t sumv1 = vdupq_n_f32 (0.0f );
2311
2313
2312
- uint64_t tmp [4 ];
2314
+ uint32_t qh0 ;
2315
+ uint32_t qh1 ;
2313
2316
2314
- for (int i = 0 ; i < nb ; ++ i ) {
2317
+ uint64_t tmp0 [4 ];
2318
+ uint64_t tmp1 [4 ];
2319
+
2320
+ for (int i = 0 ; i < nb ; i += 2 ) {
2315
2321
const block_q5_0 * restrict x0 = & x [i ];
2322
+ const block_q5_0 * restrict x1 = & x [i + 1 ];
2316
2323
const block_q8_0 * restrict y0 = & y [i ];
2324
+ const block_q8_0 * restrict y1 = & y [i + 1 ];
2317
2325
2318
- const uint8x16_t m4b = vdupq_n_u8 (0x0F );
2319
- const int8x16_t s16b = vdupq_n_s8 (0x10 );
2326
+ const uint8x16_t m4b = vdupq_n_u8 (0x0F );
2320
2327
2321
- // extract the 5th bit
2322
- uint32_t qh ;
2323
- memcpy (& qh , x0 -> qh , sizeof (qh ));
2328
+ // extract the 5th bit via lookup table ((!b) << 4)
2329
+ memcpy ( & qh0 , x0 -> qh , sizeof ( qh0 )) ;
2330
+ memcpy (& qh1 , x1 -> qh , sizeof (qh1 ));
2324
2331
2325
- tmp [0 ] = table_b2b_u [( qh >> 0 ) & 0xFF ];
2326
- tmp [1 ] = table_b2b_u [( qh >> 8 ) & 0xFF ];
2327
- tmp [2 ] = table_b2b_u [( qh >> 16 ) & 0xFF ];
2328
- tmp [3 ] = table_b2b_u [( qh >> 24 ) ];
2332
+ tmp0 [0 ] = table_b2b_1 [( qh0 >> 0 ) & 0xFF ];
2333
+ tmp0 [1 ] = table_b2b_1 [( qh0 >> 8 ) & 0xFF ];
2334
+ tmp0 [2 ] = table_b2b_1 [( qh0 >> 16 ) & 0xFF ];
2335
+ tmp0 [3 ] = table_b2b_1 [( qh0 >> 24 ) ];
2329
2336
2330
- const int8x16_t qhl = vld1q_s8 ((const int8_t * )(tmp + 0 ));
2331
- const int8x16_t qhh = vld1q_s8 ((const int8_t * )(tmp + 2 ));
2337
+ tmp1 [0 ] = table_b2b_1 [(qh1 >> 0 ) & 0xFF ];
2338
+ tmp1 [1 ] = table_b2b_1 [(qh1 >> 8 ) & 0xFF ];
2339
+ tmp1 [2 ] = table_b2b_1 [(qh1 >> 16 ) & 0xFF ];
2340
+ tmp1 [3 ] = table_b2b_1 [(qh1 >> 24 ) ];
2332
2341
2333
- const uint8x16_t v0 = vld1q_u8 (x0 -> qs );
2342
+ const int8x16_t qhl0 = vld1q_s8 ((const int8_t * )(tmp0 + 0 ));
2343
+ const int8x16_t qhh0 = vld1q_s8 ((const int8_t * )(tmp0 + 2 ));
2344
+ const int8x16_t qhl1 = vld1q_s8 ((const int8_t * )(tmp1 + 0 ));
2345
+ const int8x16_t qhh1 = vld1q_s8 ((const int8_t * )(tmp1 + 2 ));
2346
+
2347
+ const uint8x16_t v0_0 = vld1q_u8 (x0 -> qs );
2348
+ const uint8x16_t v0_1 = vld1q_u8 (x1 -> qs );
2334
2349
2335
2350
// 4-bit -> 8-bit
2336
- const int8x16_t v0l = vreinterpretq_s8_u8 (vandq_u8 (v0 , m4b ));
2337
- const int8x16_t v0h = vreinterpretq_s8_u8 (vshrq_n_u8 (v0 , 4 ));
2351
+ int8x16_t v0_0l = vreinterpretq_s8_u8 (vandq_u8 (v0_0 , m4b ));
2352
+ int8x16_t v0_0h = vreinterpretq_s8_u8 (vshrq_n_u8 (v0_0 , 4 ));
2353
+ int8x16_t v0_1l = vreinterpretq_s8_u8 (vandq_u8 (v0_1 , m4b ));
2354
+ int8x16_t v0_1h = vreinterpretq_s8_u8 (vshrq_n_u8 (v0_1 , 4 ));
2338
2355
2339
- // add high bit and sub 16
2340
- const int8x16_t v0lf = vsubq_s8 (vorrq_s8 (v0l , qhl ), s16b );
2341
- const int8x16_t v0hf = vsubq_s8 (vorrq_s8 (v0h , qhh ), s16b );
2356
+ // add high bit and sub 16 (equivalent to sub 0x10 when bit is zero)
2357
+ const int8x16_t v0_0lf = vsubq_s8 (v0_0l , qhl0 );
2358
+ const int8x16_t v0_0hf = vsubq_s8 (v0_0h , qhh0 );
2359
+ const int8x16_t v0_1lf = vsubq_s8 (v0_1l , qhl1 );
2360
+ const int8x16_t v0_1hf = vsubq_s8 (v0_1h , qhh1 );
2342
2361
2343
2362
// load y
2344
- const int8x16_t v1l = vld1q_s8 (y0 -> qs );
2345
- const int8x16_t v1h = vld1q_s8 (y0 -> qs + 16 );
2363
+ const int8x16_t v1_0l = vld1q_s8 (y0 -> qs );
2364
+ const int8x16_t v1_0h = vld1q_s8 (y0 -> qs + 16 );
2365
+ const int8x16_t v1_1l = vld1q_s8 (y1 -> qs );
2366
+ const int8x16_t v1_1h = vld1q_s8 (y1 -> qs + 16 );
2346
2367
2347
2368
const float x0d = GGML_FP16_TO_FP32 (x0 -> d );
2369
+ const float x1d = GGML_FP16_TO_FP32 (x1 -> d );
2348
2370
2349
2371
#if defined(__ARM_FEATURE_DOTPROD )
2350
- sumv = vmlaq_n_f32 (sumv , vcvtq_f32_s32 (vaddq_s32 (
2351
- vdotq_s32 (vdupq_n_s32 (0 ), v0lf , v1l ),
2352
- vdotq_s32 (vdupq_n_s32 (0 ), v0hf , v1h ))), x0d * y0 -> d );
2372
+ sumv0 = vmlaq_n_f32 (sumv0 , vcvtq_f32_s32 (vaddq_s32 (
2373
+ vdotq_s32 (vdupq_n_s32 (0 ), v0_0lf , v1_0l ),
2374
+ vdotq_s32 (vdupq_n_s32 (0 ), v0_0hf , v1_0h ))), x0d * y0 -> d );
2375
+ sumv1 = vmlaq_n_f32 (sumv1 , vcvtq_f32_s32 (vaddq_s32 (
2376
+ vdotq_s32 (vdupq_n_s32 (0 ), v0_1lf , v1_1l ),
2377
+ vdotq_s32 (vdupq_n_s32 (0 ), v0_1hf , v1_1h ))), x1d * y1 -> d );
2353
2378
#else
2354
- const int16x8_t pl0l = vmull_s8 (vget_low_s8 (v0lf ), vget_low_s8 (v1l ));
2355
- const int16x8_t pl0h = vmull_s8 (vget_high_s8 (v0lf ), vget_high_s8 (v1l ));
2356
- const int16x8_t ph0l = vmull_s8 (vget_low_s8 (v0hf ), vget_low_s8 (v1h ));
2357
- const int16x8_t ph0h = vmull_s8 (vget_high_s8 (v0hf ), vget_high_s8 (v1h ));
2379
+ const int16x8_t pl0l = vmull_s8 (vget_low_s8 (v0_0lf ), vget_low_s8 (v1_0l ));
2380
+ const int16x8_t pl0h = vmull_s8 (vget_high_s8 (v0_0lf ), vget_high_s8 (v1_0l ));
2381
+ const int16x8_t ph0l = vmull_s8 (vget_low_s8 (v0_0hf ), vget_low_s8 (v1_0h ));
2382
+ const int16x8_t ph0h = vmull_s8 (vget_high_s8 (v0_0hf ), vget_high_s8 (v1_0h ));
2383
+
2384
+ const int16x8_t pl1l = vmull_s8 (vget_low_s8 (v0_1lf ), vget_low_s8 (v1_1l ));
2385
+ const int16x8_t pl1h = vmull_s8 (vget_high_s8 (v0_1lf ), vget_high_s8 (v1_1l ));
2386
+ const int16x8_t ph1l = vmull_s8 (vget_low_s8 (v0_1hf ), vget_low_s8 (v1_1h ));
2387
+ const int16x8_t ph1h = vmull_s8 (vget_high_s8 (v0_1hf ), vget_high_s8 (v1_1h ));
2358
2388
2359
2389
const int32x4_t pl0 = vaddq_s32 (vpaddlq_s16 (pl0l ), vpaddlq_s16 (pl0h ));
2360
2390
const int32x4_t ph0 = vaddq_s32 (vpaddlq_s16 (ph0l ), vpaddlq_s16 (ph0h ));
2391
+ const int32x4_t pl1 = vaddq_s32 (vpaddlq_s16 (pl1l ), vpaddlq_s16 (pl1h ));
2392
+ const int32x4_t ph1 = vaddq_s32 (vpaddlq_s16 (ph1l ), vpaddlq_s16 (ph1h ));
2361
2393
2362
- sumv = vmlaq_n_f32 (sumv , vcvtq_f32_s32 (vaddq_s32 (pl0 , ph0 )), x0d * y0 -> d );
2394
+ sumv0 = vmlaq_n_f32 (sumv0 , vcvtq_f32_s32 (vaddq_s32 (pl0 , ph0 )), x0d * y0 -> d );
2395
+ sumv1 = vmlaq_n_f32 (sumv1 , vcvtq_f32_s32 (vaddq_s32 (pl1 , ph1 )), x1d * y1 -> d );
2363
2396
#endif
2364
2397
}
2365
2398
2366
- * s = vaddvq_f32 (sumv );
2399
+ * s = vaddvq_f32 (sumv0 ) + vaddvq_f32 ( sumv1 );
2367
2400
#elif defined(__wasm_simd128__ )
2368
2401
v128_t sumv = wasm_f32x4_splat (0.0f );
2369
2402
2403
+ uint32_t qh ;
2370
2404
uint64_t tmp [4 ];
2371
2405
2406
+ // TODO: check if unrolling this is better
2372
2407
for (int i = 0 ; i < nb ; ++ i ) {
2373
2408
const block_q5_0 * restrict x0 = & x [i ];
2374
2409
const block_q8_0 * restrict y0 = & y [i ];
@@ -2377,13 +2412,12 @@ static void ggml_vec_dot_q5_0_q8_0(const int n, float * restrict s, const void *
2377
2412
const v128_t s16b = wasm_i8x16_splat (0x10 );
2378
2413
2379
2414
// extract the 5th bit
2380
- uint32_t qh ;
2381
2415
memcpy (& qh , x0 -> qh , sizeof (qh ));
2382
2416
2383
- tmp [0 ] = table_b2b_u [(qh >> 0 ) & 0xFF ];
2384
- tmp [1 ] = table_b2b_u [(qh >> 8 ) & 0xFF ];
2385
- tmp [2 ] = table_b2b_u [(qh >> 16 ) & 0xFF ];
2386
- tmp [3 ] = table_b2b_u [(qh >> 24 ) ];
2417
+ tmp [0 ] = table_b2b_1 [(qh >> 0 ) & 0xFF ];
2418
+ tmp [1 ] = table_b2b_1 [(qh >> 8 ) & 0xFF ];
2419
+ tmp [2 ] = table_b2b_1 [(qh >> 16 ) & 0xFF ];
2420
+ tmp [3 ] = table_b2b_1 [(qh >> 24 ) ];
2387
2421
2388
2422
const v128_t qhl = wasm_v128_load (tmp + 0 );
2389
2423
const v128_t qhh = wasm_v128_load (tmp + 2 );
@@ -2395,8 +2429,8 @@ static void ggml_vec_dot_q5_0_q8_0(const int n, float * restrict s, const void *
2395
2429
const v128_t v0h = wasm_u8x16_shr (v0 , 4 );
2396
2430
2397
2431
// add high bit and sub 16
2398
- const v128_t v0lf = wasm_i8x16_sub (wasm_v128_or ( v0l , qhl ), s16b );
2399
- const v128_t v0hf = wasm_i8x16_sub (wasm_v128_or ( v0h , qhh ), s16b );
2432
+ const v128_t v0lf = wasm_i8x16_sub (v0l , qhl );
2433
+ const v128_t v0hf = wasm_i8x16_sub (v0h , qhh );
2400
2434
2401
2435
// load y
2402
2436
const v128_t v1l = wasm_v128_load (y0 -> qs );
@@ -2488,69 +2522,107 @@ static void ggml_vec_dot_q5_1_q8_1(const int n, float * restrict s, const void *
2488
2522
const block_q8_1 * restrict y = vy ;
2489
2523
2490
2524
#if defined(__ARM_NEON )
2491
- float32x4_t sumv = vdupq_n_f32 (0.0f );
2525
+ float32x4_t sumv0 = vdupq_n_f32 (0.0f );
2526
+ float32x4_t sumv1 = vdupq_n_f32 (0.0f );
2492
2527
2493
- float summs = 0.0f ;
2528
+ float summs0 = 0.0f ;
2529
+ float summs1 = 0.0f ;
2494
2530
2495
- uint64_t tmp [4 ];
2531
+ uint32_t qh0 ;
2532
+ uint32_t qh1 ;
2496
2533
2497
- for (int i = 0 ; i < nb ; ++ i ) {
2534
+ uint64_t tmp0 [4 ];
2535
+ uint64_t tmp1 [4 ];
2536
+
2537
+ for (int i = 0 ; i < nb ; i += 2 ) {
2498
2538
const block_q5_1 * restrict x0 = & x [i ];
2539
+ const block_q5_1 * restrict x1 = & x [i + 1 ];
2499
2540
const block_q8_1 * restrict y0 = & y [i ];
2541
+ const block_q8_1 * restrict y1 = & y [i + 1 ];
2500
2542
2501
- summs += GGML_FP16_TO_FP32 ( x0 -> m ) * ( y0 -> s0 + y0 -> s1 );
2543
+ const uint8x16_t m4b = vdupq_n_u8 ( 0x0F );
2502
2544
2503
- // extract the 5th bit
2504
- uint32_t qh ;
2505
- memcpy (& qh , x0 -> qh , sizeof (qh ));
2545
+ summs0 += GGML_FP16_TO_FP32 (x0 -> m ) * (y0 -> s0 + y0 -> s1 );
2546
+ summs1 += GGML_FP16_TO_FP32 (x1 -> m ) * (y1 -> s0 + y1 -> s1 );
2547
+
2548
+ // extract the 5th bit via lookup table ((b) << 4)
2549
+ memcpy (& qh0 , x0 -> qh , sizeof (qh0 ));
2550
+ memcpy (& qh1 , x1 -> qh , sizeof (qh1 ));
2506
2551
2507
- tmp [0 ] = table_b2b_u [( qh >> 0 ) & 0xFF ];
2508
- tmp [1 ] = table_b2b_u [( qh >> 8 ) & 0xFF ];
2509
- tmp [2 ] = table_b2b_u [( qh >> 16 ) & 0xFF ];
2510
- tmp [3 ] = table_b2b_u [( qh >> 24 ) ];
2552
+ tmp0 [0 ] = table_b2b_0 [( qh0 >> 0 ) & 0xFF ];
2553
+ tmp0 [1 ] = table_b2b_0 [( qh0 >> 8 ) & 0xFF ];
2554
+ tmp0 [2 ] = table_b2b_0 [( qh0 >> 16 ) & 0xFF ];
2555
+ tmp0 [3 ] = table_b2b_0 [( qh0 >> 24 ) ];
2511
2556
2512
- const int8x16_t qhl = vld1q_s8 ((const int8_t * )(tmp + 0 ));
2513
- const int8x16_t qhh = vld1q_s8 ((const int8_t * )(tmp + 2 ));
2557
+ tmp1 [0 ] = table_b2b_0 [(qh1 >> 0 ) & 0xFF ];
2558
+ tmp1 [1 ] = table_b2b_0 [(qh1 >> 8 ) & 0xFF ];
2559
+ tmp1 [2 ] = table_b2b_0 [(qh1 >> 16 ) & 0xFF ];
2560
+ tmp1 [3 ] = table_b2b_0 [(qh1 >> 24 ) ];
2514
2561
2515
- const uint8x16_t v0 = vld1q_u8 (x0 -> qs );
2562
+ const int8x16_t qhl0 = vld1q_s8 ((const int8_t * )(tmp0 + 0 ));
2563
+ const int8x16_t qhh0 = vld1q_s8 ((const int8_t * )(tmp0 + 2 ));
2564
+ const int8x16_t qhl1 = vld1q_s8 ((const int8_t * )(tmp1 + 0 ));
2565
+ const int8x16_t qhh1 = vld1q_s8 ((const int8_t * )(tmp1 + 2 ));
2566
+
2567
+ const uint8x16_t v0_0 = vld1q_u8 (x0 -> qs );
2568
+ const uint8x16_t v0_1 = vld1q_u8 (x1 -> qs );
2516
2569
2517
2570
// 4-bit -> 8-bit
2518
- const int8x16_t v0l = vreinterpretq_s8_u8 (vandq_u8 (v0 , vdupq_n_u8 (0x0F )));
2519
- const int8x16_t v0h = vreinterpretq_s8_u8 (vshrq_n_u8 (v0 , 4 ));
2571
+ const int8x16_t v0_0l = vreinterpretq_s8_u8 (vandq_u8 (v0_0 , m4b ));
2572
+ const int8x16_t v0_0h = vreinterpretq_s8_u8 (vshrq_n_u8 (v0_0 , 4 ));
2573
+ const int8x16_t v0_1l = vreinterpretq_s8_u8 (vandq_u8 (v0_1 , m4b ));
2574
+ const int8x16_t v0_1h = vreinterpretq_s8_u8 (vshrq_n_u8 (v0_1 , 4 ));
2520
2575
2521
- // add
2522
- const int8x16_t v0lf = vorrq_s8 (v0l , qhl );
2523
- const int8x16_t v0hf = vorrq_s8 (v0h , qhh );
2576
+ // add 5th bit
2577
+ const int8x16_t v0_0lf = vorrq_s8 (v0_0l , qhl0 );
2578
+ const int8x16_t v0_0hf = vorrq_s8 (v0_0h , qhh0 );
2579
+ const int8x16_t v0_1lf = vorrq_s8 (v0_1l , qhl1 );
2580
+ const int8x16_t v0_1hf = vorrq_s8 (v0_1h , qhh1 );
2524
2581
2525
2582
// load y
2526
- const int8x16_t v1l = vld1q_s8 (y0 -> qs );
2527
- const int8x16_t v1h = vld1q_s8 (y0 -> qs + 16 );
2583
+ const int8x16_t v1_0l = vld1q_s8 (y0 -> qs );
2584
+ const int8x16_t v1_0h = vld1q_s8 (y0 -> qs + 16 );
2585
+ const int8x16_t v1_1l = vld1q_s8 (y1 -> qs );
2586
+ const int8x16_t v1_1h = vld1q_s8 (y1 -> qs + 16 );
2528
2587
2529
2588
const float x0d = GGML_FP16_TO_FP32 (x0 -> d );
2589
+ const float x1d = GGML_FP16_TO_FP32 (x1 -> d );
2530
2590
2531
2591
#if defined(__ARM_FEATURE_DOTPROD )
2532
- sumv = vmlaq_n_f32 (sumv , vcvtq_f32_s32 (vaddq_s32 (
2533
- vdotq_s32 (vdupq_n_s32 (0 ), v0lf , v1l ),
2534
- vdotq_s32 (vdupq_n_s32 (0 ), v0hf , v1h ))), x0d * y0 -> d );
2592
+ sumv0 = vmlaq_n_f32 (sumv0 , vcvtq_f32_s32 (vaddq_s32 (
2593
+ vdotq_s32 (vdupq_n_s32 (0 ), v0_0lf , v1_0l ),
2594
+ vdotq_s32 (vdupq_n_s32 (0 ), v0_0hf , v1_0h ))), x0d * y0 -> d );
2595
+ sumv1 = vmlaq_n_f32 (sumv1 , vcvtq_f32_s32 (vaddq_s32 (
2596
+ vdotq_s32 (vdupq_n_s32 (0 ), v0_1lf , v1_1l ),
2597
+ vdotq_s32 (vdupq_n_s32 (0 ), v0_1hf , v1_1h ))), x1d * y1 -> d );
2535
2598
#else
2536
- const int16x8_t pl0l = vmull_s8 (vget_low_s8 (v0lf ), vget_low_s8 (v1l ));
2537
- const int16x8_t pl0h = vmull_s8 (vget_high_s8 (v0lf ), vget_high_s8 (v1l ));
2538
- const int16x8_t ph0l = vmull_s8 (vget_low_s8 (v0hf ), vget_low_s8 (v1h ));
2539
- const int16x8_t ph0h = vmull_s8 (vget_high_s8 (v0hf ), vget_high_s8 (v1h ));
2599
+ const int16x8_t pl0l = vmull_s8 (vget_low_s8 (v0_0lf ), vget_low_s8 (v1_0l ));
2600
+ const int16x8_t pl0h = vmull_s8 (vget_high_s8 (v0_0lf ), vget_high_s8 (v1_0l ));
2601
+ const int16x8_t ph0l = vmull_s8 (vget_low_s8 (v0_0hf ), vget_low_s8 (v1_0h ));
2602
+ const int16x8_t ph0h = vmull_s8 (vget_high_s8 (v0_0hf ), vget_high_s8 (v1_0h ));
2603
+
2604
+ const int16x8_t pl1l = vmull_s8 (vget_low_s8 (v0_1lf ), vget_low_s8 (v1_1l ));
2605
+ const int16x8_t pl1h = vmull_s8 (vget_high_s8 (v0_1lf ), vget_high_s8 (v1_1l ));
2606
+ const int16x8_t ph1l = vmull_s8 (vget_low_s8 (v0_1hf ), vget_low_s8 (v1_1h ));
2607
+ const int16x8_t ph1h = vmull_s8 (vget_high_s8 (v0_1hf ), vget_high_s8 (v1_1h ));
2540
2608
2541
2609
const int32x4_t pl0 = vaddq_s32 (vpaddlq_s16 (pl0l ), vpaddlq_s16 (pl0h ));
2542
2610
const int32x4_t ph0 = vaddq_s32 (vpaddlq_s16 (ph0l ), vpaddlq_s16 (ph0h ));
2611
+ const int32x4_t pl1 = vaddq_s32 (vpaddlq_s16 (pl1l ), vpaddlq_s16 (pl1h ));
2612
+ const int32x4_t ph1 = vaddq_s32 (vpaddlq_s16 (ph1l ), vpaddlq_s16 (ph1h ));
2543
2613
2544
- sumv = vmlaq_n_f32 (sumv , vcvtq_f32_s32 (vaddq_s32 (pl0 , ph0 )), x0d * y0 -> d );
2614
+ sumv0 = vmlaq_n_f32 (sumv0 , vcvtq_f32_s32 (vaddq_s32 (pl0 , ph0 )), x0d * y0 -> d );
2615
+ sumv1 = vmlaq_n_f32 (sumv1 , vcvtq_f32_s32 (vaddq_s32 (pl1 , ph1 )), x1d * y1 -> d );
2545
2616
#endif
2546
2617
}
2547
2618
2548
- * s = vaddvq_f32 (sumv ) + summs ;
2619
+ * s = vaddvq_f32 (sumv0 ) + vaddvq_f32 ( sumv1 ) + summs0 + summs1 ;
2549
2620
#elif defined(__wasm_simd128__ )
2550
2621
v128_t sumv = wasm_f32x4_splat (0.0f );
2551
2622
2552
2623
float summs = 0.0f ;
2553
2624
2625
+ uint32_t qh ;
2554
2626
uint64_t tmp [4 ];
2555
2627
2556
2628
for (int i = 0 ; i < nb ; ++ i ) {
@@ -2562,13 +2634,12 @@ static void ggml_vec_dot_q5_1_q8_1(const int n, float * restrict s, const void *
2562
2634
const v128_t m4b = wasm_i8x16_splat (0x0F );
2563
2635
2564
2636
// extract the 5th bit
2565
- uint32_t qh ;
2566
2637
memcpy (& qh , x0 -> qh , sizeof (qh ));
2567
2638
2568
- tmp [0 ] = table_b2b_u [(qh >> 0 ) & 0xFF ];
2569
- tmp [1 ] = table_b2b_u [(qh >> 8 ) & 0xFF ];
2570
- tmp [2 ] = table_b2b_u [(qh >> 16 ) & 0xFF ];
2571
- tmp [3 ] = table_b2b_u [(qh >> 24 ) ];
2639
+ tmp [0 ] = table_b2b_0 [(qh >> 0 ) & 0xFF ];
2640
+ tmp [1 ] = table_b2b_0 [(qh >> 8 ) & 0xFF ];
2641
+ tmp [2 ] = table_b2b_0 [(qh >> 16 ) & 0xFF ];
2642
+ tmp [3 ] = table_b2b_0 [(qh >> 24 ) ];
2572
2643
2573
2644
const v128_t qhl = wasm_v128_load (tmp + 0 );
2574
2645
const v128_t qhh = wasm_v128_load (tmp + 2 );
0 commit comments