Skip to content

Commit 7686677

Browse files
committed
Address the review comments
1 parent c2936b6 commit 7686677

File tree

1 file changed

+58
-47
lines changed

1 file changed

+58
-47
lines changed

llvm/lib/CodeGen/ExpandMemCmp.cpp

Lines changed: 58 additions & 47 deletions
Original file line numberDiff line numberDiff line change
@@ -117,7 +117,7 @@ class MemCmpExpansion {
117117
Value *Lhs = nullptr;
118118
Value *Rhs = nullptr;
119119
};
120-
LoadPair getLoadPair(Type *LoadSizeType, bool NeedsBSwap, Type *BSwapSizeType,
120+
LoadPair getLoadPair(Type *LoadSizeType, Type *BSwapSizeType,
121121
Type *CmpSizeType, unsigned OffsetBytes);
122122

123123
static LoadEntryVector
@@ -128,6 +128,11 @@ class MemCmpExpansion {
128128
unsigned MaxNumLoads,
129129
unsigned &NumLoadsNonOneByte);
130130

131+
static void optimiseLoadSequence(
132+
LoadEntryVector &LoadSequence,
133+
const TargetTransformInfo::MemCmpExpansionOptions &Options,
134+
bool IsUsedForZeroCmp);
135+
131136
public:
132137
MemCmpExpansion(CallInst *CI, uint64_t Size,
133138
const TargetTransformInfo::MemCmpExpansionOptions &Options,
@@ -210,6 +215,37 @@ MemCmpExpansion::computeOverlappingLoadSequence(uint64_t Size,
210215
return LoadSequence;
211216
}
212217

218+
void MemCmpExpansion::optimiseLoadSequence(
219+
LoadEntryVector &LoadSequence,
220+
const TargetTransformInfo::MemCmpExpansionOptions &Options,
221+
bool IsUsedForZeroCmp) {
222+
// This part of code attempts to optimize the LoadSequence by merging allowed
223+
// subsequences into single loads of allowed sizes from
224+
// `MemCmpExpansionOptions::AllowedTailExpansions`. If it is for zero
225+
// comparison or if no allowed tail expansions are specified, we exit early.
226+
if (IsUsedForZeroCmp || Options.AllowedTailExpansions.empty())
227+
return;
228+
229+
while (LoadSequence.size() >= 2) {
230+
auto Last = LoadSequence[LoadSequence.size() - 1];
231+
auto PreLast = LoadSequence[LoadSequence.size() - 2];
232+
233+
// Exit the loop if the two sequences are not contiguous
234+
if (PreLast.Offset + PreLast.LoadSize != Last.Offset)
235+
break;
236+
237+
auto LoadSize = Last.LoadSize + PreLast.LoadSize;
238+
if (find(Options.AllowedTailExpansions, LoadSize) ==
239+
Options.AllowedTailExpansions.end())
240+
break;
241+
242+
// Remove the last two sequences and replace with the combined sequence
243+
LoadSequence.pop_back();
244+
LoadSequence.pop_back();
245+
LoadSequence.emplace_back(PreLast.Offset, LoadSize);
246+
}
247+
}
248+
213249
// Initialize the basic block structure required for expansion of memcmp call
214250
// with given maximum load size and memcmp size parameter.
215251
// This structure includes:
@@ -255,31 +291,7 @@ MemCmpExpansion::MemCmpExpansion(
255291
}
256292
}
257293
assert(LoadSequence.size() <= Options.MaxNumLoads && "broken invariant");
258-
// This part of code attempts to optimize the LoadSequence by merging allowed
259-
// subsequences into single loads of allowed sizes from
260-
// `AllowedTailExpansions`. If it is for zero comparison or if no allowed tail
261-
// expansions are specified, we exit early.
262-
if (IsUsedForZeroCmp || !Options.AllowedTailExpansions.size())
263-
return;
264-
265-
while (LoadSequence.size() >= 2) {
266-
auto Last = LoadSequence[LoadSequence.size() - 1];
267-
auto PreLast = LoadSequence[LoadSequence.size() - 2];
268-
269-
// Exit the loop if the two sequences are not contiguous
270-
if (PreLast.Offset + PreLast.LoadSize != Last.Offset)
271-
break;
272-
273-
auto LoadSize = Last.LoadSize + PreLast.LoadSize;
274-
if (find(Options.AllowedTailExpansions, LoadSize) ==
275-
Options.AllowedTailExpansions.end())
276-
break;
277-
278-
// Remove the last two sequences and replace with the combined sequence
279-
LoadSequence.pop_back();
280-
LoadSequence.pop_back();
281-
LoadSequence.emplace_back(PreLast.Offset, LoadSize);
282-
}
294+
optimiseLoadSequence(LoadSequence, Options, IsUsedForZeroCmp);
283295
}
284296

