Skip to content

[AArch64] Don't tail call memset if it would convert to a bzero. #98969

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 6 commits into from
Jul 17, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 8 additions & 2 deletions llvm/include/llvm/CodeGen/Analysis.h
Original file line number Diff line number Diff line change
Expand Up @@ -126,7 +126,8 @@ ICmpInst::Predicate getICmpCondCode(ISD::CondCode Pred);
/// between it and the return.
///
/// This function only tests target-independent requirements.
bool isInTailCallPosition(const CallBase &Call, const TargetMachine &TM);
bool isInTailCallPosition(const CallBase &Call, const TargetMachine &TM,
bool ReturnsFirstArg = false);

/// Test if given that the input instruction is in the tail call position, if
/// there is an attribute mismatch between the caller and the callee that will
Expand All @@ -144,7 +145,12 @@ bool attributesPermitTailCall(const Function *F, const Instruction *I,
/// optimization.
bool returnTypeIsEligibleForTailCall(const Function *F, const Instruction *I,
const ReturnInst *Ret,
const TargetLoweringBase &TLI);
const TargetLoweringBase &TLI,
bool ReturnsFirstArg = false);

/// Returns true if the parent of \p CI returns CI's first argument after
/// calling \p CI.
bool funcReturnsFirstArgOfCall(const CallInst &CI);

