Skip to content

Commit c2936b6

Browse files
committed
[CodeGen] Improve ExpandMemCmp for more efficient non-register aligned sizes handling
* Enhanced the logic of ExpandMemCmp pass to merge contiguous subsequences in LoadSequence, based on sizes allowed in `AllowedTailExpansions`. * This enhancement seeks to minimize the number of basic blocks and produce optimized code when using memcmp with non-register aligned sizes. * Enable this feature for AArch64 with memcmp sizes modulo 8 equal to 3, 5, and 6.
1 parent 7548c46 commit c2936b6

File tree

5 files changed

+140
-150
lines changed

5 files changed

+140
-150
lines changed

llvm/include/llvm/Analysis/TargetTransformInfo.h

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -907,6 +907,17 @@ class TargetTransformInfo {
907907
// be done with two 4-byte compares instead of 4+2+1-byte compares. This
908908
// requires all loads in LoadSizes to be doable in an unaligned way.
909909
bool AllowOverlappingLoads = false;
910+
911+
// Sometimes, the amount of data that needs to be compared is smaller than
912+
// the standard register size, but it cannot be loaded with just one load
913+
// instruction. For example, if the size of the memory comparison is 6
914+
// bytes, we can handle it more efficiently by loading all 6 bytes in a
915+
// single block and generating an 8-byte number, instead of generating two
916+
// separate blocks with conditional jumps for 4 and 2 byte loads. This
917+
// approach simplifies the process and produces the comparison result as
918+
// normal. This array lists the allowed sizes of memcmp tails that can be
919+
// merged into one block
920+
SmallVector<unsigned, 4> AllowedTailExpansions;
910921
};
911922
MemCmpExpansionOptions enableMemCmpExpansion(bool OptSize,
912923
bool IsZeroCmp) const;

llvm/lib/CodeGen/ExpandMemCmp.cpp

Lines changed: 59 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -117,8 +117,8 @@ class MemCmpExpansion {
117117
Value *Lhs = nullptr;
118118
Value *Rhs = nullptr;
119119
};
120-
LoadPair getLoadPair(Type *LoadSizeType, bool NeedsBSwap, Type *CmpSizeType,
121-
unsigned OffsetBytes);
120+
LoadPair getLoadPair(Type *LoadSizeType, bool NeedsBSwap, Type *BSwapSizeType,
121+
Type *CmpSizeType, unsigned OffsetBytes);
122122

123123
static LoadEntryVector
124124
computeGreedyLoadSequence(uint64_t Size, llvm::ArrayRef<unsigned> LoadSizes,
@@ -255,6 +255,31 @@ MemCmpExpansion::MemCmpExpansion(
255255
}
256256
}
257257
assert(LoadSequence.size() <= Options.MaxNumLoads && "broken invariant");
258+
// This part of code attempts to optimize the LoadSequence by merging allowed
259+
// subsequences into single loads of allowed sizes from
260+
// `AllowedTailExpansions`. If it is for zero comparison or if no allowed tail
261+
// expansions are specified, we exit early.
262+
if (IsUsedForZeroCmp || !Options.AllowedTailExpansions.size())
263+
return;
264+
265+
while (LoadSequence.size() >= 2) {
266+
auto Last = LoadSequence[LoadSequence.size() - 1];
267+
auto PreLast = LoadSequence[LoadSequence.size() - 2];
268+
269+
// Exit the loop if the two sequences are not contiguous
270+
if (PreLast.Offset + PreLast.LoadSize != Last.Offset)
271+
break;
272+
273+
auto LoadSize = Last.LoadSize + PreLast.LoadSize;
274+
if (find(Options.AllowedTailExpansions, LoadSize) ==
275+
Options.AllowedTailExpansions.end())
276+
break;
277+
278+
// Remove the last two sequences and replace with the combined sequence
279+
LoadSequence.pop_back();
280+
LoadSequence.pop_back();
281+
LoadSequence.emplace_back(PreLast.Offset, LoadSize);
282+
}
258283
}
259284

