Skip to content

Commit 85dee42

Browse files
committed
Arm AArch64: minor code refactoring for rebase
1 parent b1a267e commit 85dee42

File tree

4 files changed

+31
-35
lines changed

4 files changed

+31
-35
lines changed

ggml-aarch64.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -92,7 +92,7 @@ size_t quantize_q4_0_aarch64(const float * GGML_RESTRICT src, void * GGML_RESTRI
9292
}
9393
}
9494

95-
void quantize_row_q8_0_aarch64(const float * GGML_RESTRICT x, void * GGML_RESTRICT vy, int64_t k, int nrows_interleaved, int blocklen_per_row) {
95+
void quantize_q8_0_aarch64(const float * GGML_RESTRICT x, void * GGML_RESTRICT vy, int64_t k, int nrows_interleaved, int blocklen_per_row) {
9696
assert(QK8_0 == 32);
9797
assert(k % QK8_0 == 0);
9898
const int nb = k / QK8_0;

ggml-aarch64.h

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,7 @@ extern "C" {
1313
#endif
1414

1515
// Quantization
16-
void quantize_row_q8_0_aarch64(const float * GGML_RESTRICT x, void * GGML_RESTRICT y, int64_t k, int nrows_interleaved, int blocklen_per_row);
16+
void quantize_q8_0_aarch64(const float * GGML_RESTRICT x, void * GGML_RESTRICT y, int64_t k, int nrows_interleaved, int blocklen_per_row);
1717

1818
// Quantization utilizing an importance matrix (a.k.a. "Activation aWare Quantization")
1919
size_t quantize_q4_0_aarch64(const float * GGML_RESTRICT src, void * GGML_RESTRICT dst, int64_t nrows, int64_t n_per_row, const float * imatrix);

ggml-quants.c

Lines changed: 23 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -12436,6 +12436,16 @@ static bool validate_fp16(ggml_fp16_t f, size_t i) {
1243612436
} \
1243712437
}
1243812438

12439+
#define VALIDATE_ROW_DATA_DVEC_F16_IMPL(type, data, nb, nr) \
12440+
const type * q = (const type *) (data); \
12441+
for (size_t i = 0; i < (nb); ++i) { \
12442+
for (size_t j = 0; j < (nr); ++j) { \
12443+
if (!validate_fp16(q[i].d[j], i)) { \
12444+
return false; \
12445+
} \
12446+
} \
12447+
}
12448+
1243912449
bool ggml_validate_row_data(enum ggml_type type, const void * data, size_t nbytes) {
1244012450
if (type < 0 || type >= GGML_TYPE_COUNT) {
1244112451
fprintf(stderr, "%s: invalid type %d\n", __func__, type);
@@ -12652,6 +12662,19 @@ bool ggml_validate_row_data(enum ggml_type type, const void * data, size_t nbyte
1265212662
{
1265312663
VALIDATE_ROW_DATA_D_F16_IMPL(block_iq4_nl, data, nb);
1265412664
} break;
12665+
case GGML_TYPE_Q4_0_AARCH64:
12666+
{
12667+
#if defined(__ARM_FEATURE_SVE)
12668+
if (svcntw() == 8) {
12669+
VALIDATE_ROW_DATA_DVEC_F16_IMPL(block_q4_0x8, data, nbytes / sizeof(block_q4_0x8), 8);
12670+
}
12671+
else if (ggml_cpu_has_neon() && ggml_cpu_has_matmul_int8()) {
12672+
VALIDATE_ROW_DATA_DVEC_F16_IMPL(block_q4_0x4, data, nbytes / sizeof(block_q4_0x4), 4);
12673+
}
12674+
#elif defined(__ARM_NEON)
12675+
VALIDATE_ROW_DATA_DVEC_F16_IMPL(block_q4_0x4, data, nbytes / sizeof(block_q4_0x4), 4);
12676+
#endif
12677+
} break;
1265512678
case GGML_TYPE_I8:
1265612679
case GGML_TYPE_I16:
1265712680
case GGML_TYPE_I32:

ggml.c

Lines changed: 6 additions & 33 deletions
Original file line numberDiff line numberDiff line change
@@ -653,7 +653,7 @@ static const ggml_type_traits_t type_traits[GGML_TYPE_COUNT] = {
653653
#else
654654
.nrows = 1,
655655
#endif
656-
.from_float_to_mat = quantize_row_q8_0_aarch64,
656+
.from_float_to_mat = quantize_q8_0_aarch64,
657657
},
658658
[GGML_TYPE_Q8_1] = {
659659
.type_name = "q8_1",
@@ -853,16 +853,12 @@ static const ggml_type_traits_t type_traits[GGML_TYPE_COUNT] = {
853853
.blck_size = QK4_0,
854854
.type_size = sizeof(block_q4_0),
855855
.is_quantized = true,
856-
.to_float = (ggml_to_float_t) dequantize_row_q4_0,
857-
.from_float = quantize_row_q4_0,
858-
.from_float_reference = (ggml_from_float_t) quantize_row_q4_0_reference,
859-
.vec_dot = ggml_vec_dot_q4_0_q8_0,
856+
.to_float = NULL,
857+
.from_float = NULL,
858+
.from_float_reference = NULL,
859+
.vec_dot = NULL,
860860
.vec_dot_type = GGML_TYPE_Q8_0,
861-
#if defined (__ARM_FEATURE_MATMUL_INT8)
862-
.nrows = 2,
863-
#else
864861
.nrows = 1,
865-
#endif
866862
#if defined(__ARM_FEATURE_SVE)
867863
.gemv = ggml_gemv_q4_0_q8_0_aarch64_sve256,
868864
.gemm = ggml_gemm_q4_0_q8_0_aarch64_sve256,
@@ -11111,8 +11107,7 @@ UseGgmlGemm2:;
1111111107
if ((ggml_n_dims(src0) == 2) && (ne11 == 1) && (type == GGML_TYPE_Q4_0_AARCH64)) {
1111211108
gemv(ne00, (float *)((char *) dst->data), (const char *) src0->data, (const char *) wdata, 1, ne01, ith, nth);
1111311109
}
11114-
else if ((ggml_n_dims(src0) == 2) && (ne11 >= 16) && (type == GGML_TYPE_Q4_0_AARCH64)) {
11115-
// use nrows-sized 16, 8, and 4 GEMM kernels
11110+
else if ((ggml_n_dims(src0) == 2) && (ne11 >= 2) && (type == GGML_TYPE_Q4_0_AARCH64)) {
1111611111
for (int row_iter = 0; row_iter < ne11 / 16; row_iter++) {
1111711112
gemm(ne00, (float *)((char *) dst->data + (row_iter * 16 * nb1)), (const char *) src0->data, (const char *) wdata + (src1_cont || src1->type != vec_dot_type ? (row_iter * 16) * row_size : (row_iter * 16 * nb11)), 16, ne01, ith, nth);
1111811113
}
@@ -11129,28 +11124,6 @@ UseGgmlGemm2:;
1112911124
gemv(ne00, (float *)((char *) dst->data + (row_iter * nb1)), (const char *) src0->data, (const char *) wdata + (src1_cont || src1->type != vec_dot_type ? (row_iter)*row_size : (row_iter * nb11)), 1, ne01, ith, nth);
1113011125
}
1113111126
}
11132-
else if ((ggml_n_dims(src0) == 2) && (ne11 >= 8) && (type == GGML_TYPE_Q4_0_AARCH64)) {
11133-
// use nrows-sized 8, and 4 GEMM kernels
11134-
for (int row_iter = 0; row_iter < ne11 / 8; row_iter++) {
11135-
gemm(ne00, (float *)((char *) dst->data + (row_iter * 8 * nb1)), (const char *) src0->data, (const char *) wdata + (src1_cont || src1->type != vec_dot_type ? (row_iter * 8) * row_size : (row_iter * 8 * nb11)), 8, ne01, ith, nth);
11136-
}
11137-
int rows_processed = (ne11 / 8) * 8;
11138-
for (int row_iter = 0; row_iter < (ne11 - rows_processed) / 4; row_iter++) {
11139-
gemm(ne00, (float *)((char *) dst->data + ((rows_processed + row_iter * 4) * nb1)), (const char *) src0->data, (const char *) wdata + (src1_cont || src1->type != vec_dot_type ? (rows_processed + row_iter * 4) * row_size : ((rows_processed + row_iter * 4) * nb11)), 4, ne01, ith, nth);
11140-
}
11141-
for (int row_iter = ((ne11 / 8) * 8) + ((ne11 - rows_processed) / 4 * 4); row_iter < ne11; row_iter++) {
11142-
gemv(ne00, (float *)((char *) dst->data + (row_iter * nb1)), (const char *) src0->data, (const char *) wdata + (src1_cont || src1->type != vec_dot_type ? (row_iter)*row_size : (row_iter * nb11)), 1, ne01, ith, nth);
11143-
}
11144-
}
11145-
else if ((ggml_n_dims(src0) == 2) && (ne11 >= 4) && (type == GGML_TYPE_Q4_0_AARCH64)) {
11146-
// use nrows-sized 4 GEMM kernel
11147-
for (int row_iter = 0; row_iter < ne11 / 4; row_iter++) {
11148-
gemm(ne00, (float *)((char *) dst->data + (row_iter * 4 * nb1)), (const char *) src0->data, (const char *) wdata + (src1_cont || src1->type != vec_dot_type ? (row_iter * 4) * row_size : (row_iter * 4 * nb11)), 4, ne01, ith, nth);
11149-
}
11150-
for (int row_iter = (ne11 / 4) * 4; row_iter < ne11; row_iter++) {
11151-
gemv(ne00, (float *)((char *) dst->data + (row_iter * nb1)), (const char *) src0->data, (const char *) wdata + (src1_cont || src1->type != vec_dot_type ? (row_iter)*row_size : (row_iter * nb11)), 1, ne01, ith, nth);
11152-
}
11153-
}
1115411127
else {
1115511128
for (int64_t iir1 = ir110; iir1 < ir111; iir1 += blck_1) {
1115611129
for (int64_t iir0 = ir010; iir0 < ir011; iir0 += blck_0) {

0 commit comments

Comments
 (0)