Skip to content

[X86][CodeGen] Support hoisting load/store with conditional faulting #96720

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 6 commits into from
Jun 27, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 8 additions & 0 deletions llvm/include/llvm/Analysis/TargetTransformInfo.h
Original file line number Diff line number Diff line change
Expand Up @@ -1113,6 +1113,10 @@ class TargetTransformInfo {
/// \return the number of registers in the target-provided register class.
unsigned getNumberOfRegisters(unsigned ClassID) const;

/// \return true if the target supports load/store that enables fault
/// suppression of memory operands when the source condition is false.
bool hasConditionalLoadStoreForType(Type *Ty = nullptr) const;

/// \return the target-provided register class ID for the provided type,
/// accounting for type promotion and other type-legalization techniques that
/// the target might apply. However, it specifically does not account for the
Expand Down Expand Up @@ -1956,6 +1960,7 @@ class TargetTransformInfo::Concept {
virtual bool preferToKeepConstantsAttached(const Instruction &Inst,
const Function &Fn) const = 0;
virtual unsigned getNumberOfRegisters(unsigned ClassID) const = 0;
virtual bool hasConditionalLoadStoreForType(Type *Ty = nullptr) const = 0;
virtual unsigned getRegisterClassForType(bool Vector,
Type *Ty = nullptr) const = 0;
virtual const char *getRegisterClassName(unsigned ClassID) const = 0;
Expand Down Expand Up @@ -2543,6 +2548,9 @@ class TargetTransformInfo::Model final : public TargetTransformInfo::Concept {
unsigned getNumberOfRegisters(unsigned ClassID) const override {
return Impl.getNumberOfRegisters(ClassID);
}
bool hasConditionalLoadStoreForType(Type *Ty = nullptr) const override {
return Impl.hasConditionalLoadStoreForType(Ty);
}
unsigned getRegisterClassForType(bool Vector,
Type *Ty = nullptr) const override {
return Impl.getRegisterClassForType(Vector, Ty);
Expand Down
1 change: 1 addition & 0 deletions llvm/include/llvm/Analysis/TargetTransformInfoImpl.h
Original file line number Diff line number Diff line change
Expand Up @@ -457,6 +457,7 @@ class TargetTransformInfoImplBase {
}

unsigned getNumberOfRegisters(unsigned ClassID) const { return 8; }
bool hasConditionalLoadStoreForType(Type *Ty) const { return false; }

unsigned getRegisterClassForType(bool Vector, Type *Ty = nullptr) const {
return Vector ? 1 : 0;
Expand Down
14 changes: 14 additions & 0 deletions llvm/include/llvm/CodeGen/TargetLowering.h
Original file line number Diff line number Diff line change
Expand Up @@ -3895,6 +3895,20 @@ class TargetLowering : public TargetLoweringBase {
const SDValue OldRHS, SDValue &Chain,
bool IsSignaling = false) const;

virtual SDValue visitMaskedLoad(SelectionDAG &DAG, const SDLoc &DL,
SDValue Chain, MachineMemOperand *MMO,
SDValue &NewLoad, SDValue Ptr,
SDValue PassThru, SDValue Mask) const {
llvm_unreachable("Not Implemented");
}

virtual SDValue visitMaskedStore(SelectionDAG &DAG, const SDLoc &DL,
SDValue Chain, MachineMemOperand *MMO,
SDValue Ptr, SDValue Val,
SDValue Mask) const {
llvm_unreachable("Not Implemented");
}

/// Returns a pair of (return value, chain).
/// It is an error to pass RTLIB::UNKNOWN_LIBCALL as \p LC.
std::pair<SDValue, SDValue> makeLibCall(SelectionDAG &DAG, RTLIB::Libcall LC,
Expand Down
4 changes: 4 additions & 0 deletions llvm/lib/Analysis/TargetTransformInfo.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -722,6 +722,10 @@ unsigned TargetTransformInfo::getNumberOfRegisters(unsigned ClassID) const {
return TTIImpl->getNumberOfRegisters(ClassID);
}

bool TargetTransformInfo::hasConditionalLoadStoreForType(Type *Ty) const {
return TTIImpl->hasConditionalLoadStoreForType(Ty);
}

unsigned TargetTransformInfo::getRegisterClassForType(bool Vector,
Type *Ty) const {
return TTIImpl->getRegisterClassForType(Vector, Ty);
Expand Down
32 changes: 26 additions & 6 deletions llvm/lib/CodeGen/SelectionDAG/SelectionDAGBuilder.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -4783,9 +4783,18 @@ void SelectionDAGBuilder::visitMaskedStore(const CallInst &I,
MachineMemOperand *MMO = DAG.getMachineFunction().getMachineMemOperand(
MachinePointerInfo(PtrOperand), MMOFlags,
LocationSize::beforeOrAfterPointer(), Alignment, I.getAAMetadata());

const auto &TLI = DAG.getTargetLoweringInfo();
const auto &TTI =
TLI.getTargetMachine().getTargetTransformInfo(*I.getFunction());
SDValue StoreNode =
DAG.getMaskedStore(getMemoryRoot(), sdl, Src0, Ptr, Offset, Mask, VT, MMO,
ISD::UNINDEXED, false /* Truncating */, IsCompressing);
!IsCompressing && TTI.hasConditionalLoadStoreForType(
I.getArgOperand(0)->getType()->getScalarType())
? TLI.visitMaskedStore(DAG, sdl, getMemoryRoot(), MMO, Ptr, Src0,
Mask)
: DAG.getMaskedStore(getMemoryRoot(), sdl, Src0, Ptr, Offset, Mask,
VT, MMO, ISD::UNINDEXED, /*Truncating=*/false,
IsCompressing);
DAG.setRoot(StoreNode);
setValue(&I, StoreNode);
}
Expand Down Expand Up @@ -4958,12 +4967,23 @@ void SelectionDAGBuilder::visitMaskedLoad(const CallInst &I, bool IsExpanding) {
MachinePointerInfo(PtrOperand), MMOFlags,
LocationSize::beforeOrAfterPointer(), Alignment, AAInfo, Ranges);

SDValue Load =
DAG.getMaskedLoad(VT, sdl, InChain, Ptr, Offset, Mask, Src0, VT, MMO,
ISD::UNINDEXED, ISD::NON_EXTLOAD, IsExpanding);
const auto &TLI = DAG.getTargetLoweringInfo();
const auto &TTI =
TLI.getTargetMachine().getTargetTransformInfo(*I.getFunction());
// The Load/Res may point to different values and both of them are output
// variables.
SDValue Load;
SDValue Res;
if (!IsExpanding && TTI.hasConditionalLoadStoreForType(
Src0Operand->getType()->getScalarType()))
Res = TLI.visitMaskedLoad(DAG, sdl, InChain, MMO, Load, Ptr, Src0, Mask);
else
Res = Load =
DAG.getMaskedLoad(VT, sdl, InChain, Ptr, Offset, Mask, Src0, VT, MMO,
ISD::UNINDEXED, ISD::NON_EXTLOAD, IsExpanding);
if (AddToChain)
PendingLoads.push_back(Load.getValue(1));
setValue(&I, Load);
setValue(&I, Res);
}

void SelectionDAGBuilder::visitMaskedGather(const CallInst &I) {
Expand Down
78 changes: 78 additions & 0 deletions llvm/lib/Target/X86/X86ISelLowering.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -32308,6 +32308,54 @@ bool X86TargetLowering::isInlineAsmTargetBranch(
return Inst.equals_insensitive("call") || Inst.equals_insensitive("jmp");
}

static SDValue getFlagsOfCmpZeroFori1(SelectionDAG &DAG, const SDLoc &DL,
SDValue Mask) {
EVT Ty = MVT::i8;
auto V = DAG.getBitcast(MVT::i1, Mask);
auto VE = DAG.getZExtOrTrunc(V, DL, Ty);
auto Zero = DAG.getConstant(0, DL, Ty);
SDVTList X86SubVTs = DAG.getVTList(Ty, MVT::i32);
auto CmpZero = DAG.getNode(X86ISD::SUB, DL, X86SubVTs, Zero, VE);
return SDValue(CmpZero.getNode(), 1);
}

SDValue X86TargetLowering::visitMaskedLoad(
SelectionDAG &DAG, const SDLoc &DL, SDValue Chain, MachineMemOperand *MMO,
SDValue &NewLoad, SDValue Ptr, SDValue PassThru, SDValue Mask) const {
// @llvm.masked.load.v1*(ptr, alignment, mask, passthru)
// ->
// _, flags = SUB 0, mask
// res, chain = CLOAD inchain, ptr, (bit_cast_to_scalar passthru), cond, flags
// bit_cast_to_vector<res>
EVT VTy = PassThru.getValueType();
EVT Ty = VTy.getVectorElementType();
SDVTList Tys = DAG.getVTList(Ty, MVT::Other);
auto ScalarPassThru = PassThru.isUndef() ? DAG.getConstant(0, DL, Ty)
: DAG.getBitcast(Ty, PassThru);
auto Flags = getFlagsOfCmpZeroFori1(DAG, DL, Mask);
auto COND_NE = DAG.getTargetConstant(X86::COND_NE, DL, MVT::i8);
SDValue Ops[] = {Chain, Ptr, ScalarPassThru, COND_NE, Flags};
NewLoad = DAG.getMemIntrinsicNode(X86ISD::CLOAD, DL, Tys, Ops, Ty, MMO);
return DAG.getBitcast(VTy, NewLoad);
}

SDValue X86TargetLowering::visitMaskedStore(SelectionDAG &DAG, const SDLoc &DL,
SDValue Chain,
MachineMemOperand *MMO, SDValue Ptr,
SDValue Val, SDValue Mask) const {
// llvm.masked.store.v1*(Src0, Ptr, alignment, Mask)
// ->
// _, flags = SUB 0, mask
// chain = CSTORE inchain, (bit_cast_to_scalar val), ptr, cond, flags
EVT Ty = Val.getValueType().getVectorElementType();
SDVTList Tys = DAG.getVTList(MVT::Other);
auto ScalarVal = DAG.getBitcast(Ty, Val);
auto Flags = getFlagsOfCmpZeroFori1(DAG, DL, Mask);
auto COND_NE = DAG.getTargetConstant(X86::COND_NE, DL, MVT::i8);
SDValue Ops[] = {Chain, ScalarVal, Ptr, COND_NE, Flags};
return DAG.getMemIntrinsicNode(X86ISD::CSTORE, DL, Tys, Ops, Ty, MMO);
}

/// Provide custom lowering hooks for some operations.
SDValue X86TargetLowering::LowerOperation(SDValue Op, SelectionDAG &DAG) const {
switch (Op.getOpcode()) {
Expand Down Expand Up @@ -34024,6 +34072,8 @@ const char *X86TargetLowering::getTargetNodeName(unsigned Opcode) const {
NODE_NAME_CASE(STRICT_FP80_ADD)
NODE_NAME_CASE(CCMP)
NODE_NAME_CASE(CTEST)
NODE_NAME_CASE(CLOAD)
NODE_NAME_CASE(CSTORE)
}
return nullptr;
#undef NODE_NAME_CASE
Expand Down Expand Up @@ -55633,6 +55683,32 @@ static SDValue combineSubSetcc(SDNode *N, SelectionDAG &DAG) {
return SDValue();
}

static SDValue combineX86CloadCstore(SDNode *N, SelectionDAG &DAG) {
// res, flags2 = sub 0, (setcc cc, flag)
// cload/cstore ..., cond_ne, flag2
// ->
// cload/cstore cc, flag
if (N->getConstantOperandVal(3) != X86::COND_NE)
return SDValue();

SDValue Sub = N->getOperand(4);
if (Sub.getOpcode() != X86ISD::SUB)
return SDValue();

SDValue SetCC = Sub.getOperand(1);

if (!X86::isZeroNode(Sub.getOperand(0)) || SetCC.getOpcode() != X86ISD::SETCC)
return SDValue();

SmallVector<SDValue, 5> Ops(N->op_values());
Ops[3] = SetCC.getOperand(0);
Ops[4] = SetCC.getOperand(1);

return DAG.getMemIntrinsicNode(N->getOpcode(), SDLoc(N), N->getVTList(), Ops,
cast<MemSDNode>(N)->getMemoryVT(),
cast<MemSDNode>(N)->getMemOperand());
}

static SDValue combineSub(SDNode *N, SelectionDAG &DAG,
TargetLowering::DAGCombinerInfo &DCI,
const X86Subtarget &Subtarget) {
Expand Down Expand Up @@ -57340,6 +57416,8 @@ SDValue X86TargetLowering::PerformDAGCombine(SDNode *N,
case ISD::SUB: return combineSub(N, DAG, DCI, Subtarget);
case X86ISD::ADD:
case X86ISD::SUB: return combineX86AddSub(N, DAG, DCI, Subtarget);
case X86ISD::CLOAD:
case X86ISD::CSTORE: return combineX86CloadCstore(N, DAG);
case X86ISD::SBB: return combineSBB(N, DAG);
case X86ISD::ADC: return combineADC(N, DAG, DCI);
case ISD::MUL: return combineMul(N, DAG, DCI, Subtarget);
Expand Down
12 changes: 12 additions & 0 deletions llvm/lib/Target/X86/X86ISelLowering.h
Original file line number Diff line number Diff line change
Expand Up @@ -903,6 +903,10 @@ namespace llvm {
// is needed so that this can be expanded with control flow.
VASTART_SAVE_XMM_REGS,

// Conditional load/store instructions
CLOAD,
CSTORE,

// WARNING: Do not add anything in the end unless you want the node to
// have memop! In fact, starting from FIRST_TARGET_MEMORY_OPCODE all
// opcodes will be thought as target memory ops!
Expand Down Expand Up @@ -1556,6 +1560,14 @@ namespace llvm {
bool isInlineAsmTargetBranch(const SmallVectorImpl<StringRef> &AsmStrs,
unsigned OpNo) const override;

SDValue visitMaskedLoad(SelectionDAG &DAG, const SDLoc &DL, SDValue Chain,
MachineMemOperand *MMO, SDValue &NewLoad,
SDValue Ptr, SDValue PassThru,
SDValue Mask) const override;
SDValue visitMaskedStore(SelectionDAG &DAG, const SDLoc &DL, SDValue Chain,
MachineMemOperand *MMO, SDValue Ptr, SDValue Val,
SDValue Mask) const override;

/// Lower interleaved load(s) into target specific
/// instructions/intrinsics.
bool lowerInterleavedLoad(LoadInst *LI,
Expand Down
21 changes: 21 additions & 0 deletions llvm/lib/Target/X86/X86InstrCMovSetCC.td
Original file line number Diff line number Diff line change
Expand Up @@ -113,6 +113,27 @@ let Predicates = [HasCMOV, HasCF] in {
(CFCMOV32rr GR32:$src1, (inv_cond_XFORM timm:$cond))>;
def : Pat<(X86cmov GR64:$src1, 0, timm:$cond, EFLAGS),
(CFCMOV64rr GR64:$src1, (inv_cond_XFORM timm:$cond))>;

def : Pat<(X86cload addr:$src1, 0, timm:$cond, EFLAGS),
(CFCMOV16rm addr:$src1, timm:$cond)>;
def : Pat<(X86cload addr:$src1, 0, timm:$cond, EFLAGS),
(CFCMOV32rm addr:$src1, timm:$cond)>;
def : Pat<(X86cload addr:$src1, 0, timm:$cond, EFLAGS),
(CFCMOV64rm addr:$src1, timm:$cond)>;

def : Pat<(X86cload addr:$src2, GR16:$src1, timm:$cond, EFLAGS),
(CFCMOV16rm_ND GR16:$src1, addr:$src2, timm:$cond)>;
def : Pat<(X86cload addr:$src2, GR32:$src1, timm:$cond, EFLAGS),
(CFCMOV32rm_ND GR32:$src1, addr:$src2, timm:$cond)>;
def : Pat<(X86cload addr:$src2, GR64:$src1, timm:$cond, EFLAGS),
(CFCMOV64rm_ND GR64:$src1, addr:$src2, timm:$cond)>;

def : Pat<(X86cstore GR16:$src2, addr:$src1, timm:$cond, EFLAGS),
(CFCMOV16mr addr:$src1, GR16:$src2, timm:$cond)>;
def : Pat<(X86cstore GR32:$src2, addr:$src1, timm:$cond, EFLAGS),
(CFCMOV32mr addr:$src1, GR32:$src2, timm:$cond)>;
def : Pat<(X86cstore GR64:$src2, addr:$src1, timm:$cond, EFLAGS),
(CFCMOV64mr addr:$src1, GR64:$src2, timm:$cond)>;
}

// SetCC instructions.
Expand Down
12 changes: 12 additions & 0 deletions llvm/lib/Target/X86/X86InstrFragments.td
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,15 @@ def SDTX86FCmp : SDTypeProfile<1, 2, [SDTCisVT<0, i32>, SDTCisFP<1>,
def SDTX86Ccmp : SDTypeProfile<1, 5,
[SDTCisVT<3, i8>, SDTCisVT<4, i8>, SDTCisVT<5, i32>]>;

// RES = op PTR, PASSTHRU, COND, EFLAGS
def SDTX86Cload : SDTypeProfile<1, 4,
[SDTCisInt<0>, SDTCisPtrTy<1>, SDTCisSameAs<0, 2>,
SDTCisVT<3, i8>, SDTCisVT<4, i32>]>;
// op VAL, PTR, COND, EFLAGS
def SDTX86Cstore : SDTypeProfile<0, 4,
[SDTCisInt<0>, SDTCisPtrTy<1>,
SDTCisVT<2, i8>, SDTCisVT<3, i32>]>;

def SDTX86Cmov : SDTypeProfile<1, 4,
[SDTCisSameAs<0, 1>, SDTCisSameAs<1, 2>,
SDTCisVT<3, i8>, SDTCisVT<4, i32>]>;
Expand Down Expand Up @@ -144,6 +153,9 @@ def X86bt : SDNode<"X86ISD::BT", SDTX86CmpTest>;
def X86ccmp : SDNode<"X86ISD::CCMP", SDTX86Ccmp>;
def X86ctest : SDNode<"X86ISD::CTEST", SDTX86Ccmp>;

def X86cload : SDNode<"X86ISD::CLOAD", SDTX86Cload, [SDNPHasChain, SDNPMayLoad, SDNPMemOperand]>;
def X86cstore : SDNode<"X86ISD::CSTORE", SDTX86Cstore, [SDNPHasChain, SDNPMayStore, SDNPMemOperand]>;

def X86cmov : SDNode<"X86ISD::CMOV", SDTX86Cmov>;
def X86brcond : SDNode<"X86ISD::BRCOND", SDTX86BrCond,
[SDNPHasChain]>;
Expand Down
44 changes: 35 additions & 9 deletions llvm/lib/Target/X86/X86TargetTransformInfo.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -176,6 +176,27 @@ unsigned X86TTIImpl::getNumberOfRegisters(unsigned ClassID) const {
return 8;
}

bool X86TTIImpl::hasConditionalLoadStoreForType(Type *Ty) const {
if (!ST->hasCF())
return false;
if (!Ty)
return true;
// Conditional faulting is supported by CFCMOV, which only accepts
// 16/32/64-bit operands.
// TODO: Support f32/f64 with VMOVSS/VMOVSD with zero mask when it's
// profitable.
if (!Ty->isIntegerTy())
return false;
switch (cast<IntegerType>(Ty)->getBitWidth()) {
default:
return false;
case 16:
case 32:
case 64:
return true;
}
}

TypeSize
X86TTIImpl::getRegisterBitWidth(TargetTransformInfo::RegisterKind K) const {
unsigned PreferVectorWidth = ST->getPreferVectorWidth();
Expand Down Expand Up @@ -5062,17 +5083,22 @@ X86TTIImpl::getMaskedMemoryOpCost(unsigned Opcode, Type *SrcTy, Align Alignment,
std::pair<InstructionCost, MVT> LT = getTypeLegalizationCost(SrcVTy);
auto VT = TLI->getValueType(DL, SrcVTy);
InstructionCost Cost = 0;
if (VT.isSimple() && LT.second != VT.getSimpleVT() &&
MVT Ty = LT.second;
if (Ty == MVT::i16 || Ty == MVT::i32 || Ty == MVT::i64)
// APX masked load/store for scalar is cheap.
return Cost + LT.first;

if (VT.isSimple() && Ty != VT.getSimpleVT() &&
LT.second.getVectorNumElements() == NumElem)
// Promotion requires extend/truncate for data and a shuffle for mask.
Cost += getShuffleCost(TTI::SK_PermuteTwoSrc, SrcVTy, std::nullopt,
CostKind, 0, nullptr) +
getShuffleCost(TTI::SK_PermuteTwoSrc, MaskTy, std::nullopt,
CostKind, 0, nullptr);

else if (LT.first * LT.second.getVectorNumElements() > NumElem) {
else if (LT.first * Ty.getVectorNumElements() > NumElem) {
auto *NewMaskTy = FixedVectorType::get(MaskTy->getElementType(),
LT.second.getVectorNumElements());
Ty.getVectorNumElements());
// Expanding requires fill mask with zeroes
Cost += getShuffleCost(TTI::SK_InsertSubvector, NewMaskTy, std::nullopt,
CostKind, 0, MaskTy);
Expand Down Expand Up @@ -5891,14 +5917,14 @@ bool X86TTIImpl::canMacroFuseCmp() {
}

bool X86TTIImpl::isLegalMaskedLoad(Type *DataTy, Align Alignment) {
if (!ST->hasAVX())
return false;
Type *ScalarTy = DataTy->getScalarType();

// The backend can't handle a single element vector.
if (isa<VectorType>(DataTy) &&
cast<FixedVectorType>(DataTy)->getNumElements() == 1)
// The backend can't handle a single element vector w/o CFCMOV.
if (isa<VectorType>(DataTy) && cast<FixedVectorType>(DataTy)->getNumElements() == 1)
return ST->hasCF() && hasConditionalLoadStoreForType(ScalarTy);

if (!ST->hasAVX())
return false;
Type *ScalarTy = DataTy->getScalarType();

if (ScalarTy->isPointerTy())
return true;
Expand Down
1 change: 1 addition & 0 deletions llvm/lib/Target/X86/X86TargetTransformInfo.h
Original file line number Diff line number Diff line change
Expand Up @@ -132,6 +132,7 @@ class X86TTIImpl : public BasicTTIImplBase<X86TTIImpl> {
/// @{

unsigned getNumberOfRegisters(unsigned ClassID) const;
bool hasConditionalLoadStoreForType(Type *Ty = nullptr) const;
TypeSize getRegisterBitWidth(TargetTransformInfo::RegisterKind K) const;
unsigned getLoadStoreVecRegBitWidth(unsigned AS) const;
unsigned getMaxInterleaveFactor(ElementCount VF);
Expand Down
Loading
Loading