Skip to content

Commit c806db3

Browse files
committed
improve fp16 validation performance
1 parent 6aea16e commit c806db3

File tree

1 file changed

+61
-9
lines changed

1 file changed

+61
-9
lines changed

ggml-quants.c

Lines changed: 61 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -12404,22 +12404,40 @@ static bool validate_float(float f, size_t i) {
1240412404
return true;
1240512405
}
1240612406

12407-
static bool validate_f16(ggml_fp16_t f, size_t i) {
12408-
return validate_float(GGML_FP16_TO_FP32(f), i);
12407+
static bool isinf_fp16(ggml_fp16_t f) {
12408+
return (f & 0x7c00) == 0x7c00 && (f & 0x03ff) == 0;
12409+
}
12410+
12411+
static bool isnan_fp16(ggml_fp16_t f) {
12412+
return (f & 0x7c00) == 0x7c00 && (f & 0x03ff) != 0;
12413+
}
12414+
12415+
static inline bool validate_fp16(ggml_fp16_t f, size_t i) {
12416+
if (isinf_fp16(f)) {
12417+
fprintf(stderr, "ggml_validate_row_data: found inf value at block %zu\n", i);
12418+
return false;
12419+
}
12420+
12421+
if (isnan_fp16(f)) {
12422+
fprintf(stderr, "ggml_validate_row_data: found nan value at block %zu\n", i);
12423+
return false;
12424+
}
12425+
12426+
return true;
1240912427
}
1241012428

1241112429
#define VALIDATE_ROW_DATA_D_F16_IMPL(type, data, nb) \
1241212430
const type * q = (const type *) (data); \
1241312431
for (size_t i = 0; i < (nb); ++i) { \
12414-
if (!validate_f16(q[i].d, i)) { \
12432+
if (!validate_fp16(q[i].d, i)) { \
1241512433
return false; \
1241612434
} \
1241712435
}
1241812436

1241912437
#define VALIDATE_ROW_DATA_DM_F16_IMPL(type, data, nb, d, m) \
1242012438
const type * q = (const type *) (data); \
1242112439
for (size_t i = 0; i < (nb); ++i) { \
12422-
if (!validate_f16(q[i].d, i) || !validate_f16(q[i].m, i)) { \
12440+
if (!validate_fp16(q[i].d, i) || !validate_fp16(q[i].m, i)) { \
1242312441
return false; \
1242412442
} \
1242512443
}
@@ -12436,22 +12454,56 @@ bool ggml_validate_row_data(enum ggml_type type, const void * data, size_t nbyte
1243612454
return false;
1243712455
}
1243812456

12439-
size_t nb = nbytes/ggml_type_size(type);
12457+
const size_t nb = nbytes/ggml_type_size(type);
1244012458

1244112459
switch (type) {
1244212460
case GGML_TYPE_F16:
1244312461
{
1244412462
const ggml_fp16_t * f = (const ggml_fp16_t *) data;
12445-
for (size_t i = 0; i < nb; ++i) {
12446-
if (!validate_f16(f[i], i)) {
12463+
size_t i = 0;
12464+
#ifdef __AVX2__
12465+
for (; i + 15 < nb; i += 16) {
12466+
__m256i v = _mm256_loadu_si256((const __m256i *)(f + i));
12467+
__m256i vexp = _mm256_and_si256(v, _mm256_set1_epi16(0x7c00));
12468+
__m256i cmp = _mm256_cmpeq_epi16(vexp, _mm256_set1_epi16(0x7c00));
12469+
int mask = _mm256_movemask_epi8(cmp);
12470+
if (mask) {
12471+
for (size_t j = 0; j < 16; ++j) {
12472+
if (!validate_fp16(f[i + j], i + j)) {
12473+
return false;
12474+
}
12475+
}
12476+
GGML_UNREACHABLE();
12477+
}
12478+
}
12479+
#endif
12480+
for (; i < nb; ++i) {
12481+
if (!validate_fp16(f[i], i)) {
1244712482
return false;
1244812483
}
1244912484
}
1245012485
} break;
1245112486
case GGML_TYPE_F32:
1245212487
{
1245312488
const float * f = (const float *) data;
12454-
for (size_t i = 0; i < nb; ++i) {
12489+
size_t i = 0;
12490+
#ifdef __AVX2__
12491+
for (; i + 7 < nb; i += 8) {
12492+
__m256i v = _mm256_loadu_si256((const __m256i *)(f + i));
12493+
__m256i vexp = _mm256_and_si256(v, _mm256_set1_epi32(0x7f800000));
12494+
__m256i cmp = _mm256_cmpeq_epi32(vexp, _mm256_set1_epi32(0x7f800000));
12495+
int mask = _mm256_movemask_epi8(cmp);
12496+
if (mask) {
12497+
for (size_t j = 0; j < 8; ++j) {
12498+
if (!validate_float(f[i + j], i + j)) {
12499+
return false;
12500+
}
12501+
}
12502+
GGML_UNREACHABLE();
12503+
}
12504+
}
12505+
#endif
12506+
for (; i < nb; ++i) {
1245512507
if (!validate_float(f[i], i)) {
1245612508
return false;
1245712509
}
@@ -12539,7 +12591,7 @@ bool ggml_validate_row_data(enum ggml_type type, const void * data, size_t nbyte
1253912591
iq1m_scale_t scale;
1254012592
const uint16_t * sc = (const uint16_t *)q[i].scales;
1254112593
scale.u16 = (sc[0] >> 12) | ((sc[1] >> 8) & 0x00f0) | ((sc[2] >> 4) & 0x0f00) | (sc[3] & 0xf000);
12542-
if (!validate_f16(scale.f16, i)) {
12594+
if (!validate_fp16(scale.f16, i)) {
1254312595
return false;
1254412596
}
1254512597
#endif

0 commit comments

Comments
 (0)