Skip to content

Commit 849f963

Browse files
authored
[CodeGen] Improve ExpandMemCmp for more efficient non-register aligned sizes handling (#70469)
* Enhanced the logic of ExpandMemCmp pass to merge contiguous subsequences in LoadSequence, based on sizes allowed in `AllowedTailExpansions`. * This enhancement seeks to minimize the number of basic blocks and produce optimized code when using memcmp with non-register aligned sizes. * Enable this feature for AArch64 with memcmp sizes modulo 8 equal to 3, 5, and 6. Reapplication of #69942 after fixing a bug
1 parent 89564f0 commit 849f963

File tree

5 files changed

+3973
-20
lines changed

5 files changed

+3973
-20
lines changed

llvm/include/llvm/Analysis/TargetTransformInfo.h

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -907,6 +907,17 @@ class TargetTransformInfo {
907907
// be done with two 4-byte compares instead of 4+2+1-byte compares. This
908908
// requires all loads in LoadSizes to be doable in an unaligned way.
909909
bool AllowOverlappingLoads = false;
910+
911+
// Sometimes, the amount of data that needs to be compared is smaller than
912+
// the standard register size, but it cannot be loaded with just one load
913+
// instruction. For example, if the size of the memory comparison is 6
914+
// bytes, we can handle it more efficiently by loading all 6 bytes in a
915+
// single block and generating an 8-byte number, instead of generating two
916+
// separate blocks with conditional jumps for 4 and 2 byte loads. This
917+
// approach simplifies the process and produces the comparison result as
918+
// normal. This array lists the allowed sizes of memcmp tails that can be
919+
// merged into one block
920+
SmallVector<unsigned, 4> AllowedTailExpansions;
910921
};
911922
MemCmpExpansionOptions enableMemCmpExpansion(bool OptSize,
912923
bool IsZeroCmp) const;

llvm/lib/CodeGen/ExpandMemCmp.cpp

Lines changed: 75 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -117,8 +117,8 @@ class MemCmpExpansion {
117117
Value *Lhs = nullptr;
118118
Value *Rhs = nullptr;
119119
};
120-
LoadPair getLoadPair(Type *LoadSizeType, bool NeedsBSwap, Type *CmpSizeType,
121-
unsigned OffsetBytes);
120+
LoadPair getLoadPair(Type *LoadSizeType, Type *BSwapSizeType,
121+
Type *CmpSizeType, unsigned OffsetBytes);
122122

123123
static LoadEntryVector
124124
computeGreedyLoadSequence(uint64_t Size, llvm::ArrayRef<unsigned> LoadSizes,
@@ -128,6 +128,11 @@ class MemCmpExpansion {
128128
unsigned MaxNumLoads,
129129
unsigned &NumLoadsNonOneByte);
130130

131+
static void optimiseLoadSequence(
132+
LoadEntryVector &LoadSequence,
133+
const TargetTransformInfo::MemCmpExpansionOptions &Options,
134+
bool IsUsedForZeroCmp);
135+
131136
public:
132137
MemCmpExpansion(CallInst *CI, uint64_t Size,
133138
const TargetTransformInfo::MemCmpExpansionOptions &Options,
@@ -210,6 +215,37 @@ MemCmpExpansion::computeOverlappingLoadSequence(uint64_t Size,
210215
return LoadSequence;
211216
}
212217

218+
void MemCmpExpansion::optimiseLoadSequence(
219+
LoadEntryVector &LoadSequence,
220+
const TargetTransformInfo::MemCmpExpansionOptions &Options,
221+
bool IsUsedForZeroCmp) {
222+
// This part of code attempts to optimize the LoadSequence by merging allowed
223+
// subsequences into single loads of allowed sizes from
224+
// `MemCmpExpansionOptions::AllowedTailExpansions`. If it is for zero
225+
// comparison or if no allowed tail expansions are specified, we exit early.
226+
if (IsUsedForZeroCmp || Options.AllowedTailExpansions.empty())
227+
return;
228+
229+
while (LoadSequence.size() >= 2) {
230+
auto Last = LoadSequence[LoadSequence.size() - 1];
231+
auto PreLast = LoadSequence[LoadSequence.size() - 2];
232+
233+
// Exit the loop if the two sequences are not contiguous
234+
if (PreLast.Offset + PreLast.LoadSize != Last.Offset)
235+
break;
236+
237+
auto LoadSize = Last.LoadSize + PreLast.LoadSize;
238+
if (find(Options.AllowedTailExpansions, LoadSize) ==
239+
Options.AllowedTailExpansions.end())
240+
break;
241+
242+
// Remove the last two sequences and replace with the combined sequence
243+
LoadSequence.pop_back();
244+
LoadSequence.pop_back();
245+
LoadSequence.emplace_back(PreLast.Offset, LoadSize);
246+
}
247+
}
248+
213249
// Initialize the basic block structure required for expansion of memcmp call
214250
// with given maximum load size and memcmp size parameter.
215251
// This structure includes:
@@ -255,6 +291,7 @@ MemCmpExpansion::MemCmpExpansion(
255291
}
256292
}
257293
assert(LoadSequence.size() <= Options.MaxNumLoads && "broken invariant");
294+
optimiseLoadSequence(LoadSequence, Options, IsUsedForZeroCmp);
258295
}
259296

260297
unsigned MemCmpExpansion::getNumBlocks() {
@@ -278,7 +315,7 @@ void MemCmpExpansion::createResultBlock() {
278315
}
279316

280317
MemCmpExpansion::LoadPair MemCmpExpansion::getLoadPair(Type *LoadSizeType,
281-
bool NeedsBSwap,
318+
Type *BSwapSizeType,
282319
Type *CmpSizeType,
283320
unsigned OffsetBytes) {
284321
// Get the memory source at offset `OffsetBytes`.
@@ -307,16 +344,22 @@ MemCmpExpansion::LoadPair MemCmpExpansion::getLoadPair(Type *LoadSizeType,
307344
if (!Rhs)
308345
Rhs = Builder.CreateAlignedLoad(LoadSizeType, RhsSource, RhsAlign);
309346

347+
// Zero extend if Byte Swap intrinsic has different type
348+
if (BSwapSizeType && LoadSizeType != BSwapSizeType) {
349+
Lhs = Builder.CreateZExt(Lhs, BSwapSizeType);
350+
Rhs = Builder.CreateZExt(Rhs, BSwapSizeType);
351+
}
352+
310353
// Swap bytes if required.
311-
if (NeedsBSwap) {
312-
Function *Bswap = Intrinsic::getDeclaration(CI->getModule(),
313-
Intrinsic::bswap, LoadSizeType);
354+
if (BSwapSizeType) {
355+
Function *Bswap = Intrinsic::getDeclaration(
356+
CI->getModule(), Intrinsic::bswap, BSwapSizeType);
314357
Lhs = Builder.CreateCall(Bswap, Lhs);
315358
Rhs = Builder.CreateCall(Bswap, Rhs);
316359
}
317360

318361
// Zero extend if required.
319-
if (CmpSizeType != nullptr && CmpSizeType != LoadSizeType) {
362+
if (CmpSizeType != nullptr && CmpSizeType != Lhs->getType()) {
320363
Lhs = Builder.CreateZExt(Lhs, CmpSizeType);
321364
Rhs = Builder.CreateZExt(Rhs, CmpSizeType);
322365
}
@@ -332,7 +375,7 @@ void MemCmpExpansion::emitLoadCompareByteBlock(unsigned BlockIndex,
332375
BasicBlock *BB = LoadCmpBlocks[BlockIndex];
333376
Builder.SetInsertPoint(BB);
334377
const LoadPair Loads =
335-
getLoadPair(Type::getInt8Ty(CI->getContext()), /*NeedsBSwap=*/false,
378+
getLoadPair(Type::getInt8Ty(CI->getContext()), nullptr,
336379
Type::getInt32Ty(CI->getContext()), OffsetBytes);
337380
Value *Diff = Builder.CreateSub(Loads.Lhs, Loads.Rhs);
338381

@@ -385,11 +428,12 @@ Value *MemCmpExpansion::getCompareLoadPairs(unsigned BlockIndex,
385428
IntegerType *const MaxLoadType =
386429
NumLoads == 1 ? nullptr
387430
: IntegerType::get(CI->getContext(), MaxLoadSize * 8);
431+
388432
for (unsigned i = 0; i < NumLoads; ++i, ++LoadIndex) {
389433
const LoadEntry &CurLoadEntry = LoadSequence[LoadIndex];
390434
const LoadPair Loads = getLoadPair(
391-
IntegerType::get(CI->getContext(), CurLoadEntry.LoadSize * 8),
392-
/*NeedsBSwap=*/false, MaxLoadType, CurLoadEntry.Offset);
435+
IntegerType::get(CI->getContext(), CurLoadEntry.LoadSize * 8), nullptr,
436+
MaxLoadType, CurLoadEntry.Offset);
393437

394438
if (NumLoads != 1) {
395439
// If we have multiple loads per block, we need to generate a composite
@@ -475,14 +519,20 @@ void MemCmpExpansion::emitLoadCompareBlock(unsigned BlockIndex) {
475519

476520
Type *LoadSizeType =
477521
IntegerType::get(CI->getContext(), CurLoadEntry.LoadSize * 8);
478-
Type *MaxLoadType = IntegerType::get(CI->getContext(), MaxLoadSize * 8);
522+
Type *BSwapSizeType =
523+
DL.isLittleEndian()
524+
? IntegerType::get(CI->getContext(),
525+
PowerOf2Ceil(CurLoadEntry.LoadSize * 8))
526+
: nullptr;
527+
Type *MaxLoadType = IntegerType::get(
528+
CI->getContext(),
529+
std::max(MaxLoadSize, (unsigned)PowerOf2Ceil(CurLoadEntry.LoadSize)) * 8);
479530
assert(CurLoadEntry.LoadSize <= MaxLoadSize && "Unexpected load type");
480531

481532
Builder.SetInsertPoint(LoadCmpBlocks[BlockIndex]);
482533

483-
const LoadPair Loads =
484-
getLoadPair(LoadSizeType, /*NeedsBSwap=*/DL.isLittleEndian(), MaxLoadType,
485-
CurLoadEntry.Offset);
534+
const LoadPair Loads = getLoadPair(LoadSizeType, BSwapSizeType, MaxLoadType,
535+
CurLoadEntry.Offset);
486536

487537
// Add the loaded values to the phi nodes for calculating memcmp result only
488538
// if result is not used in a zero equality.
@@ -587,19 +637,24 @@ Value *MemCmpExpansion::getMemCmpEqZeroOneBlock() {
587637
/// A memcmp expansion that only has one block of load and compare can bypass
588638
/// the compare, branch, and phi IR that is required in the general case.
589639
Value *MemCmpExpansion::getMemCmpOneBlock() {
590-
Type *LoadSizeType = IntegerType::get(CI->getContext(), Size * 8);
591640
bool NeedsBSwap = DL.isLittleEndian() && Size != 1;
641+
Type *LoadSizeType = IntegerType::get(CI->getContext(), Size * 8);
642+
Type *BSwapSizeType =
643+
NeedsBSwap ? IntegerType::get(CI->getContext(), PowerOf2Ceil(Size * 8))
644+
: nullptr;
645+
Type *MaxLoadType =
646+
IntegerType::get(CI->getContext(),
647+
std::max(MaxLoadSize, (unsigned)PowerOf2Ceil(Size)) * 8);
592648

593649
// The i8 and i16 cases don't need compares. We zext the loaded values and
594650
// subtract them to get the suitable negative, zero, or positive i32 result.
595-
if (Size < 4) {
596-
const LoadPair Loads =
597-
getLoadPair(LoadSizeType, NeedsBSwap, Builder.getInt32Ty(),
598-
/*Offset*/ 0);
651+
if (Size == 1 || Size == 2) {
652+
const LoadPair Loads = getLoadPair(LoadSizeType, BSwapSizeType,
653+
Builder.getInt32Ty(), /*Offset*/ 0);
599654
return Builder.CreateSub(Loads.Lhs, Loads.Rhs);
600655
}
601656

602-
const LoadPair Loads = getLoadPair(LoadSizeType, NeedsBSwap, LoadSizeType,
657+
const LoadPair Loads = getLoadPair(LoadSizeType, BSwapSizeType, MaxLoadType,
603658
/*Offset*/ 0);
604659
// The result of memcmp is negative, zero, or positive, so produce that by
605660
// subtracting 2 extended compare bits: sub (ugt, ult).

llvm/lib/Target/AArch64/AArch64TargetTransformInfo.cpp

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2994,6 +2994,7 @@ AArch64TTIImpl::enableMemCmpExpansion(bool OptSize, bool IsZeroCmp) const {
29942994
// they may wake up the FP unit, which raises the power consumption. Perhaps
29952995
// they could be used with no holds barred (-O3).
29962996
Options.LoadSizes = {8, 4, 2, 1};
2997+
Options.AllowedTailExpansions = {3, 5, 6};
29972998
return Options;
29982999
}
29993000

0 commit comments

Comments
 (0)