@@ -604,22 +604,24 @@ static void quantize_row_q4_0(const float * restrict x, void * restrict vy, int
604
604
#elif __ARM_NEON
605
605
for (int i = 0 ; i < nb ; i ++ ) {
606
606
float32x4_t srcv [8 ];
607
- float32x4_t asrcv [8 ];
608
- float32x4_t amaxv [8 ];
607
+ float32x4_t maxv [8 ];
608
+ float32x4_t minv [8 ];
609
609
610
610
for (int l = 0 ; l < 8 ; l ++ ) srcv [l ] = vld1q_f32 (x + i * 32 + 4 * l );
611
- for (int l = 0 ; l < 8 ; l ++ ) asrcv [l ] = vabsq_f32 (srcv [l ]);
612
611
613
- for (int l = 0 ; l < 4 ; l ++ ) amaxv [2 * l ] = vmaxq_f32 (asrcv [2 * l ], asrcv [2 * l + 1 ]);
614
- for (int l = 0 ; l < 2 ; l ++ ) amaxv [4 * l ] = vmaxq_f32 (amaxv [4 * l ], amaxv [4 * l + 2 ]);
615
- for (int l = 0 ; l < 1 ; l ++ ) amaxv [8 * l ] = vmaxq_f32 (amaxv [8 * l ], amaxv [8 * l + 4 ]);
612
+ for (int l = 0 ; l < 4 ; l ++ ) maxv [2 * l ] = vmaxq_f32 (srcv [2 * l ], srcv [2 * l + 1 ]);
613
+ for (int l = 0 ; l < 2 ; l ++ ) maxv [4 * l ] = vmaxq_f32 (maxv [4 * l ], maxv [4 * l + 2 ]);
614
+ for (int l = 0 ; l < 1 ; l ++ ) maxv [8 * l ] = vmaxq_f32 (maxv [8 * l ], maxv [8 * l + 4 ]);
616
615
617
- // absolute max
618
- const float amax = MAX (
619
- MAX (vgetq_lane_f32 (amaxv [0 ], 0 ), vgetq_lane_f32 (amaxv [0 ], 1 )),
620
- MAX (vgetq_lane_f32 (amaxv [0 ], 2 ), vgetq_lane_f32 (amaxv [0 ], 3 )));
616
+ for (int l = 0 ; l < 4 ; l ++ ) minv [2 * l ] = vminq_f32 (srcv [2 * l ], srcv [2 * l + 1 ]);
617
+ for (int l = 0 ; l < 2 ; l ++ ) minv [4 * l ] = vminq_f32 (minv [4 * l ], minv [4 * l + 2 ]);
618
+ for (int l = 0 ; l < 1 ; l ++ ) minv [8 * l ] = vminq_f32 (minv [8 * l ], minv [8 * l + 4 ]);
621
619
622
- const float d = amax / ((1 << 3 ) - 1 );
620
+ const float max = vmaxvq_f32 (maxv [0 ]);
621
+ const float min = vminvq_f32 (minv [0 ]);
622
+
623
+ const float magnitude = max >= fabsf (min ) ? max : min ;
624
+ const float d = magnitude / -8 ;
623
625
const float id = d ? 1.0f /d : 0.0f ;
624
626
625
627
y [i ].d = d ;
@@ -628,9 +630,10 @@ static void quantize_row_q4_0(const float * restrict x, void * restrict vy, int
628
630
const float32x4_t v = vmulq_n_f32 (srcv [l ], id );
629
631
const float32x4_t vf = vaddq_f32 (v , vdupq_n_f32 (8.5f ));
630
632
const int32x4_t vi = vcvtq_s32_f32 (vf );
633
+ const int32x4 vc = vminq_u32 (vi , vdupq_n_u32 (15 ));
631
634
632
- y [i ].qs [2 * l + 0 ] = vgetq_lane_s32 (vi , 0 ) | (vgetq_lane_s32 (vi , 1 ) << 4 );
633
- y [i ].qs [2 * l + 1 ] = vgetq_lane_s32 (vi , 2 ) | (vgetq_lane_s32 (vi , 3 ) << 4 );
635
+ y [i ].qs [2 * l + 0 ] = vgetq_lane_s32 (vc , 0 ) | (vgetq_lane_s32 (vc , 1 ) << 4 );
636
+ y [i ].qs [2 * l + 1 ] = vgetq_lane_s32 (vc , 2 ) | (vgetq_lane_s32 (vc , 3 ) << 4 );
634
637
}
635
638
}
636
639
#elif defined(__AVX2__ )
0 commit comments