-
Notifications
You must be signed in to change notification settings - Fork 14.3k
[X86][CodeGen] Support hoisting load/store with conditional faulting #96720
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
1. Add TTI interface for conditional load/store. 2. Mark 1 x i16/i32/i64 masked load/store legal so that it's not legalized in pass scalarize-masked-mem-intrin. 3. Visit 1 x i16/i32/i64 masked load/store to build a target-specific CLOAD/CSTORE node to avoid error in `DAGTypeLegalizer::ScalarizeVectorResult`. 4. Combine DAG to simplify the nodes for CLOAD/CSTORE. 5. Lower CLOAD/CSTORE to CFCMOV by pattern match. This is CodeGen part of llvm#95515
@llvm/pr-subscribers-llvm-selectiondag @llvm/pr-subscribers-llvm-analysis Author: Shengchen Kan (KanRobert) Changes
This is CodeGen part of #95515 Patch is 24.06 KiB, truncated to 20.00 KiB below, full version: https://github.com/llvm/llvm-project/pull/96720.diff 12 Files Affected:
diff --git a/llvm/include/llvm/Analysis/TargetTransformInfo.h b/llvm/include/llvm/Analysis/TargetTransformInfo.h
index f55f21c94a85a..f5c0127e1d422 100644
--- a/llvm/include/llvm/Analysis/TargetTransformInfo.h
+++ b/llvm/include/llvm/Analysis/TargetTransformInfo.h
@@ -1113,6 +1113,10 @@ class TargetTransformInfo {
/// \return the number of registers in the target-provided register class.
unsigned getNumberOfRegisters(unsigned ClassID) const;
+ /// \return true if the target supports load/store that enables fault
+ /// suppression of memory operands when the source condition is false.
+ bool hasConditionalLoadStoreForType(Type *Ty = nullptr) const;
+
/// \return the target-provided register class ID for the provided type,
/// accounting for type promotion and other type-legalization techniques that
/// the target might apply. However, it specifically does not account for the
@@ -1956,6 +1960,7 @@ class TargetTransformInfo::Concept {
virtual bool preferToKeepConstantsAttached(const Instruction &Inst,
const Function &Fn) const = 0;
virtual unsigned getNumberOfRegisters(unsigned ClassID) const = 0;
+ virtual bool hasConditionalLoadStoreForType(Type *Ty = nullptr) const = 0;
virtual unsigned getRegisterClassForType(bool Vector,
Type *Ty = nullptr) const = 0;
virtual const char *getRegisterClassName(unsigned ClassID) const = 0;
@@ -2543,6 +2548,9 @@ class TargetTransformInfo::Model final : public TargetTransformInfo::Concept {
unsigned getNumberOfRegisters(unsigned ClassID) const override {
return Impl.getNumberOfRegisters(ClassID);
}
+ bool hasConditionalLoadStoreForType(Type *Ty = nullptr) const override {
+ return Impl.hasConditionalLoadStoreForType(Ty);
+ }
unsigned getRegisterClassForType(bool Vector,
Type *Ty = nullptr) const override {
return Impl.getRegisterClassForType(Vector, Ty);
diff --git a/llvm/include/llvm/Analysis/TargetTransformInfoImpl.h b/llvm/include/llvm/Analysis/TargetTransformInfoImpl.h
index 7828bdc1f1f43..49b4bd00baed4 100644
--- a/llvm/include/llvm/Analysis/TargetTransformInfoImpl.h
+++ b/llvm/include/llvm/Analysis/TargetTransformInfoImpl.h
@@ -457,6 +457,7 @@ class TargetTransformInfoImplBase {
}
unsigned getNumberOfRegisters(unsigned ClassID) const { return 8; }
+ bool hasConditionalLoadStoreForType(Type *Ty) const { return false; }
unsigned getRegisterClassForType(bool Vector, Type *Ty = nullptr) const {
return Vector ? 1 : 0;
diff --git a/llvm/include/llvm/CodeGen/TargetLowering.h b/llvm/include/llvm/CodeGen/TargetLowering.h
index 06f7ee2a589c8..9a0df8b29d752 100644
--- a/llvm/include/llvm/CodeGen/TargetLowering.h
+++ b/llvm/include/llvm/CodeGen/TargetLowering.h
@@ -3895,6 +3895,20 @@ class TargetLowering : public TargetLoweringBase {
const SDValue OldRHS, SDValue &Chain,
bool IsSignaling = false) const;
+ virtual SDValue visitMaskedLoad(SelectionDAG &DAG, const SDLoc &DL,
+ SDValue Chain, MachineMemOperand *MMO,
+ SDValue &NewLoad, SDValue Ptr,
+ SDValue PassThru, SDValue Mask) const {
+ llvm_unreachable("Not Implemented");
+ }
+
+ virtual SDValue visitMaskedStore(SelectionDAG &DAG, const SDLoc &DL,
+ SDValue Chain, MachineMemOperand *MMO,
+ SDValue Ptr, SDValue Val,
+ SDValue Mask) const {
+ llvm_unreachable("Not Implemented");
+ }
+
/// Returns a pair of (return value, chain).
/// It is an error to pass RTLIB::UNKNOWN_LIBCALL as \p LC.
std::pair<SDValue, SDValue> makeLibCall(SelectionDAG &DAG, RTLIB::Libcall LC,
diff --git a/llvm/lib/Analysis/TargetTransformInfo.cpp b/llvm/lib/Analysis/TargetTransformInfo.cpp
index 7e721cbc87f3f..0db8a4201fead 100644
--- a/llvm/lib/Analysis/TargetTransformInfo.cpp
+++ b/llvm/lib/Analysis/TargetTransformInfo.cpp
@@ -722,6 +722,10 @@ unsigned TargetTransformInfo::getNumberOfRegisters(unsigned ClassID) const {
return TTIImpl->getNumberOfRegisters(ClassID);
}
+bool TargetTransformInfo::hasConditionalLoadStoreForType(Type *Ty) const {
+ return TTIImpl->hasConditionalLoadStoreForType(Ty);
+}
+
unsigned TargetTransformInfo::getRegisterClassForType(bool Vector,
Type *Ty) const {
return TTIImpl->getRegisterClassForType(Vector, Ty);
diff --git a/llvm/lib/CodeGen/SelectionDAG/SelectionDAGBuilder.cpp b/llvm/lib/CodeGen/SelectionDAG/SelectionDAGBuilder.cpp
index 296b06187ec0f..1f9e73ef949e8 100644
--- a/llvm/lib/CodeGen/SelectionDAG/SelectionDAGBuilder.cpp
+++ b/llvm/lib/CodeGen/SelectionDAG/SelectionDAGBuilder.cpp
@@ -4783,9 +4783,18 @@ void SelectionDAGBuilder::visitMaskedStore(const CallInst &I,
MachineMemOperand *MMO = DAG.getMachineFunction().getMachineMemOperand(
MachinePointerInfo(PtrOperand), MMOFlags,
LocationSize::beforeOrAfterPointer(), Alignment, I.getAAMetadata());
+
+ const auto &TLI = DAG.getTargetLoweringInfo();
+ const auto &TTI =
+ TLI.getTargetMachine().getTargetTransformInfo(*I.getFunction());
SDValue StoreNode =
- DAG.getMaskedStore(getMemoryRoot(), sdl, Src0, Ptr, Offset, Mask, VT, MMO,
- ISD::UNINDEXED, false /* Truncating */, IsCompressing);
+ (!IsCompressing && TTI.hasConditionalLoadStoreForType(
+ I.getArgOperand(0)->getType()->getScalarType()))
+ ? TLI.visitMaskedStore(DAG, sdl, getMemoryRoot(), MMO, Ptr, Src0,
+ Mask)
+ : DAG.getMaskedStore(getMemoryRoot(), sdl, Src0, Ptr, Offset, Mask,
+ VT, MMO, ISD::UNINDEXED, /*Truncating=*/false,
+ IsCompressing);
DAG.setRoot(StoreNode);
setValue(&I, StoreNode);
}
@@ -4958,12 +4967,22 @@ void SelectionDAGBuilder::visitMaskedLoad(const CallInst &I, bool IsExpanding) {
MachinePointerInfo(PtrOperand), MMOFlags,
LocationSize::beforeOrAfterPointer(), Alignment, AAInfo, Ranges);
- SDValue Load =
- DAG.getMaskedLoad(VT, sdl, InChain, Ptr, Offset, Mask, Src0, VT, MMO,
- ISD::UNINDEXED, ISD::NON_EXTLOAD, IsExpanding);
+ const auto &TLI = DAG.getTargetLoweringInfo();
+ const auto &TTI =
+ TLI.getTargetMachine().getTargetTransformInfo(*I.getFunction());
+ // The Load/Res may point to different values.
+ SDValue Load;
+ SDValue Res;
+ if (!IsExpanding && TTI.hasConditionalLoadStoreForType(
+ Src0Operand->getType()->getScalarType()))
+ Res = TLI.visitMaskedLoad(DAG, sdl, InChain, MMO, Load, Ptr, Src0, Mask);
+ else
+ Res = Load =
+ DAG.getMaskedLoad(VT, sdl, InChain, Ptr, Offset, Mask, Src0, VT, MMO,
+ ISD::UNINDEXED, ISD::NON_EXTLOAD, IsExpanding);
if (AddToChain)
PendingLoads.push_back(Load.getValue(1));
- setValue(&I, Load);
+ setValue(&I, Res);
}
void SelectionDAGBuilder::visitMaskedGather(const CallInst &I) {
diff --git a/llvm/lib/Target/X86/X86ISelLowering.cpp b/llvm/lib/Target/X86/X86ISelLowering.cpp
index f27c935812f51..a45e18ae67a91 100644
--- a/llvm/lib/Target/X86/X86ISelLowering.cpp
+++ b/llvm/lib/Target/X86/X86ISelLowering.cpp
@@ -32308,6 +32308,55 @@ bool X86TargetLowering::isInlineAsmTargetBranch(
return Inst.equals_insensitive("call") || Inst.equals_insensitive("jmp");
}
+static SDValue getFlagsOfCmpZeroFori1(SelectionDAG &DAG, const SDLoc &DL,
+ SDValue V) {
+ assert(V.getValueType() == MVT::i1 && "assume i1 value");
+ EVT Ty = MVT::i8;
+ SDValue VE = DAG.getZExtOrTrunc(V, DL, Ty);
+ SDValue Zero = DAG.getConstant(0, DL, Ty);
+ SDVTList X86SubVTs = DAG.getVTList(Ty, MVT::i32);
+ SDValue CmpZero = DAG.getNode(X86ISD::SUB, DL, X86SubVTs, Zero, VE);
+ return SDValue(CmpZero.getNode(), 1);
+}
+
+SDValue X86TargetLowering::visitMaskedLoad(
+ SelectionDAG &DAG, const SDLoc &DL, SDValue Chain, MachineMemOperand *MMO,
+ SDValue &NewLoad, SDValue Ptr, SDValue PassThru, SDValue Mask) const {
+ // @llvm.masked.load.*(ptr, alignment, mask, passthru)
+ // ->
+ // _, flags = SUB 0, mask
+ // res, chain = CLOAD inchain, ptr, (bit_cast_to_scalar passthru), cond, flags
+ // bit_cast_to_vector<res>
+ EVT VTy = PassThru.getValueType();
+ EVT Ty = VTy.getVectorElementType();
+ SDVTList Tys = DAG.getVTList(Ty, MVT::Other);
+ SDValue ScalarPassThru = DAG.getBitcast(Ty, PassThru);
+ SDValue ScalarMask = DAG.getBitcast(MVT::i1, Mask);
+ SDValue Flags = getFlagsOfCmpZeroFori1(DAG, DL, ScalarMask);
+ SDValue COND_NE = DAG.getTargetConstant(X86::COND_NE, DL, MVT::i8);
+ SDValue Ops[] = {Chain, Ptr, ScalarPassThru, COND_NE, Flags};
+ NewLoad = DAG.getMemIntrinsicNode(X86ISD::CLOAD, DL, Tys, Ops, Ty, MMO);
+ return DAG.getBitcast(VTy, NewLoad);
+}
+
+SDValue X86TargetLowering::visitMaskedStore(SelectionDAG &DAG, const SDLoc &DL,
+ SDValue Chain,
+ MachineMemOperand *MMO, SDValue Ptr,
+ SDValue Val, SDValue Mask) const {
+ // llvm.masked.store.*(Src0, Ptr, alignment, Mask)
+ // ->
+ // _, flags = SUB 0, mask
+ // chain = CSTORE inchain, (bit_cast_to_scalar val), ptr, cond, flags
+ EVT Ty = Val.getValueType().getVectorElementType();
+ SDVTList Tys = DAG.getVTList(MVT::Other);
+ SDValue ScalarVal = DAG.getBitcast(Ty, Val);
+ SDValue ScalarMask = DAG.getBitcast(MVT::i1, Mask);
+ SDValue Flags = getFlagsOfCmpZeroFori1(DAG, DL, ScalarMask);
+ SDValue COND_NE = DAG.getTargetConstant(X86::COND_NE, DL, MVT::i8);
+ SDValue Ops[] = {Chain, ScalarVal, Ptr, COND_NE, Flags};
+ return DAG.getMemIntrinsicNode(X86ISD::CSTORE, DL, Tys, Ops, Ty, MMO);
+}
+
/// Provide custom lowering hooks for some operations.
SDValue X86TargetLowering::LowerOperation(SDValue Op, SelectionDAG &DAG) const {
switch (Op.getOpcode()) {
@@ -34024,6 +34073,8 @@ const char *X86TargetLowering::getTargetNodeName(unsigned Opcode) const {
NODE_NAME_CASE(STRICT_FP80_ADD)
NODE_NAME_CASE(CCMP)
NODE_NAME_CASE(CTEST)
+ NODE_NAME_CASE(CLOAD)
+ NODE_NAME_CASE(CSTORE)
}
return nullptr;
#undef NODE_NAME_CASE
@@ -55633,6 +55684,36 @@ static SDValue combineSubSetcc(SDNode *N, SelectionDAG &DAG) {
return SDValue();
}
+static SDValue combineX86CloadCstore(SDNode *N, SelectionDAG &DAG) {
+ // res, flags2 = sub 0, (setcc cc, flag)
+ // cload/cstore ..., cond_ne, flag2
+ // ->
+ // cload/cstore cc, flag
+ //
+ // if res has no users, where op is cload/cstore.
+ if (N->getConstantOperandVal(3) != X86::COND_NE)
+ return SDValue();
+
+ SDNode *Sub = N->getOperand(4).getNode();
+ if (Sub->getOpcode() != X86ISD::SUB)
+ return SDValue();
+
+ SDValue Op1 = Sub->getOperand(1);
+
+ if (Sub->hasAnyUseOfValue(0) || !X86::isZeroNode(Sub->getOperand(0)) ||
+ Op1.getOpcode() != X86ISD::SETCC)
+ return SDValue();
+
+
+ SmallVector<SDValue> Ops(N->op_values());
+ Ops[3] = Op1.getOperand(0);
+ Ops[4] = Op1.getOperand(1);
+
+ return DAG.getMemIntrinsicNode(N->getOpcode(), SDLoc(N), N->getVTList(), Ops,
+ cast<MemSDNode>(N)->getMemoryVT(),
+ cast<MemSDNode>(N)->getMemOperand());
+}
+
static SDValue combineSub(SDNode *N, SelectionDAG &DAG,
TargetLowering::DAGCombinerInfo &DCI,
const X86Subtarget &Subtarget) {
@@ -57340,6 +57421,8 @@ SDValue X86TargetLowering::PerformDAGCombine(SDNode *N,
case ISD::SUB: return combineSub(N, DAG, DCI, Subtarget);
case X86ISD::ADD:
case X86ISD::SUB: return combineX86AddSub(N, DAG, DCI, Subtarget);
+ case X86ISD::CLOAD:
+ case X86ISD::CSTORE: return combineX86CloadCstore(N, DAG);
case X86ISD::SBB: return combineSBB(N, DAG);
case X86ISD::ADC: return combineADC(N, DAG, DCI);
case ISD::MUL: return combineMul(N, DAG, DCI, Subtarget);
diff --git a/llvm/lib/Target/X86/X86ISelLowering.h b/llvm/lib/Target/X86/X86ISelLowering.h
index 3c5c903bc0d98..362daa98e1f8e 100644
--- a/llvm/lib/Target/X86/X86ISelLowering.h
+++ b/llvm/lib/Target/X86/X86ISelLowering.h
@@ -903,6 +903,10 @@ namespace llvm {
// is needed so that this can be expanded with control flow.
VASTART_SAVE_XMM_REGS,
+ // Conditional load/store instructions
+ CLOAD,
+ CSTORE,
+
// WARNING: Do not add anything in the end unless you want the node to
// have memop! In fact, starting from FIRST_TARGET_MEMORY_OPCODE all
// opcodes will be thought as target memory ops!
@@ -1556,6 +1560,14 @@ namespace llvm {
bool isInlineAsmTargetBranch(const SmallVectorImpl<StringRef> &AsmStrs,
unsigned OpNo) const override;
+ SDValue visitMaskedLoad(SelectionDAG &DAG, const SDLoc &DL, SDValue Chain,
+ MachineMemOperand *MMO, SDValue &NewLoad,
+ SDValue Ptr, SDValue PassThru,
+ SDValue Mask) const override;
+ SDValue visitMaskedStore(SelectionDAG &DAG, const SDLoc &DL, SDValue Chain,
+ MachineMemOperand *MMO, SDValue Ptr, SDValue Val,
+ SDValue Mask) const override;
+
/// Lower interleaved load(s) into target specific
/// instructions/intrinsics.
bool lowerInterleavedLoad(LoadInst *LI,
diff --git a/llvm/lib/Target/X86/X86InstrCMovSetCC.td b/llvm/lib/Target/X86/X86InstrCMovSetCC.td
index e27aa4115990e..543057c58035a 100644
--- a/llvm/lib/Target/X86/X86InstrCMovSetCC.td
+++ b/llvm/lib/Target/X86/X86InstrCMovSetCC.td
@@ -113,6 +113,35 @@ let Predicates = [HasCMOV, HasCF] in {
(CFCMOV32rr GR32:$src1, (inv_cond_XFORM timm:$cond))>;
def : Pat<(X86cmov GR64:$src1, 0, timm:$cond, EFLAGS),
(CFCMOV64rr GR64:$src1, (inv_cond_XFORM timm:$cond))>;
+
+ def : Pat<(X86cload addr:$src1, 0, timm:$cond, EFLAGS),
+ (CFCMOV16rm addr:$src1, timm:$cond)>;
+ def : Pat<(X86cload addr:$src1, 0, timm:$cond, EFLAGS),
+ (CFCMOV32rm addr:$src1, timm:$cond)>;
+ def : Pat<(X86cload addr:$src1, 0, timm:$cond, EFLAGS),
+ (CFCMOV64rm addr:$src1, timm:$cond)>;
+
+ // FIXME: Shouldn't patterns for 0 work for undef?
+ def : Pat<(X86cload addr:$src1, undef, timm:$cond, EFLAGS),
+ (CFCMOV16rm addr:$src1, timm:$cond)>;
+ def : Pat<(X86cload addr:$src1, undef, timm:$cond, EFLAGS),
+ (CFCMOV32rm addr:$src1, timm:$cond)>;
+ def : Pat<(X86cload addr:$src1, undef, timm:$cond, EFLAGS),
+ (CFCMOV64rm addr:$src1, timm:$cond)>;
+
+ def : Pat<(X86cload addr:$src2, GR16:$src1, timm:$cond, EFLAGS),
+ (CFCMOV16rm_ND GR16:$src1, addr:$src2, timm:$cond)>;
+ def : Pat<(X86cload addr:$src2, GR32:$src1, timm:$cond, EFLAGS),
+ (CFCMOV32rm_ND GR32:$src1, addr:$src2, timm:$cond)>;
+ def : Pat<(X86cload addr:$src2, GR64:$src1, timm:$cond, EFLAGS),
+ (CFCMOV64rm_ND GR64:$src1, addr:$src2, timm:$cond)>;
+
+ def : Pat<(X86cstore GR16:$src2, addr:$src1, timm:$cond, EFLAGS),
+ (CFCMOV16mr addr:$src1, GR16:$src2, timm:$cond)>;
+ def : Pat<(X86cstore GR32:$src2, addr:$src1, timm:$cond, EFLAGS),
+ (CFCMOV32mr addr:$src1, GR32:$src2, timm:$cond)>;
+ def : Pat<(X86cstore GR64:$src2, addr:$src1, timm:$cond, EFLAGS),
+ (CFCMOV64mr addr:$src1, GR64:$src2, timm:$cond)>;
}
// SetCC instructions.
diff --git a/llvm/lib/Target/X86/X86InstrFragments.td b/llvm/lib/Target/X86/X86InstrFragments.td
index 162e322712a6d..972b56e0f0cfe 100644
--- a/llvm/lib/Target/X86/X86InstrFragments.td
+++ b/llvm/lib/Target/X86/X86InstrFragments.td
@@ -15,6 +15,15 @@ def SDTX86FCmp : SDTypeProfile<1, 2, [SDTCisVT<0, i32>, SDTCisFP<1>,
def SDTX86Ccmp : SDTypeProfile<1, 5,
[SDTCisVT<3, i8>, SDTCisVT<4, i8>, SDTCisVT<5, i32>]>;
+// res, chain = CLOAD inchain, ptr, passthru, cond, flags
+def SDTX86Cload : SDTypeProfile<1, 4,
+ [SDTCisInt<0>, SDTCisPtrTy<1>, SDTCisSameAs<0, 2>,
+ SDTCisVT<3, i8>, SDTCisVT<4, i32>]>;
+// chain = CSTORE inchain, val, ptr, cond, flags
+def SDTX86Cstore : SDTypeProfile<0, 4,
+ [SDTCisInt<0>, SDTCisPtrTy<1>,
+ SDTCisVT<2, i8>, SDTCisVT<3, i32>]>;
+
def SDTX86Cmov : SDTypeProfile<1, 4,
[SDTCisSameAs<0, 1>, SDTCisSameAs<1, 2>,
SDTCisVT<3, i8>, SDTCisVT<4, i32>]>;
@@ -144,6 +153,9 @@ def X86bt : SDNode<"X86ISD::BT", SDTX86CmpTest>;
def X86ccmp : SDNode<"X86ISD::CCMP", SDTX86Ccmp>;
def X86ctest : SDNode<"X86ISD::CTEST", SDTX86Ccmp>;
+def X86cload : SDNode<"X86ISD::CLOAD", SDTX86Cload, [SDNPHasChain, SDNPMayLoad, SDNPMemOperand]>;
+def X86cstore : SDNode<"X86ISD::CSTORE", SDTX86Cstore, [SDNPHasChain, SDNPMayStore, SDNPMemOperand]>;
+
def X86cmov : SDNode<"X86ISD::CMOV", SDTX86Cmov>;
def X86brcond : SDNode<"X86ISD::BRCOND", SDTX86BrCond,
[SDNPHasChain]>;
diff --git a/llvm/lib/Target/X86/X86TargetTransformInfo.cpp b/llvm/lib/Target/X86/X86TargetTransformInfo.cpp
index de0144331dba3..aad4b9039bbb1 100644
--- a/llvm/lib/Target/X86/X86TargetTransformInfo.cpp
+++ b/llvm/lib/Target/X86/X86TargetTransformInfo.cpp
@@ -176,6 +176,27 @@ unsigned X86TTIImpl::getNumberOfRegisters(unsigned ClassID) const {
return 8;
}
+bool X86TTIImpl::hasConditionalLoadStoreForType(Type *Ty) const {
+ if (!ST->hasCF())
+ return false;
+ if (!Ty)
+ return true;
+ // Conditional faulting is supported by CFCMOV, which only accepts
+ // 16/32/64-bit operands.
+ // TODO: Support f32/f64 with VMOVSS/VMOVSD with zero mask when it's
+ // profitable.
+ if (!Ty->isIntegerTy())
+ return false;
+ switch (cast<IntegerType>(Ty)->getBitWidth()) {
+ default:
+ return false;
+ case 16:
+ case 32:
+ case 64:
+ return true;
+ }
+}
+
TypeSize
X86TTIImpl::getRegisterBitWidth(TargetTransformInfo::RegisterKind K) const {
unsigned PreferVectorWidth = ST->getPreferVectorWidth();
@@ -5062,7 +5083,12 @@ X86TTIImpl::getMaskedMemoryOpCost(unsigned Opcode, Type *SrcTy, Align Alignment,
std::pair<InstructionCost, MVT> LT = getTypeLegalizationCost(SrcVTy);
auto VT = TLI->getValueType(DL, SrcVTy);
InstructionCost Cost = 0;
- if (VT.isSimple() && LT.second != VT.getSimpleVT() &&
+ MVT Ty = LT.second;
+ if (Ty == MVT::i16 || Ty == MVT::i32 || Ty == MVT::i64)
+ // APX masked load/store for scalar is cheap.
+ return Cost + LT.first;
+
+ if (VT.isSimple() && Ty != VT.getSimpleVT() &&
LT.second.getVectorNumElements() == NumElem)
// Promotion requires extend/truncate for data and a shuffle for mask.
Cost += getShuffleCost(TTI::SK_PermuteTwoSrc, SrcVTy, std::nullopt,
@@ -5070,9 +5096,9 @@ X86TTIImpl::getMaskedMemoryOpCost(unsigned Opcode, Type *SrcTy, Align Alignment,
getShuffleCost(TTI::SK_PermuteTwoSrc, MaskTy, std::nullopt,
CostKind, 0, nullptr);
- else if (LT.first * LT.second.getVectorNumElements() > NumElem) {
+ else if (LT.first * Ty.getVectorNumElements() > NumElem) {
auto *NewMaskTy = FixedVectorType::get(MaskTy->getElementType(),
- LT.second.getVectorNumElements());
+ Ty.getVectorNumElements());
// Expanding requires fill mask with zeroes
Cost += getShuffleCost(TTI::SK_InsertSubvector, NewMaskTy, std::nullopt,
CostKind, 0, MaskTy);
@@ -5891,14 +5917,21 @@ bool X86TTIImpl::canMacroFuseCmp() {
}
bool X86TTIImpl::isLegalMaskedLoad(Type *DataTy, Align Alignment) {
+ bool IsSingleElementVector =
+ isa<VectorType>(DataTy) &&
+ cast<FixedVectorType>(DataTy)->getNumElements() == 1;
+ Type *ScalarTy = DataTy->getScalarType();
+
+ if (ST->hasCF() && IsSingleElementVector &&
+ hasConditionalLoadStoreForType(ScalarTy))
+ return true;
+
if (!ST->hasAVX())
return false;
- // The backend can't handle a single element vector.
- if (isa<VectorType>(Dat...
[truncated]
|
@llvm/pr-subscribers-backend-x86 Author: Shengchen Kan (KanRobert) Changes
This is CodeGen part of #95515 Patch is 24.06 KiB, truncated to 20.00 KiB below, full version: https://github.com/llvm/llvm-project/pull/96720.diff 12 Files Affected:
diff --git a/llvm/include/llvm/Analysis/TargetTransformInfo.h b/llvm/include/llvm/Analysis/TargetTransformInfo.h
index f55f21c94a85a..f5c0127e1d422 100644
--- a/llvm/include/llvm/Analysis/TargetTransformInfo.h
+++ b/llvm/include/llvm/Analysis/TargetTransformInfo.h
@@ -1113,6 +1113,10 @@ class TargetTransformInfo {
/// \return the number of registers in the target-provided register class.
unsigned getNumberOfRegisters(unsigned ClassID) const;
+ /// \return true if the target supports load/store that enables fault
+ /// suppression of memory operands when the source condition is false.
+ bool hasConditionalLoadStoreForType(Type *Ty = nullptr) const;
+
/// \return the target-provided register class ID for the provided type,
/// accounting for type promotion and other type-legalization techniques that
/// the target might apply. However, it specifically does not account for the
@@ -1956,6 +1960,7 @@ class TargetTransformInfo::Concept {
virtual bool preferToKeepConstantsAttached(const Instruction &Inst,
const Function &Fn) const = 0;
virtual unsigned getNumberOfRegisters(unsigned ClassID) const = 0;
+ virtual bool hasConditionalLoadStoreForType(Type *Ty = nullptr) const = 0;
virtual unsigned getRegisterClassForType(bool Vector,
Type *Ty = nullptr) const = 0;
virtual const char *getRegisterClassName(unsigned ClassID) const = 0;
@@ -2543,6 +2548,9 @@ class TargetTransformInfo::Model final : public TargetTransformInfo::Concept {
unsigned getNumberOfRegisters(unsigned ClassID) const override {
return Impl.getNumberOfRegisters(ClassID);
}
+ bool hasConditionalLoadStoreForType(Type *Ty = nullptr) const override {
+ return Impl.hasConditionalLoadStoreForType(Ty);
+ }
unsigned getRegisterClassForType(bool Vector,
Type *Ty = nullptr) const override {
return Impl.getRegisterClassForType(Vector, Ty);
diff --git a/llvm/include/llvm/Analysis/TargetTransformInfoImpl.h b/llvm/include/llvm/Analysis/TargetTransformInfoImpl.h
index 7828bdc1f1f43..49b4bd00baed4 100644
--- a/llvm/include/llvm/Analysis/TargetTransformInfoImpl.h
+++ b/llvm/include/llvm/Analysis/TargetTransformInfoImpl.h
@@ -457,6 +457,7 @@ class TargetTransformInfoImplBase {
}
unsigned getNumberOfRegisters(unsigned ClassID) const { return 8; }
+ bool hasConditionalLoadStoreForType(Type *Ty) const { return false; }
unsigned getRegisterClassForType(bool Vector, Type *Ty = nullptr) const {
return Vector ? 1 : 0;
diff --git a/llvm/include/llvm/CodeGen/TargetLowering.h b/llvm/include/llvm/CodeGen/TargetLowering.h
index 06f7ee2a589c8..9a0df8b29d752 100644
--- a/llvm/include/llvm/CodeGen/TargetLowering.h
+++ b/llvm/include/llvm/CodeGen/TargetLowering.h
@@ -3895,6 +3895,20 @@ class TargetLowering : public TargetLoweringBase {
const SDValue OldRHS, SDValue &Chain,
bool IsSignaling = false) const;
+ virtual SDValue visitMaskedLoad(SelectionDAG &DAG, const SDLoc &DL,
+ SDValue Chain, MachineMemOperand *MMO,
+ SDValue &NewLoad, SDValue Ptr,
+ SDValue PassThru, SDValue Mask) const {
+ llvm_unreachable("Not Implemented");
+ }
+
+ virtual SDValue visitMaskedStore(SelectionDAG &DAG, const SDLoc &DL,
+ SDValue Chain, MachineMemOperand *MMO,
+ SDValue Ptr, SDValue Val,
+ SDValue Mask) const {
+ llvm_unreachable("Not Implemented");
+ }
+
/// Returns a pair of (return value, chain).
/// It is an error to pass RTLIB::UNKNOWN_LIBCALL as \p LC.
std::pair<SDValue, SDValue> makeLibCall(SelectionDAG &DAG, RTLIB::Libcall LC,
diff --git a/llvm/lib/Analysis/TargetTransformInfo.cpp b/llvm/lib/Analysis/TargetTransformInfo.cpp
index 7e721cbc87f3f..0db8a4201fead 100644
--- a/llvm/lib/Analysis/TargetTransformInfo.cpp
+++ b/llvm/lib/Analysis/TargetTransformInfo.cpp
@@ -722,6 +722,10 @@ unsigned TargetTransformInfo::getNumberOfRegisters(unsigned ClassID) const {
return TTIImpl->getNumberOfRegisters(ClassID);
}
+bool TargetTransformInfo::hasConditionalLoadStoreForType(Type *Ty) const {
+ return TTIImpl->hasConditionalLoadStoreForType(Ty);
+}
+
unsigned TargetTransformInfo::getRegisterClassForType(bool Vector,
Type *Ty) const {
return TTIImpl->getRegisterClassForType(Vector, Ty);
diff --git a/llvm/lib/CodeGen/SelectionDAG/SelectionDAGBuilder.cpp b/llvm/lib/CodeGen/SelectionDAG/SelectionDAGBuilder.cpp
index 296b06187ec0f..1f9e73ef949e8 100644
--- a/llvm/lib/CodeGen/SelectionDAG/SelectionDAGBuilder.cpp
+++ b/llvm/lib/CodeGen/SelectionDAG/SelectionDAGBuilder.cpp
@@ -4783,9 +4783,18 @@ void SelectionDAGBuilder::visitMaskedStore(const CallInst &I,
MachineMemOperand *MMO = DAG.getMachineFunction().getMachineMemOperand(
MachinePointerInfo(PtrOperand), MMOFlags,
LocationSize::beforeOrAfterPointer(), Alignment, I.getAAMetadata());
+
+ const auto &TLI = DAG.getTargetLoweringInfo();
+ const auto &TTI =
+ TLI.getTargetMachine().getTargetTransformInfo(*I.getFunction());
SDValue StoreNode =
- DAG.getMaskedStore(getMemoryRoot(), sdl, Src0, Ptr, Offset, Mask, VT, MMO,
- ISD::UNINDEXED, false /* Truncating */, IsCompressing);
+ (!IsCompressing && TTI.hasConditionalLoadStoreForType(
+ I.getArgOperand(0)->getType()->getScalarType()))
+ ? TLI.visitMaskedStore(DAG, sdl, getMemoryRoot(), MMO, Ptr, Src0,
+ Mask)
+ : DAG.getMaskedStore(getMemoryRoot(), sdl, Src0, Ptr, Offset, Mask,
+ VT, MMO, ISD::UNINDEXED, /*Truncating=*/false,
+ IsCompressing);
DAG.setRoot(StoreNode);
setValue(&I, StoreNode);
}
@@ -4958,12 +4967,22 @@ void SelectionDAGBuilder::visitMaskedLoad(const CallInst &I, bool IsExpanding) {
MachinePointerInfo(PtrOperand), MMOFlags,
LocationSize::beforeOrAfterPointer(), Alignment, AAInfo, Ranges);
- SDValue Load =
- DAG.getMaskedLoad(VT, sdl, InChain, Ptr, Offset, Mask, Src0, VT, MMO,
- ISD::UNINDEXED, ISD::NON_EXTLOAD, IsExpanding);
+ const auto &TLI = DAG.getTargetLoweringInfo();
+ const auto &TTI =
+ TLI.getTargetMachine().getTargetTransformInfo(*I.getFunction());
+ // The Load/Res may point to different values.
+ SDValue Load;
+ SDValue Res;
+ if (!IsExpanding && TTI.hasConditionalLoadStoreForType(
+ Src0Operand->getType()->getScalarType()))
+ Res = TLI.visitMaskedLoad(DAG, sdl, InChain, MMO, Load, Ptr, Src0, Mask);
+ else
+ Res = Load =
+ DAG.getMaskedLoad(VT, sdl, InChain, Ptr, Offset, Mask, Src0, VT, MMO,
+ ISD::UNINDEXED, ISD::NON_EXTLOAD, IsExpanding);
if (AddToChain)
PendingLoads.push_back(Load.getValue(1));
- setValue(&I, Load);
+ setValue(&I, Res);
}
void SelectionDAGBuilder::visitMaskedGather(const CallInst &I) {
diff --git a/llvm/lib/Target/X86/X86ISelLowering.cpp b/llvm/lib/Target/X86/X86ISelLowering.cpp
index f27c935812f51..a45e18ae67a91 100644
--- a/llvm/lib/Target/X86/X86ISelLowering.cpp
+++ b/llvm/lib/Target/X86/X86ISelLowering.cpp
@@ -32308,6 +32308,55 @@ bool X86TargetLowering::isInlineAsmTargetBranch(
return Inst.equals_insensitive("call") || Inst.equals_insensitive("jmp");
}
+static SDValue getFlagsOfCmpZeroFori1(SelectionDAG &DAG, const SDLoc &DL,
+ SDValue V) {
+ assert(V.getValueType() == MVT::i1 && "assume i1 value");
+ EVT Ty = MVT::i8;
+ SDValue VE = DAG.getZExtOrTrunc(V, DL, Ty);
+ SDValue Zero = DAG.getConstant(0, DL, Ty);
+ SDVTList X86SubVTs = DAG.getVTList(Ty, MVT::i32);
+ SDValue CmpZero = DAG.getNode(X86ISD::SUB, DL, X86SubVTs, Zero, VE);
+ return SDValue(CmpZero.getNode(), 1);
+}
+
+SDValue X86TargetLowering::visitMaskedLoad(
+ SelectionDAG &DAG, const SDLoc &DL, SDValue Chain, MachineMemOperand *MMO,
+ SDValue &NewLoad, SDValue Ptr, SDValue PassThru, SDValue Mask) const {
+ // @llvm.masked.load.*(ptr, alignment, mask, passthru)
+ // ->
+ // _, flags = SUB 0, mask
+ // res, chain = CLOAD inchain, ptr, (bit_cast_to_scalar passthru), cond, flags
+ // bit_cast_to_vector<res>
+ EVT VTy = PassThru.getValueType();
+ EVT Ty = VTy.getVectorElementType();
+ SDVTList Tys = DAG.getVTList(Ty, MVT::Other);
+ SDValue ScalarPassThru = DAG.getBitcast(Ty, PassThru);
+ SDValue ScalarMask = DAG.getBitcast(MVT::i1, Mask);
+ SDValue Flags = getFlagsOfCmpZeroFori1(DAG, DL, ScalarMask);
+ SDValue COND_NE = DAG.getTargetConstant(X86::COND_NE, DL, MVT::i8);
+ SDValue Ops[] = {Chain, Ptr, ScalarPassThru, COND_NE, Flags};
+ NewLoad = DAG.getMemIntrinsicNode(X86ISD::CLOAD, DL, Tys, Ops, Ty, MMO);
+ return DAG.getBitcast(VTy, NewLoad);
+}
+
+SDValue X86TargetLowering::visitMaskedStore(SelectionDAG &DAG, const SDLoc &DL,
+ SDValue Chain,
+ MachineMemOperand *MMO, SDValue Ptr,
+ SDValue Val, SDValue Mask) const {
+ // llvm.masked.store.*(Src0, Ptr, alignment, Mask)
+ // ->
+ // _, flags = SUB 0, mask
+ // chain = CSTORE inchain, (bit_cast_to_scalar val), ptr, cond, flags
+ EVT Ty = Val.getValueType().getVectorElementType();
+ SDVTList Tys = DAG.getVTList(MVT::Other);
+ SDValue ScalarVal = DAG.getBitcast(Ty, Val);
+ SDValue ScalarMask = DAG.getBitcast(MVT::i1, Mask);
+ SDValue Flags = getFlagsOfCmpZeroFori1(DAG, DL, ScalarMask);
+ SDValue COND_NE = DAG.getTargetConstant(X86::COND_NE, DL, MVT::i8);
+ SDValue Ops[] = {Chain, ScalarVal, Ptr, COND_NE, Flags};
+ return DAG.getMemIntrinsicNode(X86ISD::CSTORE, DL, Tys, Ops, Ty, MMO);
+}
+
/// Provide custom lowering hooks for some operations.
SDValue X86TargetLowering::LowerOperation(SDValue Op, SelectionDAG &DAG) const {
switch (Op.getOpcode()) {
@@ -34024,6 +34073,8 @@ const char *X86TargetLowering::getTargetNodeName(unsigned Opcode) const {
NODE_NAME_CASE(STRICT_FP80_ADD)
NODE_NAME_CASE(CCMP)
NODE_NAME_CASE(CTEST)
+ NODE_NAME_CASE(CLOAD)
+ NODE_NAME_CASE(CSTORE)
}
return nullptr;
#undef NODE_NAME_CASE
@@ -55633,6 +55684,36 @@ static SDValue combineSubSetcc(SDNode *N, SelectionDAG &DAG) {
return SDValue();
}
+static SDValue combineX86CloadCstore(SDNode *N, SelectionDAG &DAG) {
+ // res, flags2 = sub 0, (setcc cc, flag)
+ // cload/cstore ..., cond_ne, flag2
+ // ->
+ // cload/cstore cc, flag
+ //
+ // if res has no users, where op is cload/cstore.
+ if (N->getConstantOperandVal(3) != X86::COND_NE)
+ return SDValue();
+
+ SDNode *Sub = N->getOperand(4).getNode();
+ if (Sub->getOpcode() != X86ISD::SUB)
+ return SDValue();
+
+ SDValue Op1 = Sub->getOperand(1);
+
+ if (Sub->hasAnyUseOfValue(0) || !X86::isZeroNode(Sub->getOperand(0)) ||
+ Op1.getOpcode() != X86ISD::SETCC)
+ return SDValue();
+
+
+ SmallVector<SDValue> Ops(N->op_values());
+ Ops[3] = Op1.getOperand(0);
+ Ops[4] = Op1.getOperand(1);
+
+ return DAG.getMemIntrinsicNode(N->getOpcode(), SDLoc(N), N->getVTList(), Ops,
+ cast<MemSDNode>(N)->getMemoryVT(),
+ cast<MemSDNode>(N)->getMemOperand());
+}
+
static SDValue combineSub(SDNode *N, SelectionDAG &DAG,
TargetLowering::DAGCombinerInfo &DCI,
const X86Subtarget &Subtarget) {
@@ -57340,6 +57421,8 @@ SDValue X86TargetLowering::PerformDAGCombine(SDNode *N,
case ISD::SUB: return combineSub(N, DAG, DCI, Subtarget);
case X86ISD::ADD:
case X86ISD::SUB: return combineX86AddSub(N, DAG, DCI, Subtarget);
+ case X86ISD::CLOAD:
+ case X86ISD::CSTORE: return combineX86CloadCstore(N, DAG);
case X86ISD::SBB: return combineSBB(N, DAG);
case X86ISD::ADC: return combineADC(N, DAG, DCI);
case ISD::MUL: return combineMul(N, DAG, DCI, Subtarget);
diff --git a/llvm/lib/Target/X86/X86ISelLowering.h b/llvm/lib/Target/X86/X86ISelLowering.h
index 3c5c903bc0d98..362daa98e1f8e 100644
--- a/llvm/lib/Target/X86/X86ISelLowering.h
+++ b/llvm/lib/Target/X86/X86ISelLowering.h
@@ -903,6 +903,10 @@ namespace llvm {
// is needed so that this can be expanded with control flow.
VASTART_SAVE_XMM_REGS,
+ // Conditional load/store instructions
+ CLOAD,
+ CSTORE,
+
// WARNING: Do not add anything in the end unless you want the node to
// have memop! In fact, starting from FIRST_TARGET_MEMORY_OPCODE all
// opcodes will be thought as target memory ops!
@@ -1556,6 +1560,14 @@ namespace llvm {
bool isInlineAsmTargetBranch(const SmallVectorImpl<StringRef> &AsmStrs,
unsigned OpNo) const override;
+ SDValue visitMaskedLoad(SelectionDAG &DAG, const SDLoc &DL, SDValue Chain,
+ MachineMemOperand *MMO, SDValue &NewLoad,
+ SDValue Ptr, SDValue PassThru,
+ SDValue Mask) const override;
+ SDValue visitMaskedStore(SelectionDAG &DAG, const SDLoc &DL, SDValue Chain,
+ MachineMemOperand *MMO, SDValue Ptr, SDValue Val,
+ SDValue Mask) const override;
+
/// Lower interleaved load(s) into target specific
/// instructions/intrinsics.
bool lowerInterleavedLoad(LoadInst *LI,
diff --git a/llvm/lib/Target/X86/X86InstrCMovSetCC.td b/llvm/lib/Target/X86/X86InstrCMovSetCC.td
index e27aa4115990e..543057c58035a 100644
--- a/llvm/lib/Target/X86/X86InstrCMovSetCC.td
+++ b/llvm/lib/Target/X86/X86InstrCMovSetCC.td
@@ -113,6 +113,35 @@ let Predicates = [HasCMOV, HasCF] in {
(CFCMOV32rr GR32:$src1, (inv_cond_XFORM timm:$cond))>;
def : Pat<(X86cmov GR64:$src1, 0, timm:$cond, EFLAGS),
(CFCMOV64rr GR64:$src1, (inv_cond_XFORM timm:$cond))>;
+
+ def : Pat<(X86cload addr:$src1, 0, timm:$cond, EFLAGS),
+ (CFCMOV16rm addr:$src1, timm:$cond)>;
+ def : Pat<(X86cload addr:$src1, 0, timm:$cond, EFLAGS),
+ (CFCMOV32rm addr:$src1, timm:$cond)>;
+ def : Pat<(X86cload addr:$src1, 0, timm:$cond, EFLAGS),
+ (CFCMOV64rm addr:$src1, timm:$cond)>;
+
+ // FIXME: Shouldn't patterns for 0 work for undef?
+ def : Pat<(X86cload addr:$src1, undef, timm:$cond, EFLAGS),
+ (CFCMOV16rm addr:$src1, timm:$cond)>;
+ def : Pat<(X86cload addr:$src1, undef, timm:$cond, EFLAGS),
+ (CFCMOV32rm addr:$src1, timm:$cond)>;
+ def : Pat<(X86cload addr:$src1, undef, timm:$cond, EFLAGS),
+ (CFCMOV64rm addr:$src1, timm:$cond)>;
+
+ def : Pat<(X86cload addr:$src2, GR16:$src1, timm:$cond, EFLAGS),
+ (CFCMOV16rm_ND GR16:$src1, addr:$src2, timm:$cond)>;
+ def : Pat<(X86cload addr:$src2, GR32:$src1, timm:$cond, EFLAGS),
+ (CFCMOV32rm_ND GR32:$src1, addr:$src2, timm:$cond)>;
+ def : Pat<(X86cload addr:$src2, GR64:$src1, timm:$cond, EFLAGS),
+ (CFCMOV64rm_ND GR64:$src1, addr:$src2, timm:$cond)>;
+
+ def : Pat<(X86cstore GR16:$src2, addr:$src1, timm:$cond, EFLAGS),
+ (CFCMOV16mr addr:$src1, GR16:$src2, timm:$cond)>;
+ def : Pat<(X86cstore GR32:$src2, addr:$src1, timm:$cond, EFLAGS),
+ (CFCMOV32mr addr:$src1, GR32:$src2, timm:$cond)>;
+ def : Pat<(X86cstore GR64:$src2, addr:$src1, timm:$cond, EFLAGS),
+ (CFCMOV64mr addr:$src1, GR64:$src2, timm:$cond)>;
}
// SetCC instructions.
diff --git a/llvm/lib/Target/X86/X86InstrFragments.td b/llvm/lib/Target/X86/X86InstrFragments.td
index 162e322712a6d..972b56e0f0cfe 100644
--- a/llvm/lib/Target/X86/X86InstrFragments.td
+++ b/llvm/lib/Target/X86/X86InstrFragments.td
@@ -15,6 +15,15 @@ def SDTX86FCmp : SDTypeProfile<1, 2, [SDTCisVT<0, i32>, SDTCisFP<1>,
def SDTX86Ccmp : SDTypeProfile<1, 5,
[SDTCisVT<3, i8>, SDTCisVT<4, i8>, SDTCisVT<5, i32>]>;
+// res, chain = CLOAD inchain, ptr, passthru, cond, flags
+def SDTX86Cload : SDTypeProfile<1, 4,
+ [SDTCisInt<0>, SDTCisPtrTy<1>, SDTCisSameAs<0, 2>,
+ SDTCisVT<3, i8>, SDTCisVT<4, i32>]>;
+// chain = CSTORE inchain, val, ptr, cond, flags
+def SDTX86Cstore : SDTypeProfile<0, 4,
+ [SDTCisInt<0>, SDTCisPtrTy<1>,
+ SDTCisVT<2, i8>, SDTCisVT<3, i32>]>;
+
def SDTX86Cmov : SDTypeProfile<1, 4,
[SDTCisSameAs<0, 1>, SDTCisSameAs<1, 2>,
SDTCisVT<3, i8>, SDTCisVT<4, i32>]>;
@@ -144,6 +153,9 @@ def X86bt : SDNode<"X86ISD::BT", SDTX86CmpTest>;
def X86ccmp : SDNode<"X86ISD::CCMP", SDTX86Ccmp>;
def X86ctest : SDNode<"X86ISD::CTEST", SDTX86Ccmp>;
+def X86cload : SDNode<"X86ISD::CLOAD", SDTX86Cload, [SDNPHasChain, SDNPMayLoad, SDNPMemOperand]>;
+def X86cstore : SDNode<"X86ISD::CSTORE", SDTX86Cstore, [SDNPHasChain, SDNPMayStore, SDNPMemOperand]>;
+
def X86cmov : SDNode<"X86ISD::CMOV", SDTX86Cmov>;
def X86brcond : SDNode<"X86ISD::BRCOND", SDTX86BrCond,
[SDNPHasChain]>;
diff --git a/llvm/lib/Target/X86/X86TargetTransformInfo.cpp b/llvm/lib/Target/X86/X86TargetTransformInfo.cpp
index de0144331dba3..aad4b9039bbb1 100644
--- a/llvm/lib/Target/X86/X86TargetTransformInfo.cpp
+++ b/llvm/lib/Target/X86/X86TargetTransformInfo.cpp
@@ -176,6 +176,27 @@ unsigned X86TTIImpl::getNumberOfRegisters(unsigned ClassID) const {
return 8;
}
+bool X86TTIImpl::hasConditionalLoadStoreForType(Type *Ty) const {
+ if (!ST->hasCF())
+ return false;
+ if (!Ty)
+ return true;
+ // Conditional faulting is supported by CFCMOV, which only accepts
+ // 16/32/64-bit operands.
+ // TODO: Support f32/f64 with VMOVSS/VMOVSD with zero mask when it's
+ // profitable.
+ if (!Ty->isIntegerTy())
+ return false;
+ switch (cast<IntegerType>(Ty)->getBitWidth()) {
+ default:
+ return false;
+ case 16:
+ case 32:
+ case 64:
+ return true;
+ }
+}
+
TypeSize
X86TTIImpl::getRegisterBitWidth(TargetTransformInfo::RegisterKind K) const {
unsigned PreferVectorWidth = ST->getPreferVectorWidth();
@@ -5062,7 +5083,12 @@ X86TTIImpl::getMaskedMemoryOpCost(unsigned Opcode, Type *SrcTy, Align Alignment,
std::pair<InstructionCost, MVT> LT = getTypeLegalizationCost(SrcVTy);
auto VT = TLI->getValueType(DL, SrcVTy);
InstructionCost Cost = 0;
- if (VT.isSimple() && LT.second != VT.getSimpleVT() &&
+ MVT Ty = LT.second;
+ if (Ty == MVT::i16 || Ty == MVT::i32 || Ty == MVT::i64)
+ // APX masked load/store for scalar is cheap.
+ return Cost + LT.first;
+
+ if (VT.isSimple() && Ty != VT.getSimpleVT() &&
LT.second.getVectorNumElements() == NumElem)
// Promotion requires extend/truncate for data and a shuffle for mask.
Cost += getShuffleCost(TTI::SK_PermuteTwoSrc, SrcVTy, std::nullopt,
@@ -5070,9 +5096,9 @@ X86TTIImpl::getMaskedMemoryOpCost(unsigned Opcode, Type *SrcTy, Align Alignment,
getShuffleCost(TTI::SK_PermuteTwoSrc, MaskTy, std::nullopt,
CostKind, 0, nullptr);
- else if (LT.first * LT.second.getVectorNumElements() > NumElem) {
+ else if (LT.first * Ty.getVectorNumElements() > NumElem) {
auto *NewMaskTy = FixedVectorType::get(MaskTy->getElementType(),
- LT.second.getVectorNumElements());
+ Ty.getVectorNumElements());
// Expanding requires fill mask with zeroes
Cost += getShuffleCost(TTI::SK_InsertSubvector, NewMaskTy, std::nullopt,
CostKind, 0, MaskTy);
@@ -5891,14 +5917,21 @@ bool X86TTIImpl::canMacroFuseCmp() {
}
bool X86TTIImpl::isLegalMaskedLoad(Type *DataTy, Align Alignment) {
+ bool IsSingleElementVector =
+ isa<VectorType>(DataTy) &&
+ cast<FixedVectorType>(DataTy)->getNumElements() == 1;
+ Type *ScalarTy = DataTy->getScalarType();
+
+ if (ST->hasCF() && IsSingleElementVector &&
+ hasConditionalLoadStoreForType(ScalarTy))
+ return true;
+
if (!ST->hasAVX())
return false;
- // The backend can't handle a single element vector.
- if (isa<VectorType>(Dat...
[truncated]
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM. Do we need to wait for the CFG patch?
Thanks. I don't think we need to. We always use masked.load/store for the CFG transform. The possible difference is that in which stage we emit them and in which cases it's profitable. Also, middle-end patch can only be buildable after this lands. |
This comment was marked as off-topic.
This comment was marked as off-topic.
…lvm#96720) 1. Add TTI interface for conditional load/store. 2. Mark 1 x i16/i32/i64 masked load/store legal so that it's not legalized in pass scalarize-masked-mem-intrin. 3. Visit 1 x i16/i32/i64 masked load/store to build a target-specific CLOAD/CSTORE node to avoid error in `DAGTypeLegalizer::ScalarizeVectorResult`. 4. Combine DAG to simplify the nodes for CLOAD/CSTORE. 5. Lower CLOAD/CSTORE to CFCMOV by pattern match. This is CodeGen part of llvm#95515
…lvm#96720) 1. Add TTI interface for conditional load/store. 2. Mark 1 x i16/i32/i64 masked load/store legal so that it's not legalized in pass scalarize-masked-mem-intrin. 3. Visit 1 x i16/i32/i64 masked load/store to build a target-specific CLOAD/CSTORE node to avoid error in `DAGTypeLegalizer::ScalarizeVectorResult`. 4. Combine DAG to simplify the nodes for CLOAD/CSTORE. 5. Lower CLOAD/CSTORE to CFCMOV by pattern match. This is CodeGen part of llvm#95515
legalized in pass scalarize-masked-mem-intrin.
CLOAD/CSTORE node to avoid error in
DAGTypeLegalizer::ScalarizeVectorResult
.This is CodeGen part of #95515