DenseMap<const MachineBasicBlock *, int>
getEHScopeMembership(const MachineFunction &MF);
Expand Down
26 changes: 16 additions & 10 deletions llvm/include/llvm/CodeGen/SelectionDAG.h
Original file line number Diff line number Diff line change
Expand Up @@ -1182,24 +1182,30 @@ class SelectionDAG {
/// stack arguments from being clobbered.
SDValue getStackArgumentTokenFactor(SDValue Chain);

SDValue getMemcpy(SDValue Chain, const SDLoc &dl, SDValue Dst, SDValue Src,
SDValue Size, Align Alignment, bool isVol,
bool AlwaysInline, bool isTailCall,
MachinePointerInfo DstPtrInfo,
MachinePointerInfo SrcPtrInfo,
const AAMDNodes &AAInfo = AAMDNodes(),
AAResults *AA = nullptr);

/* \p CI if not null is the memset call being lowered.
* \p OverrideTailCall is an optional parameter that can be used to override
* the tail call optimization decision. */
SDValue
getMemcpy(SDValue Chain, const SDLoc &dl, SDValue Dst, SDValue Src,
SDValue Size, Align Alignment, bool isVol, bool AlwaysInline,
const CallInst *CI, std::optional<bool> OverrideTailCall,
MachinePointerInfo DstPtrInfo, MachinePointerInfo SrcPtrInfo,
const AAMDNodes &AAInfo = AAMDNodes(), AAResults *AA = nullptr);

/* \p CI if not null is the memset call being lowered.
* \p OverrideTailCall is an optional parameter that can be used to override
* the tail call optimization decision. */
SDValue getMemmove(SDValue Chain, const SDLoc &dl, SDValue Dst, SDValue Src,
SDValue Size, Align Alignment, bool isVol, bool isTailCall,
SDValue Size, Align Alignment, bool isVol,
const CallInst *CI, std::optional<bool> OverrideTailCall,
MachinePointerInfo DstPtrInfo,
MachinePointerInfo SrcPtrInfo,
const AAMDNodes &AAInfo = AAMDNodes(),
AAResults *AA = nullptr);

SDValue getMemset(SDValue Chain, const SDLoc &dl, SDValue Dst, SDValue Src,
SDValue Size, Align Alignment, bool isVol,
bool AlwaysInline, bool isTailCall,
bool AlwaysInline, const CallInst *CI,
MachinePointerInfo DstPtrInfo,
const AAMDNodes &AAInfo = AAMDNodes());

Expand Down
57 changes: 19 additions & 38 deletions llvm/lib/CodeGen/Analysis.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -532,7 +532,8 @@ static bool nextRealType(SmallVectorImpl<Type *> &SubTypes,
/// between it and the return.
///
/// This function only tests target-independent requirements.
bool llvm::isInTailCallPosition(const CallBase &Call, const TargetMachine &TM) {
bool llvm::isInTailCallPosition(const CallBase &Call, const TargetMachine &TM,
bool ReturnsFirstArg) {
const BasicBlock *ExitBB = Call.getParent();
const Instruction *Term = ExitBB->getTerminator();
const ReturnInst *Ret = dyn_cast<ReturnInst>(Term);
Expand Down Expand Up @@ -575,7 +576,8 @@ bool llvm::isInTailCallPosition(const CallBase &Call, const TargetMachine &TM) {

const Function *F = ExitBB->getParent();
return returnTypeIsEligibleForTailCall(
F, &Call, Ret, *TM.getSubtargetImpl(*F)->getTargetLowering());
F, &Call, Ret, *TM.getSubtargetImpl(*F)->getTargetLowering(),
ReturnsFirstArg);
}

bool llvm::attributesPermitTailCall(const Function *F, const Instruction *I,
Expand Down Expand Up @@ -638,26 +640,11 @@ bool llvm::attributesPermitTailCall(const Function *F, const Instruction *I,
return CallerAttrs == CalleeAttrs;
}

/// Check whether B is a bitcast of a pointer type to another pointer type,
/// which is equal to A.
static bool isPointerBitcastEqualTo(const Value *A, const Value *B) {
assert(A && B && "Expected non-null inputs!");

auto *BitCastIn = dyn_cast<BitCastInst>(B);

if (!BitCastIn)
return false;

if (!A->getType()->isPointerTy() || !B->getType()->isPointerTy())
return false;

return A == BitCastIn->getOperand(0);
}

bool llvm::returnTypeIsEligibleForTailCall(const Function *F,
const Instruction *I,
const ReturnInst *Ret,
const TargetLoweringBase &TLI) {
const TargetLoweringBase &TLI,
bool ReturnsFirstArg) {
// If the block ends with a void return or unreachable, it doesn't matter
// what the call's return type is.
if (!Ret || Ret->getNumOperands() == 0) return true;
Expand All @@ -671,26 +658,11 @@ bool llvm::returnTypeIsEligibleForTailCall(const Function *F,
if (!attributesPermitTailCall(F, I, Ret, TLI, &AllowDifferingSizes))
return false;

const Value *RetVal = Ret->getOperand(0), *CallVal = I;
// Intrinsic like llvm.memcpy has no return value, but the expanded
// libcall may or may not have return value. On most platforms, it
// will be expanded as memcpy in libc, which returns the first
// argument. On other platforms like arm-none-eabi, memcpy may be
// expanded as library call without return value, like __aeabi_memcpy.
const CallInst *Call = cast<CallInst>(I);
if (Function *F = Call->getCalledFunction()) {
Intrinsic::ID IID = F->getIntrinsicID();
if (((IID == Intrinsic::memcpy &&
TLI.getLibcallName(RTLIB::MEMCPY) == StringRef("memcpy")) ||
(IID == Intrinsic::memmove &&
TLI.getLibcallName(RTLIB::MEMMOVE) == StringRef("memmove")) ||
(IID == Intrinsic::memset &&
TLI.getLibcallName(RTLIB::MEMSET) == StringRef("memset"))) &&
(RetVal == Call->getArgOperand(0) ||
isPointerBitcastEqualTo(RetVal, Call->getArgOperand(0))))
return true;
}
// If the return value is the first argument of the call.
if (ReturnsFirstArg)
return true;

const Value *RetVal = Ret->getOperand(0), *CallVal = I;
SmallVector<unsigned, 4> RetPath, CallPath;
SmallVector<Type *, 4> RetSubTypes, CallSubTypes;

Expand Down Expand Up @@ -739,6 +711,15 @@ bool llvm::returnTypeIsEligibleForTailCall(const Function *F,
return true;
}

bool llvm::funcReturnsFirstArgOfCall(const CallInst &CI) {
const ReturnInst *Ret = dyn_cast<ReturnInst>(CI.getParent()->getTerminator());
Value *RetVal = Ret ? Ret->getReturnValue() : nullptr;
bool ReturnsFirstArg = false;
if (RetVal && ((RetVal == CI.getArgOperand(0))))
ReturnsFirstArg = true;
return ReturnsFirstArg;
}

static void collectEHScopeMembers(
DenseMap<const MachineBasicBlock *, int> &EHScopeMembership, int EHScope,
const MachineBasicBlock *MBB) {
Expand Down
62 changes: 49 additions & 13 deletions llvm/lib/CodeGen/SelectionDAG/SelectionDAG.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -75,6 +75,7 @@
#include <cstdint>
#include <cstdlib>
#include <limits>
#include <optional>
#include <set>
#include <string>
#include <utility>
Expand Down Expand Up @@ -8236,12 +8237,11 @@ static void checkAddrSpaceIsValidForLibcall(const TargetLowering *TLI,
}
}

SDValue SelectionDAG::getMemcpy(SDValue Chain, const SDLoc &dl, SDValue Dst,
SDValue Src, SDValue Size, Align Alignment,
bool isVol, bool AlwaysInline, bool isTailCall,
MachinePointerInfo DstPtrInfo,
MachinePointerInfo SrcPtrInfo,
const AAMDNodes &AAInfo, AAResults *AA) {
SDValue SelectionDAG::getMemcpy(
SDValue Chain, const SDLoc &dl, SDValue Dst, SDValue Src, SDValue Size,
Align Alignment, bool isVol, bool AlwaysInline, const CallInst *CI,
std::optional<bool> OverrideTailCall, MachinePointerInfo DstPtrInfo,
MachinePointerInfo SrcPtrInfo, const AAMDNodes &AAInfo, AAResults *AA) {
// Check to see if we should lower the memcpy to loads and stores first.
// For cases within the target-specified limits, this is the best choice.
ConstantSDNode *ConstantSize = dyn_cast<ConstantSDNode>(Size);
Expand Down Expand Up @@ -8296,6 +8296,18 @@ SDValue SelectionDAG::getMemcpy(SDValue Chain, const SDLoc &dl, SDValue Dst,
Entry.Node = Size; Args.push_back(Entry);
// FIXME: pass in SDLoc
TargetLowering::CallLoweringInfo CLI(*this);
bool IsTailCall = false;
if (OverrideTailCall.has_value()) {
IsTailCall = *OverrideTailCall;
} else {
bool LowersToMemcpy =
TLI->getLibcallName(RTLIB::MEMCPY) == StringRef("memcpy");
bool ReturnsFirstArg = CI && funcReturnsFirstArgOfCall(*CI);
IsTailCall = CI && CI->isTailCall() &&
isInTailCallPosition(*CI, getTarget(),
ReturnsFirstArg && LowersToMemcpy);
}

CLI.setDebugLoc(dl)
.setChain(Chain)
.setLibCallee(TLI->getLibcallCallingConv(RTLIB::MEMCPY),
Expand All @@ -8304,7 +8316,7 @@ SDValue SelectionDAG::getMemcpy(SDValue Chain, const SDLoc &dl, SDValue Dst,
TLI->getPointerTy(getDataLayout())),
std::move(Args))
.setDiscardResult()
.setTailCall(isTailCall);
.setTailCall(IsTailCall);

std::pair<SDValue,SDValue> CallResult = TLI->LowerCallTo(CLI);
return CallResult.second;
Expand Down Expand Up @@ -8352,7 +8364,8 @@ SDValue SelectionDAG::getAtomicMemcpy(SDValue Chain, const SDLoc &dl,

SDValue SelectionDAG::getMemmove(SDValue Chain, const SDLoc &dl, SDValue Dst,
SDValue Src, SDValue Size, Align Alignment,
bool isVol, bool isTailCall,
bool isVol, const CallInst *CI,
std::optional<bool> OverrideTailCall,
MachinePointerInfo DstPtrInfo,
MachinePointerInfo SrcPtrInfo,
const AAMDNodes &AAInfo, AAResults *AA) {
Expand Down Expand Up @@ -8398,6 +8411,19 @@ SDValue SelectionDAG::getMemmove(SDValue Chain, const SDLoc &dl, SDValue Dst,
Entry.Node = Size; Args.push_back(Entry);
// FIXME: pass in SDLoc
TargetLowering::CallLoweringInfo CLI(*this);

bool IsTailCall = false;
if (OverrideTailCall.has_value()) {
IsTailCall = *OverrideTailCall;
} else {
bool LowersToMemmove =
TLI->getLibcallName(RTLIB::MEMMOVE) == StringRef("memmove");
bool ReturnsFirstArg = CI && funcReturnsFirstArgOfCall(*CI);
IsTailCall = CI && CI->isTailCall() &&
isInTailCallPosition(*CI, getTarget(),
ReturnsFirstArg && LowersToMemmove);
}

CLI.setDebugLoc(dl)
.setChain(Chain)
.setLibCallee(TLI->getLibcallCallingConv(RTLIB::MEMMOVE),
Expand All @@ -8406,7 +8432,7 @@ SDValue SelectionDAG::getMemmove(SDValue Chain, const SDLoc &dl, SDValue Dst,
TLI->getPointerTy(getDataLayout())),
std::move(Args))
.setDiscardResult()
.setTailCall(isTailCall);
.setTailCall(IsTailCall);

std::pair<SDValue,SDValue> CallResult = TLI->LowerCallTo(CLI);
return CallResult.second;
Expand Down Expand Up @@ -8454,7 +8480,8 @@ SDValue SelectionDAG::getAtomicMemmove(SDValue Chain, const SDLoc &dl,

SDValue SelectionDAG::getMemset(SDValue Chain, const SDLoc &dl, SDValue Dst,
SDValue Src, SDValue Size, Align Alignment,
bool isVol, bool AlwaysInline, bool isTailCall,
bool isVol, bool AlwaysInline,
const CallInst *CI,
MachinePointerInfo DstPtrInfo,
const AAMDNodes &AAInfo) {
// Check to see if we should lower the memset to stores first.
Expand Down Expand Up @@ -8514,8 +8541,9 @@ SDValue SelectionDAG::getMemset(SDValue Chain, const SDLoc &dl, SDValue Dst,
return Entry;
};

bool UseBZero = isNullConstant(Src) && BzeroName;
// If zeroing out and bzero is present, use it.
if (isNullConstant(Src) && BzeroName) {
if (UseBZero) {
TargetLowering::ArgListTy Args;
Args.push_back(CreateEntry(Dst, PointerType::getUnqual(Ctx)));
Args.push_back(CreateEntry(Size, DL.getIntPtrType(Ctx)));
Expand All @@ -8533,8 +8561,16 @@ SDValue SelectionDAG::getMemset(SDValue Chain, const SDLoc &dl, SDValue Dst,
TLI->getPointerTy(DL)),
std::move(Args));
}

CLI.setDiscardResult().setTailCall(isTailCall);
bool LowersToMemset =
TLI->getLibcallName(RTLIB::MEMSET) == StringRef("memset");
// If we're going to use bzero, make sure not to tail call unless the
// subsequent return doesn't need a value, as bzero doesn't return the first
// arg unlike memset.
bool ReturnsFirstArg = CI && funcReturnsFirstArgOfCall(*CI) && !UseBZero;
bool IsTailCall =
CI && CI->isTailCall() &&
isInTailCallPosition(*CI, getTarget(), ReturnsFirstArg && LowersToMemset);
CLI.setDiscardResult().setTailCall(IsTailCall);

std::pair<SDValue, SDValue> CallResult = TLI->LowerCallTo(CLI);
return CallResult.second;
Expand Down
41 changes: 19 additions & 22 deletions llvm/lib/CodeGen/SelectionDAG/SelectionDAGBuilder.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -6460,14 +6460,14 @@ void SelectionDAGBuilder::visitIntrinsicCall(const CallInst &I,
Align SrcAlign = MCI.getSourceAlign().valueOrOne();
Align Alignment = std::min(DstAlign, SrcAlign);
bool isVol = MCI.isVolatile();
bool isTC = I.isTailCall() && isInTailCallPosition(I, DAG.getTarget());
// FIXME: Support passing different dest/src alignments to the memcpy DAG
// node.
SDValue Root = isVol ? getRoot() : getMemoryRoot();
SDValue MC = DAG.getMemcpy(
Root, sdl, Op1, Op2, Op3, Alignment, isVol,
/* AlwaysInline */ false, isTC, MachinePointerInfo(I.getArgOperand(0)),
MachinePointerInfo(I.getArgOperand(1)), I.getAAMetadata(), AA);
SDValue MC = DAG.getMemcpy(Root, sdl, Op1, Op2, Op3, Alignment, isVol,
/* AlwaysInline */ false, &I, std::nullopt,
MachinePointerInfo(I.getArgOperand(0)),
MachinePointerInfo(I.getArgOperand(1)),
I.getAAMetadata(), AA);
updateDAGForMaybeTailCall(MC);
return;
}
Expand All @@ -6482,13 +6482,13 @@ void SelectionDAGBuilder::visitIntrinsicCall(const CallInst &I,
Align SrcAlign = MCI.getSourceAlign().valueOrOne();
Align Alignment = std::min(DstAlign, SrcAlign);
bool isVol = MCI.isVolatile();
bool isTC = I.isTailCall() && isInTailCallPosition(I, DAG.getTarget());
// FIXME: Support passing different dest/src alignments to the memcpy DAG
// node.
SDValue MC = DAG.getMemcpy(
getRoot(), sdl, Dst, Src, Size, Alignment, isVol,
/* AlwaysInline */ true, isTC, MachinePointerInfo(I.getArgOperand(0)),
MachinePointerInfo(I.getArgOperand(1)), I.getAAMetadata(), AA);
SDValue MC = DAG.getMemcpy(getRoot(), sdl, Dst, Src, Size, Alignment, isVol,
/* AlwaysInline */ true, &I, std::nullopt,
MachinePointerInfo(I.getArgOperand(0)),
MachinePointerInfo(I.getArgOperand(1)),
I.getAAMetadata(), AA);
updateDAGForMaybeTailCall(MC);
return;
}
Expand All @@ -6500,11 +6500,10 @@ void SelectionDAGBuilder::visitIntrinsicCall(const CallInst &I,
// @llvm.memset defines 0 and 1 to both mean no alignment.
Align Alignment = MSI.getDestAlign().valueOrOne();
bool isVol = MSI.isVolatile();
bool isTC = I.isTailCall() && isInTailCallPosition(I, DAG.getTarget());
SDValue Root = isVol ? getRoot() : getMemoryRoot();
SDValue MS = DAG.getMemset(
Root, sdl, Op1, Op2, Op3, Alignment, isVol, /* AlwaysInline */ false,
isTC, MachinePointerInfo(I.getArgOperand(0)), I.getAAMetadata());
&I, MachinePointerInfo(I.getArgOperand(0)), I.getAAMetadata());
updateDAGForMaybeTailCall(MS);
return;
}
Expand All @@ -6517,10 +6516,9 @@ void SelectionDAGBuilder::visitIntrinsicCall(const CallInst &I,
// @llvm.memset defines 0 and 1 to both mean no alignment.
Align DstAlign = MSII.getDestAlign().valueOrOne();
bool isVol = MSII.isVolatile();
bool isTC = I.isTailCall() && isInTailCallPosition(I, DAG.getTarget());
SDValue Root = isVol ? getRoot() : getMemoryRoot();
SDValue MC = DAG.getMemset(Root, sdl, Dst, Value, Size, DstAlign, isVol,
/* AlwaysInline */ true, isTC,
/* AlwaysInline */ true, &I,
MachinePointerInfo(I.getArgOperand(0)),
I.getAAMetadata());
updateDAGForMaybeTailCall(MC);
Expand All @@ -6536,12 +6534,12 @@ void SelectionDAGBuilder::visitIntrinsicCall(const CallInst &I,
Align SrcAlign = MMI.getSourceAlign().valueOrOne();
Align Alignment = std::min(DstAlign, SrcAlign);
bool isVol = MMI.isVolatile();
bool isTC = I.isTailCall() && isInTailCallPosition(I, DAG.getTarget());
// FIXME: Support passing different dest/src alignments to the memmove DAG
// node.
SDValue Root = isVol ? getRoot() : getMemoryRoot();
SDValue MM = DAG.getMemmove(Root, sdl, Op1, Op2, Op3, Alignment, isVol,
isTC, MachinePointerInfo(I.getArgOperand(0)),
SDValue MM = DAG.getMemmove(Root, sdl, Op1, Op2, Op3, Alignment, isVol, &I,
/* OverrideTailCall */ std::nullopt,
MachinePointerInfo(I.getArgOperand(0)),
MachinePointerInfo(I.getArgOperand(1)),
I.getAAMetadata(), AA);
updateDAGForMaybeTailCall(MM);
Expand Down Expand Up @@ -9039,11 +9037,10 @@ bool SelectionDAGBuilder::visitMemPCpyCall(const CallInst &I) {
// because the return pointer needs to be adjusted by the size of
// the copied memory.
SDValue Root = getMemoryRoot();
SDValue MC = DAG.getMemcpy(Root, sdl, Dst, Src, Size, Alignment, false, false,
/*isTailCall=*/false,
MachinePointerInfo(I.getArgOperand(0)),
MachinePointerInfo(I.getArgOperand(1)),
I.getAAMetadata());
SDValue MC = DAG.getMemcpy(
Root, sdl, Dst, Src, Size, Alignment, false, false, /*CI=*/nullptr,
std::nullopt, MachinePointerInfo(I.getArgOperand(0)),
MachinePointerInfo(I.getArgOperand(1)), I.getAAMetadata());
assert(MC.getNode() != nullptr &&
"** memcpy should not be lowered as TailCall in mempcpy context **");
DAG.setRoot(MC);
Expand Down
Loading
Loading