Skip to content

Commit 3ceb073

Browse files
committed
Merge branch 'master' into q4_1xq8_0
2 parents 7840f66 + 50a8a2a commit 3ceb073

File tree

7 files changed

+340
-21
lines changed

7 files changed

+340
-21
lines changed

.gitignore

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,7 @@ models/*
2424
/perplexity
2525
/embedding
2626
/benchmark-q4_0-matmult
27+
/vdot
2728
/Pipfile
2829

2930
arm_neon.h

CMakeLists.txt

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -305,4 +305,5 @@ endif ()
305305

306306
if (LLAMA_BUILD_EXAMPLES)
307307
add_subdirectory(examples)
308+
add_subdirectory(pocs)
308309
endif()

Makefile

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -133,7 +133,7 @@ $(info I CC: $(CCV))
133133
$(info I CXX: $(CXXV))
134134
$(info )
135135

136-
default: main quantize quantize-stats perplexity embedding
136+
default: main quantize quantize-stats perplexity embedding vdot
137137

138138
#
139139
# Build library
@@ -169,6 +169,9 @@ perplexity: examples/perplexity/perplexity.cpp ggml.o llama.o common.o
169169
embedding: examples/embedding/embedding.cpp ggml.o llama.o common.o
170170
$(CXX) $(CXXFLAGS) $^ -o $@ $(LDFLAGS)
171171

172+
vdot: pocs/vdot/vdot.cpp ggml.o
173+
$(CXX) $(CXXFLAGS) $^ -o $@ $(LDFLAGS)
174+
172175
libllama.so: llama.o ggml.o
173176
$(CXX) $(CXXFLAGS) -shared -fPIC -o $@ $^ $(LDFLAGS)
174177

ggml.c

Lines changed: 13 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -2249,8 +2249,8 @@ static void ggml_vec_dot_q4_0_q8_0(const int n, float * restrict s, const void *
22492249
float sumf = 0.0;
22502250

22512251
#if defined(__ARM_NEON)
2252-
float sum0 = 0.0f;
2253-
float sum1 = 0.0f;
2252+
float32x4_t sumv0 = vdupq_n_f32(0.0f);
2253+
float32x4_t sumv1 = vdupq_n_f32(0.0f);
22542254

22552255
for (int i = 0; i < nb; i += 2) {
22562256
const block_q4_0 * restrict x0 = &x[i + 0];
@@ -2290,14 +2290,11 @@ static void ggml_vec_dot_q4_0_q8_0(const int n, float * restrict s, const void *
22902290

22912291
#if defined(__ARM_FEATURE_DOTPROD)
22922292
// dot product into int32x4_t
2293-
int32x4_t p_0 = vdotq_s32(vdupq_n_s32(0), v0_0ls, v1_0ls);
2294-
int32x4_t p_1 = vdotq_s32(vdupq_n_s32(0), v0_1ls, v1_1ls);
2293+
const int32x4_t p_0 = vdotq_s32(vdotq_s32(vdupq_n_s32(0), v0_0ls, v1_0ls), v0_0hs, v1_0hs);
2294+
const int32x4_t p_1 = vdotq_s32(vdotq_s32(vdupq_n_s32(0), v0_1ls, v1_1ls), v0_1hs, v1_1hs);
22952295

2296-
p_0 = vdotq_s32(p_0, v0_0hs, v1_0hs);
2297-
p_1 = vdotq_s32(p_1, v0_1hs, v1_1hs);
2298-
2299-
sum0 += x0->d*y0->d*vaddvq_s32(p_0);
2300-
sum1 += x1->d*y1->d*vaddvq_s32(p_1);
2296+
sumv0 = vmlaq_n_f32(sumv0, vcvtq_f32_s32(p_0), x0->d*y0->d);
2297+
sumv1 = vmlaq_n_f32(sumv1, vcvtq_f32_s32(p_1), x1->d*y1->d);
23012298
#else
23022299
const int16x8_t pl0l = vmull_s8(vget_low_s8 (v0_0ls), vget_low_s8 (v1_0ls));
23032300
const int16x8_t pl0h = vmull_s8(vget_high_s8(v0_0ls), vget_high_s8(v1_0ls));
@@ -2309,21 +2306,17 @@ static void ggml_vec_dot_q4_0_q8_0(const int n, float * restrict s, const void *
23092306
const int16x8_t ph1l = vmull_s8(vget_low_s8 (v0_1hs), vget_low_s8 (v1_1hs));
23102307
const int16x8_t ph1h = vmull_s8(vget_high_s8(v0_1hs), vget_high_s8(v1_1hs));
23112308

2312-
const int16x8_t pl_0 = vaddq_s16(pl0l, pl0h);
2313-
const int16x8_t ph_0 = vaddq_s16(ph0l, ph0h);
2314-
2315-
const int16x8_t pl_1 = vaddq_s16(pl1l, pl1h);
2316-
const int16x8_t ph_1 = vaddq_s16(ph1l, ph1h);
2317-
2318-
const int16x8_t p_0 = vaddq_s16(pl_0, ph_0);
2319-
const int16x8_t p_1 = vaddq_s16(pl_1, ph_1);
2309+
const int32x4_t pl0 = vaddq_s32(vpaddlq_s16(pl0l), vpaddlq_s16(pl0h));
2310+
const int32x4_t ph0 = vaddq_s32(vpaddlq_s16(ph0l), vpaddlq_s16(ph0h));
2311+
const int32x4_t pl1 = vaddq_s32(vpaddlq_s16(pl1l), vpaddlq_s16(pl1h));
2312+
const int32x4_t ph1 = vaddq_s32(vpaddlq_s16(ph1l), vpaddlq_s16(ph1h));
23202313

2321-
sum0 += x0->d*y0->d*vaddvq_s16(p_0);
2322-
sum1 += x1->d*y1->d*vaddvq_s16(p_1);
2314+
sumv0 = vmlaq_n_f32(sumv0, vcvtq_f32_s32(vaddq_s32(pl0, ph0)), x0->d*y0->d);
2315+
sumv1 = vmlaq_n_f32(sumv1, vcvtq_f32_s32(vaddq_s32(pl1, ph1)), x1->d*y1->d);
23232316
#endif
23242317
}
23252318

2326-
sumf = sum0 + sum1;
2319+
sumf = vaddvq_f32(sumv0) + vaddvq_f32(sumv1);
23272320
#elif defined(__AVX2__)
23282321
// Initialize accumulator with zeros
23292322
__m256 acc = _mm256_setzero_ps();

pocs/CMakeLists.txt

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,12 @@
1+
# dependencies
2+
3+
find_package(Threads REQUIRED)
4+
5+
# third-party
6+
7+
include_directories(${CMAKE_CURRENT_SOURCE_DIR})
8+
9+
if (EMSCRIPTEN)
10+
else()
11+
add_subdirectory(vdot)
12+
endif()

pocs/vdot/CMakeLists.txt

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,4 @@
1+
set(TARGET vdot)
2+
add_executable(${TARGET} vdot.cpp)
3+
target_link_libraries(${TARGET} PRIVATE common llama ${CMAKE_THREAD_LIBS_INIT})
4+
target_compile_features(${TARGET} PRIVATE cxx_std_11)

0 commit comments

Comments
 (0)