260285
unsigned MemCmpExpansion::getNumBlocks() {
@@ -279,6 +304,7 @@ void MemCmpExpansion::createResultBlock() {
279304

280305
MemCmpExpansion::LoadPair MemCmpExpansion::getLoadPair(Type *LoadSizeType,
281306
bool NeedsBSwap,
307+
Type *BSwapSizeType,
282308
Type *CmpSizeType,
283309
unsigned OffsetBytes) {
284310
// Get the memory source at offset `OffsetBytes`.
@@ -307,16 +333,22 @@ MemCmpExpansion::LoadPair MemCmpExpansion::getLoadPair(Type *LoadSizeType,
307333
if (!Rhs)
308334
Rhs = Builder.CreateAlignedLoad(LoadSizeType, RhsSource, RhsAlign);
309335

336+
// Zero extend if Byte Swap intrinsic has different type
337+
if (NeedsBSwap && LoadSizeType != BSwapSizeType) {
338+
Lhs = Builder.CreateZExt(Lhs, BSwapSizeType);
339+
Rhs = Builder.CreateZExt(Rhs, BSwapSizeType);
340+
}
341+
310342
// Swap bytes if required.
311343
if (NeedsBSwap) {
312-
Function *Bswap = Intrinsic::getDeclaration(CI->getModule(),
313-
Intrinsic::bswap, LoadSizeType);
344+
Function *Bswap = Intrinsic::getDeclaration(
345+
CI->getModule(), Intrinsic::bswap, BSwapSizeType);
314346
Lhs = Builder.CreateCall(Bswap, Lhs);
315347
Rhs = Builder.CreateCall(Bswap, Rhs);
316348
}
317349

318350
// Zero extend if required.
319-
if (CmpSizeType != nullptr && CmpSizeType != LoadSizeType) {
351+
if (CmpSizeType != nullptr && CmpSizeType != Lhs->getType()) {
320352
Lhs = Builder.CreateZExt(Lhs, CmpSizeType);
321353
Rhs = Builder.CreateZExt(Rhs, CmpSizeType);
322354
}
@@ -333,7 +365,7 @@ void MemCmpExpansion::emitLoadCompareByteBlock(unsigned BlockIndex,
333365
Builder.SetInsertPoint(BB);
334366
const LoadPair Loads =
335367
getLoadPair(Type::getInt8Ty(CI->getContext()), /*NeedsBSwap=*/false,
336-
Type::getInt32Ty(CI->getContext()), OffsetBytes);
368+
nullptr, Type::getInt32Ty(CI->getContext()), OffsetBytes);
337369
Value *Diff = Builder.CreateSub(Loads.Lhs, Loads.Rhs);
338370

339371
PhiRes->addIncoming(Diff, BB);
@@ -385,11 +417,12 @@ Value *MemCmpExpansion::getCompareLoadPairs(unsigned BlockIndex,
385417
IntegerType *const MaxLoadType =
386418
NumLoads == 1 ? nullptr
387419
: IntegerType::get(CI->getContext(), MaxLoadSize * 8);
420+
388421
for (unsigned i = 0; i < NumLoads; ++i, ++LoadIndex) {
389422
const LoadEntry &CurLoadEntry = LoadSequence[LoadIndex];
390423
const LoadPair Loads = getLoadPair(
391424
IntegerType::get(CI->getContext(), CurLoadEntry.LoadSize * 8),
392-
/*NeedsBSwap=*/false, MaxLoadType, CurLoadEntry.Offset);
425+
/*NeedsBSwap=*/false, nullptr, MaxLoadType, CurLoadEntry.Offset);
393426

394427
if (NumLoads != 1) {
395428
// If we have multiple loads per block, we need to generate a composite
@@ -475,14 +508,18 @@ void MemCmpExpansion::emitLoadCompareBlock(unsigned BlockIndex) {
475508

476509
Type *LoadSizeType =
477510
IntegerType::get(CI->getContext(), CurLoadEntry.LoadSize * 8);
478-
Type *MaxLoadType = IntegerType::get(CI->getContext(), MaxLoadSize * 8);
511+
Type *BSwapSizeType = IntegerType::get(
512+
CI->getContext(), PowerOf2Ceil(CurLoadEntry.LoadSize * 8));
513+
Type *MaxLoadType = IntegerType::get(
514+
CI->getContext(),
515+
std::max(MaxLoadSize, (unsigned)PowerOf2Ceil(CurLoadEntry.LoadSize)) * 8);
479516
assert(CurLoadEntry.LoadSize <= MaxLoadSize && "Unexpected load type");
480517

481518
Builder.SetInsertPoint(LoadCmpBlocks[BlockIndex]);
482519

483520
const LoadPair Loads =
484-
getLoadPair(LoadSizeType, /*NeedsBSwap=*/DL.isLittleEndian(), MaxLoadType,
485-
CurLoadEntry.Offset);
521+
getLoadPair(LoadSizeType, /*NeedsBSwap=*/DL.isLittleEndian(),
522+
BSwapSizeType, MaxLoadType, CurLoadEntry.Offset);
486523

487524
// Add the loaded values to the phi nodes for calculating memcmp result only
488525
// if result is not used in a zero equality.
@@ -588,19 +625,26 @@ Value *MemCmpExpansion::getMemCmpEqZeroOneBlock() {
588625
/// the compare, branch, and phi IR that is required in the general case.
589626
Value *MemCmpExpansion::getMemCmpOneBlock() {
590627
Type *LoadSizeType = IntegerType::get(CI->getContext(), Size * 8);
628+
Type *BSwapSizeType =
629+
IntegerType::get(CI->getContext(), PowerOf2Ceil(Size * 8));
630+
Type *MaxLoadType =
631+
IntegerType::get(CI->getContext(),
632+
std::max(MaxLoadSize, (unsigned)PowerOf2Ceil(Size)) * 8);
633+
591634
bool NeedsBSwap = DL.isLittleEndian() && Size != 1;
592635

593636
// The i8 and i16 cases don't need compares. We zext the loaded values and
594637
// subtract them to get the suitable negative, zero, or positive i32 result.
595638
if (Size < 4) {
596-
const LoadPair Loads =
597-
getLoadPair(LoadSizeType, NeedsBSwap, Builder.getInt32Ty(),
598-
/*Offset*/ 0);
639+
const LoadPair Loads = getLoadPair(LoadSizeType, NeedsBSwap, BSwapSizeType,
640+
Builder.getInt32Ty(),
641+
/*Offset*/ 0);
599642
return Builder.CreateSub(Loads.Lhs, Loads.Rhs);
600643
}
601644

602-
const LoadPair Loads = getLoadPair(LoadSizeType, NeedsBSwap, LoadSizeType,
603-
/*Offset*/ 0);
645+
const LoadPair Loads =
646+
getLoadPair(LoadSizeType, NeedsBSwap, BSwapSizeType, MaxLoadType,
647+
/*Offset*/ 0);
604648
// The result of memcmp is negative, zero, or positive, so produce that by
605649
// subtracting 2 extended compare bits: sub (ugt, ult).
606650
// If a target prefers to use selects to get -1/0/1, they should be able

llvm/lib/Target/AArch64/AArch64TargetTransformInfo.cpp

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2961,6 +2961,7 @@ AArch64TTIImpl::enableMemCmpExpansion(bool OptSize, bool IsZeroCmp) const {
29612961
// they may wake up the FP unit, which raises the power consumption. Perhaps
29622962
// they could be used with no holds barred (-O3).
29632963
Options.LoadSizes = {8, 4, 2, 1};
2964+
Options.AllowedTailExpansions = {3, 5, 6};
29642965
return Options;
29652966
}
29662967

llvm/test/CodeGen/AArch64/memcmp.ll

Lines changed: 37 additions & 59 deletions
Original file line numberDiff line numberDiff line change
@@ -152,22 +152,15 @@ define i1 @length2_eq_nobuiltin_attr(ptr %X, ptr %Y) nounwind {
152152
define i32 @length3(ptr %X, ptr %Y) nounwind {
153153
; CHECK-LABEL: length3:
154154
; CHECK: // %bb.0:
155-
; CHECK-NEXT: ldrh w8, [x0]
156-
; CHECK-NEXT: ldrh w9, [x1]
155+
; CHECK-NEXT: ldrb w8, [x0, #2]
156+
; CHECK-NEXT: ldrh w9, [x0]
157+
; CHECK-NEXT: ldrb w10, [x1, #2]
158+
; CHECK-NEXT: ldrh w11, [x1]
159+
; CHECK-NEXT: orr w8, w9, w8, lsl #16
160+
; CHECK-NEXT: orr w9, w11, w10, lsl #16
157161
; CHECK-NEXT: rev w8, w8
158162
; CHECK-NEXT: rev w9, w9
159-
; CHECK-NEXT: lsr w8, w8, #16
160-
; CHECK-NEXT: lsr w9, w9, #16
161-
; CHECK-NEXT: cmp w8, w9
162-
; CHECK-NEXT: b.ne .LBB11_2
163-
; CHECK-NEXT: // %bb.1: // %loadbb1
164-
; CHECK-NEXT: ldrb w8, [x0, #2]
165-
; CHECK-NEXT: ldrb w9, [x1, #2]
166163
; CHECK-NEXT: sub w0, w8, w9
167-
; CHECK-NEXT: ret
168-
; CHECK-NEXT: .LBB11_2: // %res_block
169-
; CHECK-NEXT: mov w8, #-1 // =0xffffffff
170-
; CHECK-NEXT: cneg w0, w8, hs
171164
; CHECK-NEXT: ret
172165
%m = tail call i32 @memcmp(ptr %X, ptr %Y, i64 3) nounwind
173166
ret i32 %m
@@ -272,20 +265,18 @@ define i1 @length4_eq_const(ptr %X) nounwind {
272265
define i32 @length5(ptr %X, ptr %Y) nounwind {
273266
; CHECK-LABEL: length5:
274267
; CHECK: // %bb.0:
275-
; CHECK-NEXT: ldr w8, [x0]
276-
; CHECK-NEXT: ldr w9, [x1]
277-
; CHECK-NEXT: rev w8, w8
278-
; CHECK-NEXT: rev w9, w9
279-
; CHECK-NEXT: cmp w8, w9
280-
; CHECK-NEXT: b.ne .LBB18_2
281-
; CHECK-NEXT: // %bb.1: // %loadbb1
282268
; CHECK-NEXT: ldrb w8, [x0, #4]
283-
; CHECK-NEXT: ldrb w9, [x1, #4]
269+
; CHECK-NEXT: ldr w9, [x0]
270+
; CHECK-NEXT: ldrb w10, [x1, #4]
271+
; CHECK-NEXT: ldr w11, [x1]
272+
; CHECK-NEXT: orr x8, x9, x8, lsl #32
273+
; CHECK-NEXT: orr x9, x11, x10, lsl #32
274+
; CHECK-NEXT: rev x8, x8
275+
; CHECK-NEXT: rev x9, x9
276+
; CHECK-NEXT: cmp x8, x9
277+
; CHECK-NEXT: cset w8, hi
278+
; CHECK-NEXT: cset w9, lo
284279
; CHECK-NEXT: sub w0, w8, w9
285-
; CHECK-NEXT: ret
286-
; CHECK-NEXT: .LBB18_2: // %res_block
287-
; CHECK-NEXT: mov w8, #-1 // =0xffffffff
288-
; CHECK-NEXT: cneg w0, w8, hs
289280
; CHECK-NEXT: ret
290281
%m = tail call i32 @memcmp(ptr %X, ptr %Y, i64 5) nounwind
291282
ret i32 %m
@@ -310,22 +301,19 @@ define i1 @length5_eq(ptr %X, ptr %Y) nounwind {
310301
define i1 @length5_lt(ptr %X, ptr %Y) nounwind {
311302
; CHECK-LABEL: length5_lt:
312303
; CHECK: // %bb.0:
313-
; CHECK-NEXT: ldr w8, [x0]
314-
; CHECK-NEXT: ldr w9, [x1]
315-
; CHECK-NEXT: rev w8, w8
316-
; CHECK-NEXT: rev w9, w9
317-
; CHECK-NEXT: cmp w8, w9
318-
; CHECK-NEXT: b.ne .LBB20_2
319-
; CHECK-NEXT: // %bb.1: // %loadbb1
320304
; CHECK-NEXT: ldrb w8, [x0, #4]
321-
; CHECK-NEXT: ldrb w9, [x1, #4]
305+
; CHECK-NEXT: ldr w9, [x0]
306+
; CHECK-NEXT: ldrb w10, [x1, #4]
307+
; CHECK-NEXT: ldr w11, [x1]
308+
; CHECK-NEXT: orr x8, x9, x8, lsl #32
309+
; CHECK-NEXT: orr x9, x11, x10, lsl #32
310+
; CHECK-NEXT: rev x8, x8
311+
; CHECK-NEXT: rev x9, x9
312+
; CHECK-NEXT: cmp x8, x9
313+
; CHECK-NEXT: cset w8, hi
314+
; CHECK-NEXT: cset w9, lo
322315
; CHECK-NEXT: sub w8, w8, w9
323316
; CHECK-NEXT: lsr w0, w8, #31
324-
; CHECK-NEXT: ret
325-
; CHECK-NEXT: .LBB20_2: // %res_block
326-
; CHECK-NEXT: mov w8, #-1 // =0xffffffff
327-
; CHECK-NEXT: cneg w8, w8, hs
328-
; CHECK-NEXT: lsr w0, w8, #31
329317
; CHECK-NEXT: ret
330318
%m = tail call i32 @memcmp(ptr %X, ptr %Y, i64 5) nounwind
331319
%c = icmp slt i32 %m, 0
@@ -335,28 +323,18 @@ define i1 @length5_lt(ptr %X, ptr %Y) nounwind {
335323
define i32 @length6(ptr %X, ptr %Y) nounwind {
336324
; CHECK-LABEL: length6:
337325
; CHECK: // %bb.0:
338-
; CHECK-NEXT: ldr w8, [x0]
339-
; CHECK-NEXT: ldr w9, [x1]
340-
; CHECK-NEXT: rev w8, w8
341-
; CHECK-NEXT: rev w9, w9
342-
; CHECK-NEXT: cmp w8, w9
343-
; CHECK-NEXT: b.ne .LBB21_3
344-
; CHECK-NEXT: // %bb.1: // %loadbb1
345326
; CHECK-NEXT: ldrh w8, [x0, #4]
346-
; CHECK-NEXT: ldrh w9, [x1, #4]
347-
; CHECK-NEXT: rev w8, w8
348-
; CHECK-NEXT: rev w9, w9
349-
; CHECK-NEXT: lsr w8, w8, #16
350-
; CHECK-NEXT: lsr w9, w9, #16
351-
; CHECK-NEXT: cmp w8, w9
352-
; CHECK-NEXT: b.ne .LBB21_3
353-
; CHECK-NEXT: // %bb.2:
354-
; CHECK-NEXT: mov w0, wzr
355-
; CHECK-NEXT: ret
356-
; CHECK-NEXT: .LBB21_3: // %res_block
357-
; CHECK-NEXT: cmp w8, w9
358-
; CHECK-NEXT: mov w8, #-1 // =0xffffffff
359-
; CHECK-NEXT: cneg w0, w8, hs
327+
; CHECK-NEXT: ldr w9, [x0]
328+
; CHECK-NEXT: ldrh w10, [x1, #4]
329+
; CHECK-NEXT: ldr w11, [x1]
330+
; CHECK-NEXT: orr x8, x9, x8, lsl #32
331+
; CHECK-NEXT: orr x9, x11, x10, lsl #32
332+
; CHECK-NEXT: rev x8, x8
333+
; CHECK-NEXT: rev x9, x9
334+
; CHECK-NEXT: cmp x8, x9
335+
; CHECK-NEXT: cset w8, hi
336+
; CHECK-NEXT: cset w9, lo
337+
; CHECK-NEXT: sub w0, w8, w9
360338
; CHECK-NEXT: ret
361339
%m = tail call i32 @memcmp(ptr %X, ptr %Y, i64 6) nounwind
362340
ret i32 %m

0 commit comments

Comments
 (0)