Skip to content

[CodeGen] Improve ExpandMemCmp for more efficient non-register aligned sizes handling #70469

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 4 commits into from
Oct 30, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
11 changes: 11 additions & 0 deletions llvm/include/llvm/Analysis/TargetTransformInfo.h
Original file line number Diff line number Diff line change
Expand Up @@ -907,6 +907,17 @@ class TargetTransformInfo {
// be done with two 4-byte compares instead of 4+2+1-byte compares. This
// requires all loads in LoadSizes to be doable in an unaligned way.
bool AllowOverlappingLoads = false;

// Sometimes, the amount of data that needs to be compared is smaller than
// the standard register size, but it cannot be loaded with just one load
// instruction. For example, if the size of the memory comparison is 6
// bytes, we can handle it more efficiently by loading all 6 bytes in a
// single block and generating an 8-byte number, instead of generating two
// separate blocks with conditional jumps for 4 and 2 byte loads. This
// approach simplifies the process and produces the comparison result as
// normal. This array lists the allowed sizes of memcmp tails that can be
// merged into one block
SmallVector<unsigned, 4> AllowedTailExpansions;
};
MemCmpExpansionOptions enableMemCmpExpansion(bool OptSize,
bool IsZeroCmp) const;
Expand Down
95 changes: 75 additions & 20 deletions llvm/lib/CodeGen/ExpandMemCmp.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -117,8 +117,8 @@ class MemCmpExpansion {
Value *Lhs = nullptr;
Value *Rhs = nullptr;
};
LoadPair getLoadPair(Type *LoadSizeType, bool NeedsBSwap, Type *CmpSizeType,
unsigned OffsetBytes);
LoadPair getLoadPair(Type *LoadSizeType, Type *BSwapSizeType,
Type *CmpSizeType, unsigned OffsetBytes);

static LoadEntryVector
computeGreedyLoadSequence(uint64_t Size, llvm::ArrayRef<unsigned> LoadSizes,
Expand All @@ -128,6 +128,11 @@ class MemCmpExpansion {
unsigned MaxNumLoads,
unsigned &NumLoadsNonOneByte);

static void optimiseLoadSequence(
LoadEntryVector &LoadSequence,
const TargetTransformInfo::MemCmpExpansionOptions &Options,
bool IsUsedForZeroCmp);

public:
MemCmpExpansion(CallInst *CI, uint64_t Size,
const TargetTransformInfo::MemCmpExpansionOptions &Options,
Expand Down Expand Up @@ -210,6 +215,37 @@ MemCmpExpansion::computeOverlappingLoadSequence(uint64_t Size,
return LoadSequence;
}

void MemCmpExpansion::optimiseLoadSequence(
LoadEntryVector &LoadSequence,
const TargetTransformInfo::MemCmpExpansionOptions &Options,
bool IsUsedForZeroCmp) {
// This part of code attempts to optimize the LoadSequence by merging allowed
// subsequences into single loads of allowed sizes from
// `MemCmpExpansionOptions::AllowedTailExpansions`. If it is for zero
// comparison or if no allowed tail expansions are specified, we exit early.
if (IsUsedForZeroCmp || Options.AllowedTailExpansions.empty())
return;

while (LoadSequence.size() >= 2) {
auto Last = LoadSequence[LoadSequence.size() - 1];
auto PreLast = LoadSequence[LoadSequence.size() - 2];

// Exit the loop if the two sequences are not contiguous
if (PreLast.Offset + PreLast.LoadSize != Last.Offset)
break;

auto LoadSize = Last.LoadSize + PreLast.LoadSize;
if (find(Options.AllowedTailExpansions, LoadSize) ==
Options.AllowedTailExpansions.end())
break;

// Remove the last two sequences and replace with the combined sequence
LoadSequence.pop_back();
LoadSequence.pop_back();
LoadSequence.emplace_back(PreLast.Offset, LoadSize);
}
}

// Initialize the basic block structure required for expansion of memcmp call
// with given maximum load size and memcmp size parameter.
// This structure includes:
Expand Down Expand Up @@ -255,6 +291,7 @@ MemCmpExpansion::MemCmpExpansion(
}
}
assert(LoadSequence.size() <= Options.MaxNumLoads && "broken invariant");
optimiseLoadSequence(LoadSequence, Options, IsUsedForZeroCmp);
}

