Skip to content

Commit 6e64a6a

Browse files
committed
Update quantize_row_q4_0 for WASM
Untested
1 parent 5d447e9 commit 6e64a6a

File tree

1 file changed

+22
-13
lines changed

1 file changed

+22
-13
lines changed

ggml.c

Lines changed: 22 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -797,24 +797,32 @@ static void quantize_row_q4_0(const float * restrict x, void * restrict vy, int
797797
}
798798
#elif defined(__wasm_simd128__)
799799
for (int i = 0; i < nb; i++) {
800-
float amax = 0.0f; // absolute max
800+
float max = 0.0f;
801+
float min = 0.0f;
801802

802803
v128_t srcv [8];
803-
v128_t asrcv[8];
804-
v128_t amaxv[8];
804+
v128_t maxv[8];
805+
v128_t minv[8];
805806

806807
for (int l = 0; l < 8; l++) srcv[l] = wasm_v128_load(x + i*32 + 4*l);
807-
for (int l = 0; l < 8; l++) asrcv[l] = wasm_f32x4_abs(srcv[l]);
808808

809-
for (int l = 0; l < 4; l++) amaxv[2*l] = wasm_f32x4_max(asrcv[2*l], asrcv[2*l+1]);
810-
for (int l = 0; l < 2; l++) amaxv[4*l] = wasm_f32x4_max(amaxv[4*l], amaxv[4*l+2]);
811-
for (int l = 0; l < 1; l++) amaxv[8*l] = wasm_f32x4_max(amaxv[8*l], amaxv[8*l+4]);
809+
for (int l = 0; l < 4; l++) maxv[2*l] = wasm_f32x4_max(srcv[2*l], srcv[2*l+1]);
810+
for (int l = 0; l < 2; l++) maxv[4*l] = wasm_f32x4_max(maxv[4*l], maxv[4*l+2]);
811+
for (int l = 0; l < 1; l++) maxv[8*l] = wasm_f32x4_max(maxv[8*l], maxv[8*l+4]);
812812

813-
amax = MAX(
814-
MAX(wasm_f32x4_extract_lane(amaxv[0], 0), wasm_f32x4_extract_lane(amaxv[0], 1)),
815-
MAX(wasm_f32x4_extract_lane(amaxv[0], 2), wasm_f32x4_extract_lane(amaxv[0], 3)));
813+
for (int l = 0; l < 4; l++) minv[2*l] = wasm_f32x4_min(srcv[2*l], srcv[2*l+1]);
814+
for (int l = 0; l < 2; l++) minv[4*l] = wasm_f32x4_min(minv[4*l], minv[4*l+2]);
815+
for (int l = 0; l < 1; l++) minv[8*l] = wasm_f32x4_min(minv[8*l], minv[8*l+4]);
816816

817-
const float d = amax / ((1 << 3) - 1);
817+
max = MAX(
818+
MAX(wasm_f32x4_extract_lane(maxv[0], 0), wasm_f32x4_extract_lane(maxv[0], 1)),
819+
MAX(wasm_f32x4_extract_lane(maxv[0], 2), wasm_f32x4_extract_lane(maxv[0], 3)));
820+
min = MIN(
821+
MIN(wasm_f32x4_extract_lane(minv[0], 0), wasm_f32x4_extract_lane(minv[0], 1)),
822+
MIN(wasm_f32x4_extract_lane(minv[0], 2), wasm_f32x4_extract_lane(minv[0], 3)));
823+
824+
const float magnitude = max >= fabsf(min) ? max : min;
825+
const float d = magnitude / -8;
818826
const float id = d ? 1.0/d : 0.0;
819827

820828
y[i].d = d;
@@ -823,9 +831,10 @@ static void quantize_row_q4_0(const float * restrict x, void * restrict vy, int
823831
const v128_t v = wasm_f32x4_mul(srcv[l], wasm_f32x4_splat(id));
824832
const v128_t vf = wasm_f32x4_add(v, wasm_f32x4_splat(8.5f));
825833
const v128_t vi = wasm_i32x4_trunc_sat_f32x4(vf);
834+
const v128_t vc = wasm_i32x4_min_u(vi, wasm_i32x4_splat(15));
826835

827-
y[i].qs[2*l + 0] = wasm_i32x4_extract_lane(vi, 0) | (wasm_i32x4_extract_lane(vi, 1) << 4);
828-
y[i].qs[2*l + 1] = wasm_i32x4_extract_lane(vi, 2) | (wasm_i32x4_extract_lane(vi, 3) << 4);
836+
y[i].qs[2*l + 0] = wasm_i32x4_extract_lane(vc, 0) | (wasm_i32x4_extract_lane(vc, 1) << 4);
837+
y[i].qs[2*l + 1] = wasm_i32x4_extract_lane(vc, 2) | (wasm_i32x4_extract_lane(vc, 3) << 4);
829838
}
830839
}
831840
#else

0 commit comments

Comments
 (0)