Skip to content

Commit 55dec7c

Browse files
committed
add neon impl
1 parent cf4fa0c commit 55dec7c

File tree

1 file changed

+33
-4
lines changed

1 file changed

+33
-4
lines changed

ggml-quants.c

Lines changed: 33 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -12412,7 +12412,7 @@ static bool isnan_fp16(ggml_fp16_t f) {
1241212412
return (f & 0x7c00) == 0x7c00 && (f & 0x03ff) != 0;
1241312413
}
1241412414

12415-
static inline bool validate_fp16(ggml_fp16_t f, size_t i) {
12415+
static bool validate_fp16(ggml_fp16_t f, size_t i) {
1241612416
if (isinf_fp16(f)) {
1241712417
fprintf(stderr, "ggml_validate_row_data: found inf value at block %zu\n", i);
1241812418
return false;
@@ -12448,7 +12448,6 @@ bool ggml_validate_row_data(enum ggml_type type, const void * data, size_t nbyte
1244812448
return false;
1244912449
}
1245012450

12451-
// size check
1245212451
if (nbytes % ggml_type_size(type) != 0) {
1245312452
fprintf(stderr, "%s: invalid size %zu for type %d\n", __func__, nbytes, type);
1245412453
return false;
@@ -12461,7 +12460,7 @@ bool ggml_validate_row_data(enum ggml_type type, const void * data, size_t nbyte
1246112460
{
1246212461
const ggml_fp16_t * f = (const ggml_fp16_t *) data;
1246312462
size_t i = 0;
12464-
#ifdef __AVX2__
12463+
#if defined(__AVX2__)
1246512464
for (; i + 15 < nb; i += 16) {
1246612465
__m256i v = _mm256_loadu_si256((const __m256i *)(f + i));
1246712466
__m256i vexp = _mm256_and_si256(v, _mm256_set1_epi16(0x7c00));
@@ -12476,6 +12475,21 @@ bool ggml_validate_row_data(enum ggml_type type, const void * data, size_t nbyte
1247612475
GGML_UNREACHABLE();
1247712476
}
1247812477
}
12478+
#elif defined(__ARM_NEON)
12479+
for (; i + 7 < nb; i += 8) {
12480+
uint16x8_t v = vld1q_u16(f + i);
12481+
uint16x8_t vexp = vandq_u16(v, vdupq_n_u16(0x7c00));
12482+
uint16x8_t cmp = vceqq_u16(vexp, vdupq_n_u16(0x7c00));
12483+
uint64_t mask = vget_lane_u64(vreinterpret_u64_u8(vshrn_n_u16(cmp, 4)), 0);
12484+
if (mask) {
12485+
for (size_t j = 0; j < 8; ++j) {
12486+
if (!validate_fp16(f[i + j], i + j)) {
12487+
return false;
12488+
}
12489+
}
12490+
GGML_UNREACHABLE();
12491+
}
12492+
}
1247912493
#endif
1248012494
for (; i < nb; ++i) {
1248112495
if (!validate_fp16(f[i], i)) {
@@ -12487,7 +12501,7 @@ bool ggml_validate_row_data(enum ggml_type type, const void * data, size_t nbyte
1248712501
{
1248812502
const float * f = (const float *) data;
1248912503
size_t i = 0;
12490-
#ifdef __AVX2__
12504+
#if defined(__AVX2__)
1249112505
for (; i + 7 < nb; i += 8) {
1249212506
__m256i v = _mm256_loadu_si256((const __m256i *)(f + i));
1249312507
__m256i vexp = _mm256_and_si256(v, _mm256_set1_epi32(0x7f800000));
@@ -12502,6 +12516,21 @@ bool ggml_validate_row_data(enum ggml_type type, const void * data, size_t nbyte
1250212516
GGML_UNREACHABLE();
1250312517
}
1250412518
}
12519+
#elif defined(__ARM_NEON)
12520+
for (; i + 3 < nb; i += 4) {
12521+
uint32x4_t v = vld1q_u32((const uint32_t *)f + i);
12522+
uint32x4_t vexp = vandq_u32(v, vdupq_n_u32(0x7f800000));
12523+
uint32x4_t cmp = vceqq_u32(vexp, vdupq_n_u32(0x7f800000));
12524+
uint64_t mask = vget_lane_u64(vreinterpret_u64_u8(vshrn_n_u32(cmp, 8)), 0);
12525+
if (mask) {
12526+
for (size_t j = 0; j < 4; ++j) {
12527+
if (!validate_float(f[i + j], i + j)) {
12528+
return false;
12529+
}
12530+
}
12531+
GGML_UNREACHABLE();
12532+
}
12533+
}
1250512534
#endif
1250612535
for (; i < nb; ++i) {
1250712536
if (!validate_float(f[i], i)) {

0 commit comments

Comments
 (0)