Skip to content

Commit 366805e

Browse files
committed
[LIBC] Add an optimized memcmp implementation for AArch64
Differential Revision: https://reviews.llvm.org/D105441
1 parent d0fe294 commit 366805e

File tree

6 files changed

+179
-45
lines changed

6 files changed

+179
-45
lines changed

libc/src/string/CMakeLists.txt

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -295,7 +295,7 @@ endif()
295295

296296
function(add_memcmp memcmp_name)
297297
add_implementation(memcmp ${memcmp_name}
298-
SRCS ${LIBC_SOURCE_DIR}/src/string/memcmp.cpp
298+
SRCS ${LIBC_MEMCMP_SRC}
299299
HDRS ${LIBC_SOURCE_DIR}/src/string/memcmp.h
300300
DEPENDS
301301
.memory_utils.memory_utils
@@ -307,13 +307,19 @@ function(add_memcmp memcmp_name)
307307
endfunction()
308308

309309
if(${LIBC_TARGET_ARCHITECTURE_IS_X86})
310+
set(LIBC_MEMCMP_SRC ${LIBC_SOURCE_DIR}/src/string/memcmp.cpp)
310311
add_memcmp(memcmp_x86_64_opt_sse2 COMPILE_OPTIONS -march=k8 REQUIRE SSE2)
311312
add_memcmp(memcmp_x86_64_opt_sse4 COMPILE_OPTIONS -march=nehalem REQUIRE SSE4_2)
312313
add_memcmp(memcmp_x86_64_opt_avx2 COMPILE_OPTIONS -march=haswell REQUIRE AVX2)
313314
add_memcmp(memcmp_x86_64_opt_avx512 COMPILE_OPTIONS -march=skylake-avx512 REQUIRE AVX512F)
314315
add_memcmp(memcmp_opt_host COMPILE_OPTIONS ${LIBC_COMPILE_OPTIONS_NATIVE})
315316
add_memcmp(memcmp)
317+
elseif(${LIBC_TARGET_ARCHITECTURE_IS_AARCH64})
318+
set(LIBC_MEMCMP_SRC ${LIBC_SOURCE_DIR}/src/string/aarch64/memcmp.cpp)
319+
add_memcmp(memcmp)
320+
add_memcmp(memcmp_opt_host COMPILE_OPTIONS ${LIBC_COMPILE_OPTIONS_NATIVE})
316321
else()
322+
set(LIBC_MEMCMP_SRC ${LIBC_SOURCE_DIR}/src/string/memcmp.cpp)
317323
add_memcmp(memcmp_opt_host COMPILE_OPTIONS ${LIBC_COMPILE_OPTIONS_NATIVE})
318324
add_memcmp(memcmp)
319325
endif()

libc/src/string/aarch64/memcmp.cpp

Lines changed: 59 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,59 @@
1+
//===-- Implementation of memcmp ------------------------------------------===//
2+
//
3+
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4+
// See https://llvm.org/LICENSE.txt for license information.
5+
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6+
//
7+
//===----------------------------------------------------------------------===//
8+
9+
#include "src/string/memcmp.h"
10+
#include "src/__support/common.h"
11+
#include "src/string/memory_utils/elements.h"
12+
#include <stddef.h> // size_t
13+
14+
namespace __llvm_libc {
15+
namespace aarch64 {
16+
17+
static int memcmp_impl(const char *lhs, const char *rhs, size_t count) {
18+
if (count == 0)
19+
return 0;
20+
if (count == 1)
21+
return ThreeWayCompare<_1>(lhs, rhs);
22+
else if (count == 2)
23+
return ThreeWayCompare<_2>(lhs, rhs);
24+
else if (count == 3)
25+
return ThreeWayCompare<_3>(lhs, rhs);
26+
else if (count < 8)
27+
return ThreeWayCompare<HeadTail<_4>>(lhs, rhs, count);
28+
else if (count < 16)
29+
return ThreeWayCompare<HeadTail<_8>>(lhs, rhs, count);
30+
else if (count < 128) {
31+
if (Equals<_16>(lhs, rhs)) {
32+
if (count < 32)
33+
return ThreeWayCompare<Tail<_16>>(lhs, rhs, count);
34+
else {
35+
if (Equals<_16>(lhs + 16, rhs + 16)) {
36+
if (count < 64)
37+
return ThreeWayCompare<Tail<_32>>(lhs, rhs, count);
38+
if (count < 128)
39+
return ThreeWayCompare<Loop<_16>>(lhs + 32, rhs + 32, count - 32);
40+
} else
41+
return ThreeWayCompare<_16>(lhs + count - 32, rhs + count - 32);
42+
}
43+
}
44+
return ThreeWayCompare<_16>(lhs, rhs);
45+
} else
46+
return ThreeWayCompare<Align<_16, Arg::Lhs>::Then<Loop<_32>>>(lhs, rhs,
47+
count);
48+
}
49+
} // namespace aarch64
50+
51+
LLVM_LIBC_FUNCTION(int, memcmp,
52+
(const void *lhs, const void *rhs, size_t count)) {
53+
54+
const char *_lhs = reinterpret_cast<const char *>(lhs);
55+
const char *_rhs = reinterpret_cast<const char *>(rhs);
56+
return aarch64::memcmp_impl(_lhs, _rhs, count);
57+
}
58+
59+
} // namespace __llvm_libc

libc/src/string/memory_utils/elements.h

Lines changed: 23 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -211,8 +211,8 @@ template <typename T> struct HeadTail {
211211
}
212212

213213
static int ThreeWayCompare(const char *lhs, const char *rhs, size_t size) {
214-
if (const int result = T::ThreeWayCompare(lhs, rhs))
215-
return result;
214+
if (!T::Equals(lhs, rhs))
215+
return T::ThreeWayCompare(lhs, rhs);
216216
return Tail<T>::ThreeWayCompare(lhs, rhs, size);
217217
}
218218

@@ -251,8 +251,8 @@ template <typename T> struct Loop {
251251

252252
static int ThreeWayCompare(const char *lhs, const char *rhs, size_t size) {
253253
for (size_t offset = 0; offset < size - T::kSize; offset += T::kSize)
254-
if (const int result = T::ThreeWayCompare(lhs + offset, rhs + offset))
255-
return result;
254+
if (!T::Equals(lhs + offset, rhs + offset))
255+
return T::ThreeWayCompare(lhs + offset, rhs + offset);
256256
return Tail<T>::ThreeWayCompare(lhs, rhs, size);
257257
}
258258

@@ -327,8 +327,8 @@ template <typename AlignmentT, Arg AlignOn> struct Align {
327327
}
328328

329329
static int ThreeWayCompare(const char *lhs, const char *rhs, size_t size) {
330-
if (const int result = AlignmentT::ThreeWayCompare(lhs, rhs))
331-
return result;
330+
if (!AlignmentT::Equals(lhs, rhs))
331+
return AlignmentT::ThreeWayCompare(lhs, rhs);
332332
internal::AlignHelper<AlignOn, Alignment>::Bump(lhs, rhs, size);
333333
return NextT::ThreeWayCompare(lhs, rhs, size);
334334
}
@@ -370,12 +370,18 @@ template <size_t Size> struct Builtin {
370370
#endif
371371
}
372372

373+
#if __has_builtin(__builtin_memcmp_inline)
374+
#define LLVM_LIBC_MEMCMP __builtin_memcmp_inline
375+
#else
376+
#define LLVM_LIBC_MEMCMP __builtin_memcmp
377+
#endif
378+
373379
static bool Equals(const char *lhs, const char *rhs) {
374-
return __builtin_memcmp(lhs, rhs, kSize) == 0;
380+
return LLVM_LIBC_MEMCMP(lhs, rhs, kSize) == 0;
375381
}
376382

377383
static int ThreeWayCompare(const char *lhs, const char *rhs) {
378-
return __builtin_memcmp(lhs, rhs, kSize);
384+
return LLVM_LIBC_MEMCMP(lhs, rhs, kSize);
379385
}
380386

381387
static void SplatSet(char *dst, const unsigned char value) {
@@ -428,6 +434,8 @@ template <typename T> struct Scalar {
428434
Store(dst, GetSplattedValue(value));
429435
}
430436

437+
static int ScalarThreeWayCompare(T a, T b);
438+
431439
private:
432440
static T Load(const char *ptr) {
433441
T value;
@@ -440,7 +448,6 @@ template <typename T> struct Scalar {
440448
static T GetSplattedValue(const unsigned char value) {
441449
return T(~0) / T(0xFF) * T(value);
442450
}
443-
static int ScalarThreeWayCompare(T a, T b);
444451
};
445452

446453
template <>
@@ -457,23 +464,15 @@ inline int Scalar<uint16_t>::ScalarThreeWayCompare(uint16_t a, uint16_t b) {
457464
}
458465
template <>
459466
inline int Scalar<uint32_t>::ScalarThreeWayCompare(uint32_t a, uint32_t b) {
460-
const int64_t la = Endian::ToBigEndian(a);
461-
const int64_t lb = Endian::ToBigEndian(b);
462-
if (la < lb)
463-
return -1;
464-
if (la > lb)
465-
return 1;
466-
return 0;
467+
const uint32_t la = Endian::ToBigEndian(a);
468+
const uint32_t lb = Endian::ToBigEndian(b);
469+
return la > lb ? 1 : la < lb ? -1 : 0;
467470
}
468471
template <>
469472
inline int Scalar<uint64_t>::ScalarThreeWayCompare(uint64_t a, uint64_t b) {
470-
const __int128_t la = Endian::ToBigEndian(a);
471-
const __int128_t lb = Endian::ToBigEndian(b);
472-
if (la < lb)
473-
return -1;
474-
if (la > lb)
475-
return 1;
476-
return 0;
473+
const uint64_t la = Endian::ToBigEndian(a);
474+
const uint64_t lb = Endian::ToBigEndian(b);
475+
return la > lb ? 1 : la < lb ? -1 : 0;
477476
}
478477

479478
using UINT8 = Scalar<uint8_t>; // 1 Byte
@@ -494,6 +493,7 @@ using _128 = Repeated<_8, 16>;
494493
} // namespace scalar
495494
} // namespace __llvm_libc
496495

496+
#include <src/string/memory_utils/elements_aarch64.h>
497497
#include <src/string/memory_utils/elements_x86.h>
498498

499499
#endif // LLVM_LIBC_SRC_STRING_MEMORY_UTILS_ELEMENTS_H
Lines changed: 68 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,68 @@
1+
//===-- Elementary operations for aarch64 --------------------------------===//
2+
//
3+
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4+
// See https://llvm.org/LICENSE.txt for license information.
5+
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6+
//
7+
//===----------------------------------------------------------------------===//
8+
9+
#ifndef LLVM_LIBC_SRC_STRING_MEMORY_UTILS_ELEMENTS_AARCH64_H
10+
#define LLVM_LIBC_SRC_STRING_MEMORY_UTILS_ELEMENTS_AARCH64_H
11+
12+
#include <src/string/memory_utils/elements.h>
13+
#include <stddef.h> // size_t
14+
#include <stdint.h> // uint8_t, uint16_t, uint32_t, uint64_t
15+
16+
#ifdef __ARM_NEON
17+
#include <arm_neon.h>
18+
#endif
19+
20+
namespace __llvm_libc {
21+
namespace aarch64 {
22+
23+
using _1 = __llvm_libc::scalar::_1;
24+
using _2 = __llvm_libc::scalar::_2;
25+
using _3 = __llvm_libc::scalar::_3;
26+
using _4 = __llvm_libc::scalar::_4;
27+
using _8 = __llvm_libc::scalar::_8;
28+
using _16 = __llvm_libc::scalar::_16;
29+
30+
#ifdef __ARM_NEON
31+
struct N32 {
32+
static constexpr size_t kSize = 32;
33+
static bool Equals(const char *lhs, const char *rhs) {
34+
uint8x16_t l_0 = vld1q_u8((const uint8_t *)lhs);
35+
uint8x16_t r_0 = vld1q_u8((const uint8_t *)rhs);
36+
uint8x16_t l_1 = vld1q_u8((const uint8_t *)(lhs + 16));
37+
uint8x16_t r_1 = vld1q_u8((const uint8_t *)(rhs + 16));
38+
uint8x16_t temp = vpmaxq_u8(veorq_u8(l_0, r_0), veorq_u8(l_1, r_1));
39+
uint64_t res =
40+
vgetq_lane_u64(vreinterpretq_u64_u8(vpmaxq_u8(temp, temp)), 0);
41+
return res == 0;
42+
}
43+
static int ThreeWayCompare(const char *lhs, const char *rhs) {
44+
uint8x16_t l_0 = vld1q_u8((const uint8_t *)lhs);
45+
uint8x16_t r_0 = vld1q_u8((const uint8_t *)rhs);
46+
uint8x16_t l_1 = vld1q_u8((const uint8_t *)(lhs + 16));
47+
uint8x16_t r_1 = vld1q_u8((const uint8_t *)(rhs + 16));
48+
uint8x16_t temp = vpmaxq_u8(veorq_u8(l_0, r_0), veorq_u8(l_1, r_1));
49+
uint64_t res =
50+
vgetq_lane_u64(vreinterpretq_u64_u8(vpmaxq_u8(temp, temp)), 0);
51+
if (res == 0)
52+
return 0;
53+
size_t index = (__builtin_ctzl(res) >> 3) << 2;
54+
uint32_t l = *((const uint32_t *)(lhs + index));
55+
uint32_t r = *((const uint32_t *)(rhs + index));
56+
return __llvm_libc::scalar::_4::ScalarThreeWayCompare(l, r);
57+
}
58+
};
59+
60+
using _32 = N32;
61+
#else
62+
using _32 = __llvm_libc::scalar::_32;
63+
#endif // __ARM_NEON
64+
65+
} // namespace aarch64
66+
} // namespace __llvm_libc
67+
68+
#endif // LLVM_LIBC_SRC_STRING_MEMORY_UTILS_ELEMENTS_AARCH64_H

libc/test/src/string/CMakeLists.txt

Lines changed: 2 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -52,27 +52,6 @@ add_libc_unittest(
5252
libc.src.string.memchr
5353
)
5454

55-
add_libc_unittest(
56-
memcmp_test
57-
SUITE
58-
libc_string_unittests
59-
SRCS
60-
memcmp_test.cpp
61-
DEPENDS
62-
libc.src.string.memcmp
63-
)
64-
65-
add_libc_unittest(
66-
memmove_test
67-
SUITE
68-
libc_string_unittests
69-
SRCS
70-
memmove_test.cpp
71-
DEPENDS
72-
libc.src.string.memcmp
73-
libc.src.string.memmove
74-
)
75-
7655
add_libc_unittest(
7756
strchr_test
7857
SUITE
@@ -209,3 +188,5 @@ endfunction()
209188
add_libc_multi_impl_test(memcpy SRCS memcpy_test.cpp)
210189
add_libc_multi_impl_test(memset SRCS memset_test.cpp)
211190
add_libc_multi_impl_test(bzero SRCS bzero_test.cpp)
191+
add_libc_multi_impl_test(memcmp SRCS memcmp_test.cpp)
192+
add_libc_multi_impl_test(memmove SRCS memmove_test.cpp)

libc/test/src/string/memcmp_test.cpp

Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@
88

99
#include "src/string/memcmp.h"
1010
#include "utils/UnitTest/Test.h"
11+
#include <cstring>
1112

1213
TEST(LlvmLibcMemcmpTest, CmpZeroByte) {
1314
const char *lhs = "ab";
@@ -32,3 +33,22 @@ TEST(LlvmLibcMemcmpTest, LhsAfterRhsLexically) {
3233
const char *rhs = "ab";
3334
EXPECT_EQ(__llvm_libc::memcmp(lhs, rhs, 2), 1);
3435
}
36+
37+
TEST(LlvmLibcMemcmpTest, Sweep) {
38+
static constexpr size_t kMaxSize = 1024;
39+
char lhs[kMaxSize];
40+
char rhs[kMaxSize];
41+
42+
memset(lhs, 'a', sizeof(lhs));
43+
memset(rhs, 'a', sizeof(rhs));
44+
for (int i = 0; i < kMaxSize; ++i)
45+
EXPECT_EQ(__llvm_libc::memcmp(lhs, rhs, i), 0);
46+
47+
memset(lhs, 'a', sizeof(lhs));
48+
memset(rhs, 'a', sizeof(rhs));
49+
for (int i = 0; i < kMaxSize; ++i) {
50+
rhs[i] = 'b';
51+
EXPECT_EQ(__llvm_libc::memcmp(lhs, rhs, kMaxSize), -1);
52+
rhs[i] = 'a';
53+
}
54+
}

0 commit comments

Comments
 (0)