unsigned MemCmpExpansion::getNumBlocks() {
Expand All @@ -278,7 +315,7 @@ void MemCmpExpansion::createResultBlock() {
}

MemCmpExpansion::LoadPair MemCmpExpansion::getLoadPair(Type *LoadSizeType,
bool NeedsBSwap,
Type *BSwapSizeType,
Type *CmpSizeType,
unsigned OffsetBytes) {
// Get the memory source at offset `OffsetBytes`.
Expand Down Expand Up @@ -307,16 +344,22 @@ MemCmpExpansion::LoadPair MemCmpExpansion::getLoadPair(Type *LoadSizeType,
if (!Rhs)
Rhs = Builder.CreateAlignedLoad(LoadSizeType, RhsSource, RhsAlign);

// Zero extend if Byte Swap intrinsic has different type
if (BSwapSizeType && LoadSizeType != BSwapSizeType) {
Lhs = Builder.CreateZExt(Lhs, BSwapSizeType);
Rhs = Builder.CreateZExt(Rhs, BSwapSizeType);
}

// Swap bytes if required.
if (NeedsBSwap) {
Function *Bswap = Intrinsic::getDeclaration(CI->getModule(),
Intrinsic::bswap, LoadSizeType);
if (BSwapSizeType) {
Function *Bswap = Intrinsic::getDeclaration(
CI->getModule(), Intrinsic::bswap, BSwapSizeType);
Lhs = Builder.CreateCall(Bswap, Lhs);
Rhs = Builder.CreateCall(Bswap, Rhs);
}

// Zero extend if required.
if (CmpSizeType != nullptr && CmpSizeType != LoadSizeType) {
if (CmpSizeType != nullptr && CmpSizeType != Lhs->getType()) {
Lhs = Builder.CreateZExt(Lhs, CmpSizeType);
Rhs = Builder.CreateZExt(Rhs, CmpSizeType);
}
Expand All @@ -332,7 +375,7 @@ void MemCmpExpansion::emitLoadCompareByteBlock(unsigned BlockIndex,
BasicBlock *BB = LoadCmpBlocks[BlockIndex];
Builder.SetInsertPoint(BB);
const LoadPair Loads =
getLoadPair(Type::getInt8Ty(CI->getContext()), /*NeedsBSwap=*/false,
getLoadPair(Type::getInt8Ty(CI->getContext()), nullptr,
Type::getInt32Ty(CI->getContext()), OffsetBytes);
Value *Diff = Builder.CreateSub(Loads.Lhs, Loads.Rhs);

Expand Down Expand Up @@ -385,11 +428,12 @@ Value *MemCmpExpansion::getCompareLoadPairs(unsigned BlockIndex,
IntegerType *const MaxLoadType =
NumLoads == 1 ? nullptr
: IntegerType::get(CI->getContext(), MaxLoadSize * 8);

for (unsigned i = 0; i < NumLoads; ++i, ++LoadIndex) {
const LoadEntry &CurLoadEntry = LoadSequence[LoadIndex];
const LoadPair Loads = getLoadPair(
IntegerType::get(CI->getContext(), CurLoadEntry.LoadSize * 8),
/*NeedsBSwap=*/false, MaxLoadType, CurLoadEntry.Offset);
IntegerType::get(CI->getContext(), CurLoadEntry.LoadSize * 8), nullptr,
MaxLoadType, CurLoadEntry.Offset);

if (NumLoads != 1) {
// If we have multiple loads per block, we need to generate a composite
Expand Down Expand Up @@ -475,14 +519,20 @@ void MemCmpExpansion::emitLoadCompareBlock(unsigned BlockIndex) {

Type *LoadSizeType =
IntegerType::get(CI->getContext(), CurLoadEntry.LoadSize * 8);
Type *MaxLoadType = IntegerType::get(CI->getContext(), MaxLoadSize * 8);
Type *BSwapSizeType =
DL.isLittleEndian()
? IntegerType::get(CI->getContext(),
PowerOf2Ceil(CurLoadEntry.LoadSize * 8))
: nullptr;
Type *MaxLoadType = IntegerType::get(
CI->getContext(),
std::max(MaxLoadSize, (unsigned)PowerOf2Ceil(CurLoadEntry.LoadSize)) * 8);
assert(CurLoadEntry.LoadSize <= MaxLoadSize && "Unexpected load type");

Builder.SetInsertPoint(LoadCmpBlocks[BlockIndex]);

const LoadPair Loads =
getLoadPair(LoadSizeType, /*NeedsBSwap=*/DL.isLittleEndian(), MaxLoadType,
CurLoadEntry.Offset);
const LoadPair Loads = getLoadPair(LoadSizeType, BSwapSizeType, MaxLoadType,
CurLoadEntry.Offset);

// Add the loaded values to the phi nodes for calculating memcmp result only
// if result is not used in a zero equality.
Expand Down Expand Up @@ -587,19 +637,24 @@ Value *MemCmpExpansion::getMemCmpEqZeroOneBlock() {
/// A memcmp expansion that only has one block of load and compare can bypass
/// the compare, branch, and phi IR that is required in the general case.
Value *MemCmpExpansion::getMemCmpOneBlock() {
Type *LoadSizeType = IntegerType::get(CI->getContext(), Size * 8);
bool NeedsBSwap = DL.isLittleEndian() && Size != 1;
Type *LoadSizeType = IntegerType::get(CI->getContext(), Size * 8);
Type *BSwapSizeType =
NeedsBSwap ? IntegerType::get(CI->getContext(), PowerOf2Ceil(Size * 8))
: nullptr;
Type *MaxLoadType =
IntegerType::get(CI->getContext(),
std::max(MaxLoadSize, (unsigned)PowerOf2Ceil(Size)) * 8);

// The i8 and i16 cases don't need compares. We zext the loaded values and
// subtract them to get the suitable negative, zero, or positive i32 result.
if (Size < 4) {
const LoadPair Loads =
getLoadPair(LoadSizeType, NeedsBSwap, Builder.getInt32Ty(),
/*Offset*/ 0);
if (Size == 1 || Size == 2) {
const LoadPair Loads = getLoadPair(LoadSizeType, BSwapSizeType,
Builder.getInt32Ty(), /*Offset*/ 0);
return Builder.CreateSub(Loads.Lhs, Loads.Rhs);
}

const LoadPair Loads = getLoadPair(LoadSizeType, NeedsBSwap, LoadSizeType,
const LoadPair Loads = getLoadPair(LoadSizeType, BSwapSizeType, MaxLoadType,
/*Offset*/ 0);
// The result of memcmp is negative, zero, or positive, so produce that by
// subtracting 2 extended compare bits: sub (ugt, ult).
Expand Down
1 change: 1 addition & 0 deletions llvm/lib/Target/AArch64/AArch64TargetTransformInfo.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -2994,6 +2994,7 @@ AArch64TTIImpl::enableMemCmpExpansion(bool OptSize, bool IsZeroCmp) const {
// they may wake up the FP unit, which raises the power consumption. Perhaps
// they could be used with no holds barred (-O3).
Options.LoadSizes = {8, 4, 2, 1};
Options.AllowedTailExpansions = {3, 5, 6};
return Options;
}

Expand Down
Loading