285297
unsigned MemCmpExpansion::getNumBlocks() {
@@ -303,7 +315,6 @@ void MemCmpExpansion::createResultBlock() {
303315
}
304316

305317
MemCmpExpansion::LoadPair MemCmpExpansion::getLoadPair(Type *LoadSizeType,
306-
bool NeedsBSwap,
307318
Type *BSwapSizeType,
308319
Type *CmpSizeType,
309320
unsigned OffsetBytes) {
@@ -334,13 +345,13 @@ MemCmpExpansion::LoadPair MemCmpExpansion::getLoadPair(Type *LoadSizeType,
334345
Rhs = Builder.CreateAlignedLoad(LoadSizeType, RhsSource, RhsAlign);
335346

336347
// Zero extend if Byte Swap intrinsic has different type
337-
if (NeedsBSwap && LoadSizeType != BSwapSizeType) {
348+
if (BSwapSizeType && LoadSizeType != BSwapSizeType) {
338349
Lhs = Builder.CreateZExt(Lhs, BSwapSizeType);
339350
Rhs = Builder.CreateZExt(Rhs, BSwapSizeType);
340351
}
341352

342353
// Swap bytes if required.
343-
if (NeedsBSwap) {
354+
if (BSwapSizeType) {
344355
Function *Bswap = Intrinsic::getDeclaration(
345356
CI->getModule(), Intrinsic::bswap, BSwapSizeType);
346357
Lhs = Builder.CreateCall(Bswap, Lhs);
@@ -364,8 +375,8 @@ void MemCmpExpansion::emitLoadCompareByteBlock(unsigned BlockIndex,
364375
BasicBlock *BB = LoadCmpBlocks[BlockIndex];
365376
Builder.SetInsertPoint(BB);
366377
const LoadPair Loads =
367-
getLoadPair(Type::getInt8Ty(CI->getContext()), /*NeedsBSwap=*/false,
368-
nullptr, Type::getInt32Ty(CI->getContext()), OffsetBytes);
378+
getLoadPair(Type::getInt8Ty(CI->getContext()), nullptr,
379+
Type::getInt32Ty(CI->getContext()), OffsetBytes);
369380
Value *Diff = Builder.CreateSub(Loads.Lhs, Loads.Rhs);
370381

371382
PhiRes->addIncoming(Diff, BB);
@@ -421,8 +432,8 @@ Value *MemCmpExpansion::getCompareLoadPairs(unsigned BlockIndex,
421432
for (unsigned i = 0; i < NumLoads; ++i, ++LoadIndex) {
422433
const LoadEntry &CurLoadEntry = LoadSequence[LoadIndex];
423434
const LoadPair Loads = getLoadPair(
424-
IntegerType::get(CI->getContext(), CurLoadEntry.LoadSize * 8),
425-
/*NeedsBSwap=*/false, nullptr, MaxLoadType, CurLoadEntry.Offset);
435+
IntegerType::get(CI->getContext(), CurLoadEntry.LoadSize * 8), nullptr,
436+
MaxLoadType, CurLoadEntry.Offset);
426437

427438
if (NumLoads != 1) {
428439
// If we have multiple loads per block, we need to generate a composite
@@ -508,18 +519,20 @@ void MemCmpExpansion::emitLoadCompareBlock(unsigned BlockIndex) {
508519

509520
Type *LoadSizeType =
510521
IntegerType::get(CI->getContext(), CurLoadEntry.LoadSize * 8);
511-
Type *BSwapSizeType = IntegerType::get(
512-
CI->getContext(), PowerOf2Ceil(CurLoadEntry.LoadSize * 8));
522+
Type *BSwapSizeType =
523+
DL.isLittleEndian()
524+
? IntegerType::get(CI->getContext(),
525+
PowerOf2Ceil(CurLoadEntry.LoadSize * 8))
526+
: nullptr;
513527
Type *MaxLoadType = IntegerType::get(
514528
CI->getContext(),
515529
std::max(MaxLoadSize, (unsigned)PowerOf2Ceil(CurLoadEntry.LoadSize)) * 8);
516530
assert(CurLoadEntry.LoadSize <= MaxLoadSize && "Unexpected load type");
517531

518532
Builder.SetInsertPoint(LoadCmpBlocks[BlockIndex]);
519533

520-
const LoadPair Loads =
521-
getLoadPair(LoadSizeType, /*NeedsBSwap=*/DL.isLittleEndian(),
522-
BSwapSizeType, MaxLoadType, CurLoadEntry.Offset);
534+
const LoadPair Loads = getLoadPair(LoadSizeType, BSwapSizeType, MaxLoadType,
535+
CurLoadEntry.Offset);
523536

524537
// Add the loaded values to the phi nodes for calculating memcmp result only
525538
// if result is not used in a zero equality.
@@ -624,27 +637,25 @@ Value *MemCmpExpansion::getMemCmpEqZeroOneBlock() {
624637
/// A memcmp expansion that only has one block of load and compare can bypass
625638
/// the compare, branch, and phi IR that is required in the general case.
626639
Value *MemCmpExpansion::getMemCmpOneBlock() {
640+
bool NeedsBSwap = DL.isLittleEndian() && Size != 1;
627641
Type *LoadSizeType = IntegerType::get(CI->getContext(), Size * 8);
628642
Type *BSwapSizeType =
629-
IntegerType::get(CI->getContext(), PowerOf2Ceil(Size * 8));
643+
NeedsBSwap ? IntegerType::get(CI->getContext(), PowerOf2Ceil(Size * 8))
644+
: nullptr;
630645
Type *MaxLoadType =
631646
IntegerType::get(CI->getContext(),
632647
std::max(MaxLoadSize, (unsigned)PowerOf2Ceil(Size)) * 8);
633648

634-
bool NeedsBSwap = DL.isLittleEndian() && Size != 1;
635-
636649
// The i8 and i16 cases don't need compares. We zext the loaded values and
637650
// subtract them to get the suitable negative, zero, or positive i32 result.
638651
if (Size < 4) {
639-
const LoadPair Loads = getLoadPair(LoadSizeType, NeedsBSwap, BSwapSizeType,
640-
Builder.getInt32Ty(),
641-
/*Offset*/ 0);
652+
const LoadPair Loads = getLoadPair(LoadSizeType, BSwapSizeType,
653+
Builder.getInt32Ty(), /*Offset*/ 0);
642654
return Builder.CreateSub(Loads.Lhs, Loads.Rhs);
643655
}
644656

645-
const LoadPair Loads =
646-
getLoadPair(LoadSizeType, NeedsBSwap, BSwapSizeType, MaxLoadType,
647-
/*Offset*/ 0);
657+
const LoadPair Loads = getLoadPair(LoadSizeType, BSwapSizeType, MaxLoadType,
658+
/*Offset*/ 0);
648659
// The result of memcmp is negative, zero, or positive, so produce that by
649660
// subtracting 2 extended compare bits: sub (ugt, ult).
650661
// If a target prefers to use selects to get -1/0/1, they should be able

0 commit comments

Comments
 (0)