@@ -565,28 +565,42 @@ static void quantize_row_q4_0(const float * restrict x, void * restrict vy, int
565
565
566
566
#if defined(__POWER9_VECTOR__ )
567
567
const vector float v85 = vec_splats (8.5f );
568
+ const vector signed int v15 = vec_splats (15 );
568
569
for (int i = 0 ; i < nb ; i ++ ) {
569
- float amax = 0.0f ; // absolute max
570
+ float max = 0.0f ;
571
+ float min = 0.0f ;
570
572
571
573
vector float srcv [8 ];
572
- vector float asrcv [8 ];
573
- vector float amaxv [8 ];
574
+ vector float maxv [8 ];
575
+ vector float minv [8 ];
574
576
575
577
for (int l = 0 ; l < 8 ; l ++ ) srcv [l ] = * (vector float * )(x + i * 32 + 4 * l );
576
- for (int l = 0 ; l < 8 ; l ++ ) asrcv [l ] = vec_abs (srcv [l ]);
578
+ // for (int l = 0; l < 8; l++) asrcv[l] = vec_abs(srcv[l]);
577
579
578
- for (int l = 0 ; l < 4 ; l ++ ) amaxv [2 * l ] = vec_max (asrcv [2 * l ], asrcv [2 * l + 1 ]);
579
- //for (int l = 0; l < 2; l++) amaxv [4*l] = vec_max(amaxv [4*l], amaxv [4*l+2]);
580
- amaxv [0 ] = vec_max (amaxv [0 ], amaxv [2 ]);
581
- amaxv [4 ] = vec_max (amaxv [4 ], amaxv [6 ]);
582
- //for (int l = 0; l < 1; l++) amaxv [8*l] = vec_max(amaxv [8*l], amaxv [8*l+4]);
583
- amaxv [0 ] = vec_max (amaxv [0 ], amaxv [4 ]);
580
+ for (int l = 0 ; l < 4 ; l ++ ) maxv [2 * l ] = vec_max (asrcv [2 * l ], asrcv [2 * l + 1 ]);
581
+ //for (int l = 0; l < 2; l++) maxv [4*l] = vec_max(maxv [4*l], maxv [4*l+2]);
582
+ maxv [0 ] = vec_max (maxv [0 ], maxv [2 ]);
583
+ maxv [4 ] = vec_max (maxv [4 ], maxv [6 ]);
584
+ //for (int l = 0; l < 1; l++) maxv [8*l] = vec_max(maxv [8*l], maxv [8*l+4]);
585
+ maxv [0 ] = vec_max (maxv [0 ], maxv [4 ]);
584
586
585
- amax = MAX (
586
- MAX (vec_extract (amaxv [0 ], 0 ), vec_extract (amaxv [0 ], 1 )),
587
- MAX (vec_extract (amaxv [0 ], 2 ), vec_extract (amaxv [0 ], 3 )));
587
+ for (int l = 0 ; l < 4 ; l ++ ) minv [2 * l ] = vec_min (asrcv [2 * l ], asrcv [2 * l + 1 ]);
588
+ //for (int l = 0; l < 2; l++) minv[4*l] = vec_min(minv[4*l], minv[4*l+2]);
589
+ minv [0 ] = vec_min (minv [0 ], minv [2 ]);
590
+ minv [4 ] = vec_min (minv [4 ], minv [6 ]);
591
+ //for (int l = 0; l < 1; l++) minv[8*l] = vec_min(minv[8*l], minv[8*l+4]);
592
+ minv [0 ] = vec_min (minv [0 ], minv [4 ]);
588
593
589
- const float d = amax / ((1 << 3 ) - 1 );
594
+
595
+ max = MAX (
596
+ MAX (vec_extract (maxv [0 ], 0 ), vec_extract (maxv [0 ], 1 )),
597
+ MAX (vec_extract (maxv [0 ], 2 ), vec_extract (maxv [0 ], 3 )));
598
+ min = MIN (
599
+ MIN (vec_extract (minv [0 ], 0 ), vec_extract (minv [0 ], 1 )),
600
+ MIN (vec_extract (minv [0 ], 2 ), vec_extract (minv [0 ], 3 )));
601
+
602
+ const float magnitude = max >= fabsf (min ) ? max : min ;
603
+ const float d = magnitude / -8 ;
590
604
const float id = d ? 1.0 /d : 0.0 ;
591
605
592
606
y [i ].d = d ;
@@ -596,9 +610,10 @@ static void quantize_row_q4_0(const float * restrict x, void * restrict vy, int
596
610
for (int l = 0 ; l < 8 ; l ++ ) {
597
611
const vector float vf = vec_madd (srcv [l ], vid , v85 );
598
612
const vector signed int vi = vec_signed (vf );
613
+ const vector signed int vc = vec_min (vi , v15 );
599
614
600
- pb [2 * l + 0 ] = vec_extract (vi , 0 ) | (vec_extract (vi , 1 ) << 4 );
601
- pb [2 * l + 1 ] = vec_extract (vi , 2 ) | (vec_extract (vi , 3 ) << 4 );
615
+ pb [2 * l + 0 ] = vec_extract (vc , 0 ) | (vec_extract (vc , 1 ) << 4 );
616
+ pb [2 * l + 1 ] = vec_extract (vc , 2 ) | (vec_extract (vc , 3 ) << 4 );
602
617
}
603
618
}
604
619
#elif __ARM_NEON
0 commit comments