Skip to content

[X86][CodeGen] Support hoisting load/store with conditional faulting #96720

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 6 commits into from
Jun 27, 2024

Conversation

KanRobert
Copy link
Contributor

  1. Add TTI interface for conditional load/store.
  2. Mark 1 x i16/i32/i64 masked load/store legal so that it's not
    legalized in pass scalarize-masked-mem-intrin.
  3. Visit 1 x i16/i32/i64 masked load/store to build a target-specific
    CLOAD/CSTORE node to avoid error in
    DAGTypeLegalizer::ScalarizeVectorResult.
  4. Combine DAG to simplify the nodes for CLOAD/CSTORE.
  5. Lower CLOAD/CSTORE to CFCMOV by pattern match.

This is CodeGen part of #95515

1. Add TTI interface for conditional load/store.
2. Mark 1 x i16/i32/i64 masked load/store legal so that it's not
   legalized in pass scalarize-masked-mem-intrin.
3. Visit 1 x i16/i32/i64 masked load/store to build a target-specific
   CLOAD/CSTORE node to avoid error in
   `DAGTypeLegalizer::ScalarizeVectorResult`.
4. Combine DAG to simplify the nodes for CLOAD/CSTORE.
5. Lower CLOAD/CSTORE to CFCMOV by pattern match.

This is CodeGen part of llvm#95515
@llvmbot llvmbot added backend:X86 llvm:SelectionDAG SelectionDAGISel as well llvm:analysis Includes value tracking, cost tables and constant folding labels Jun 26, 2024
@llvmbot
Copy link
Member

llvmbot commented Jun 26, 2024

@llvm/pr-subscribers-llvm-selectiondag

@llvm/pr-subscribers-llvm-analysis

Author: Shengchen Kan (KanRobert)

Changes
  1. Add TTI interface for conditional load/store.
  2. Mark 1 x i16/i32/i64 masked load/store legal so that it's not
    legalized in pass scalarize-masked-mem-intrin.
  3. Visit 1 x i16/i32/i64 masked load/store to build a target-specific
    CLOAD/CSTORE node to avoid error in
    DAGTypeLegalizer::ScalarizeVectorResult.
  4. Combine DAG to simplify the nodes for CLOAD/CSTORE.
  5. Lower CLOAD/CSTORE to CFCMOV by pattern match.

This is CodeGen part of #95515


Patch is 24.06 KiB, truncated to 20.00 KiB below, full version: https://github.com/llvm/llvm-project/pull/96720.diff

12 Files Affected:

  • (modified) llvm/include/llvm/Analysis/TargetTransformInfo.h (+8)
  • (modified) llvm/include/llvm/Analysis/TargetTransformInfoImpl.h (+1)
  • (modified) llvm/include/llvm/CodeGen/TargetLowering.h (+14)
  • (modified) llvm/lib/Analysis/TargetTransformInfo.cpp (+4)
  • (modified) llvm/lib/CodeGen/SelectionDAG/SelectionDAGBuilder.cpp (+25-6)
  • (modified) llvm/lib/Target/X86/X86ISelLowering.cpp (+83)
  • (modified) llvm/lib/Target/X86/X86ISelLowering.h (+12)
  • (modified) llvm/lib/Target/X86/X86InstrCMovSetCC.td (+29)
  • (modified) llvm/lib/Target/X86/X86InstrFragments.td (+12)
  • (modified) llvm/lib/Target/X86/X86TargetTransformInfo.cpp (+40-7)
  • (modified) llvm/lib/Target/X86/X86TargetTransformInfo.h (+1)
  • (added) llvm/test/CodeGen/X86/apx/cf.ll (+85)
diff --git a/llvm/include/llvm/Analysis/TargetTransformInfo.h b/llvm/include/llvm/Analysis/TargetTransformInfo.h
index f55f21c94a85a..f5c0127e1d422 100644
--- a/llvm/include/llvm/Analysis/TargetTransformInfo.h
+++ b/llvm/include/llvm/Analysis/TargetTransformInfo.h
@@ -1113,6 +1113,10 @@ class TargetTransformInfo {
   /// \return the number of registers in the target-provided register class.
   unsigned getNumberOfRegisters(unsigned ClassID) const;
 
+  /// \return true if the target supports load/store that enables fault
+  /// suppression of memory operands when the source condition is false.
+  bool hasConditionalLoadStoreForType(Type *Ty = nullptr) const;
+
   /// \return the target-provided register class ID for the provided type,
   /// accounting for type promotion and other type-legalization techniques that
   /// the target might apply. However, it specifically does not account for the
@@ -1956,6 +1960,7 @@ class TargetTransformInfo::Concept {
   virtual bool preferToKeepConstantsAttached(const Instruction &Inst,
                                              const Function &Fn) const = 0;
   virtual unsigned getNumberOfRegisters(unsigned ClassID) const = 0;
+  virtual bool hasConditionalLoadStoreForType(Type *Ty = nullptr) const = 0;
   virtual unsigned getRegisterClassForType(bool Vector,
                                            Type *Ty = nullptr) const = 0;
   virtual const char *getRegisterClassName(unsigned ClassID) const = 0;
@@ -2543,6 +2548,9 @@ class TargetTransformInfo::Model final : public TargetTransformInfo::Concept {
   unsigned getNumberOfRegisters(unsigned ClassID) const override {
     return Impl.getNumberOfRegisters(ClassID);
   }
+  bool hasConditionalLoadStoreForType(Type *Ty = nullptr) const override {
+    return Impl.hasConditionalLoadStoreForType(Ty);
+  }
   unsigned getRegisterClassForType(bool Vector,
                                    Type *Ty = nullptr) const override {
     return Impl.getRegisterClassForType(Vector, Ty);
diff --git a/llvm/include/llvm/Analysis/TargetTransformInfoImpl.h b/llvm/include/llvm/Analysis/TargetTransformInfoImpl.h
index 7828bdc1f1f43..49b4bd00baed4 100644
--- a/llvm/include/llvm/Analysis/TargetTransformInfoImpl.h
+++ b/llvm/include/llvm/Analysis/TargetTransformInfoImpl.h
@@ -457,6 +457,7 @@ class TargetTransformInfoImplBase {
   }
 
   unsigned getNumberOfRegisters(unsigned ClassID) const { return 8; }
+  bool hasConditionalLoadStoreForType(Type *Ty) const { return false; }
 
   unsigned getRegisterClassForType(bool Vector, Type *Ty = nullptr) const {
     return Vector ? 1 : 0;
diff --git a/llvm/include/llvm/CodeGen/TargetLowering.h b/llvm/include/llvm/CodeGen/TargetLowering.h
index 06f7ee2a589c8..9a0df8b29d752 100644
--- a/llvm/include/llvm/CodeGen/TargetLowering.h
+++ b/llvm/include/llvm/CodeGen/TargetLowering.h
@@ -3895,6 +3895,20 @@ class TargetLowering : public TargetLoweringBase {
                            const SDValue OldRHS, SDValue &Chain,
                            bool IsSignaling = false) const;
 
+  virtual SDValue visitMaskedLoad(SelectionDAG &DAG, const SDLoc &DL,
+                                  SDValue Chain, MachineMemOperand *MMO,
+                                  SDValue &NewLoad, SDValue Ptr,
+                                  SDValue PassThru, SDValue Mask) const {
+    llvm_unreachable("Not Implemented");
+  }
+
+  virtual SDValue visitMaskedStore(SelectionDAG &DAG, const SDLoc &DL,
+                                   SDValue Chain, MachineMemOperand *MMO,
+                                   SDValue Ptr, SDValue Val,
+                                   SDValue Mask) const {
+    llvm_unreachable("Not Implemented");
+  }
+
   /// Returns a pair of (return value, chain).
   /// It is an error to pass RTLIB::UNKNOWN_LIBCALL as \p LC.
   std::pair<SDValue, SDValue> makeLibCall(SelectionDAG &DAG, RTLIB::Libcall LC,
diff --git a/llvm/lib/Analysis/TargetTransformInfo.cpp b/llvm/lib/Analysis/TargetTransformInfo.cpp
index 7e721cbc87f3f..0db8a4201fead 100644
--- a/llvm/lib/Analysis/TargetTransformInfo.cpp
+++ b/llvm/lib/Analysis/TargetTransformInfo.cpp
@@ -722,6 +722,10 @@ unsigned TargetTransformInfo::getNumberOfRegisters(unsigned ClassID) const {
   return TTIImpl->getNumberOfRegisters(ClassID);
 }
 
+bool TargetTransformInfo::hasConditionalLoadStoreForType(Type *Ty) const {
+  return TTIImpl->hasConditionalLoadStoreForType(Ty);
+}
+
 unsigned TargetTransformInfo::getRegisterClassForType(bool Vector,
                                                       Type *Ty) const {
   return TTIImpl->getRegisterClassForType(Vector, Ty);
diff --git a/llvm/lib/CodeGen/SelectionDAG/SelectionDAGBuilder.cpp b/llvm/lib/CodeGen/SelectionDAG/SelectionDAGBuilder.cpp
index 296b06187ec0f..1f9e73ef949e8 100644
--- a/llvm/lib/CodeGen/SelectionDAG/SelectionDAGBuilder.cpp
+++ b/llvm/lib/CodeGen/SelectionDAG/SelectionDAGBuilder.cpp
@@ -4783,9 +4783,18 @@ void SelectionDAGBuilder::visitMaskedStore(const CallInst &I,
   MachineMemOperand *MMO = DAG.getMachineFunction().getMachineMemOperand(
       MachinePointerInfo(PtrOperand), MMOFlags,
       LocationSize::beforeOrAfterPointer(), Alignment, I.getAAMetadata());
+
+  const auto &TLI = DAG.getTargetLoweringInfo();
+  const auto &TTI =
+      TLI.getTargetMachine().getTargetTransformInfo(*I.getFunction());
   SDValue StoreNode =
-      DAG.getMaskedStore(getMemoryRoot(), sdl, Src0, Ptr, Offset, Mask, VT, MMO,
-                         ISD::UNINDEXED, false /* Truncating */, IsCompressing);
+      (!IsCompressing && TTI.hasConditionalLoadStoreForType(
+                             I.getArgOperand(0)->getType()->getScalarType()))
+          ? TLI.visitMaskedStore(DAG, sdl, getMemoryRoot(), MMO, Ptr, Src0,
+                                 Mask)
+          : DAG.getMaskedStore(getMemoryRoot(), sdl, Src0, Ptr, Offset, Mask,
+                               VT, MMO, ISD::UNINDEXED, /*Truncating=*/false,
+                               IsCompressing);
   DAG.setRoot(StoreNode);
   setValue(&I, StoreNode);
 }
@@ -4958,12 +4967,22 @@ void SelectionDAGBuilder::visitMaskedLoad(const CallInst &I, bool IsExpanding) {
       MachinePointerInfo(PtrOperand), MMOFlags,
       LocationSize::beforeOrAfterPointer(), Alignment, AAInfo, Ranges);
 
-  SDValue Load =
-      DAG.getMaskedLoad(VT, sdl, InChain, Ptr, Offset, Mask, Src0, VT, MMO,
-                        ISD::UNINDEXED, ISD::NON_EXTLOAD, IsExpanding);
+  const auto &TLI = DAG.getTargetLoweringInfo();
+  const auto &TTI =
+      TLI.getTargetMachine().getTargetTransformInfo(*I.getFunction());
+  // The Load/Res may point to different values.
+  SDValue Load;
+  SDValue Res;
+  if (!IsExpanding && TTI.hasConditionalLoadStoreForType(
+                          Src0Operand->getType()->getScalarType()))
+    Res = TLI.visitMaskedLoad(DAG, sdl, InChain, MMO, Load, Ptr, Src0, Mask);
+  else
+    Res = Load =
+        DAG.getMaskedLoad(VT, sdl, InChain, Ptr, Offset, Mask, Src0, VT, MMO,
+                          ISD::UNINDEXED, ISD::NON_EXTLOAD, IsExpanding);
   if (AddToChain)
     PendingLoads.push_back(Load.getValue(1));
-  setValue(&I, Load);
+  setValue(&I, Res);
 }
 
 void SelectionDAGBuilder::visitMaskedGather(const CallInst &I) {
diff --git a/llvm/lib/Target/X86/X86ISelLowering.cpp b/llvm/lib/Target/X86/X86ISelLowering.cpp
index f27c935812f51..a45e18ae67a91 100644
--- a/llvm/lib/Target/X86/X86ISelLowering.cpp
+++ b/llvm/lib/Target/X86/X86ISelLowering.cpp
@@ -32308,6 +32308,55 @@ bool X86TargetLowering::isInlineAsmTargetBranch(
   return Inst.equals_insensitive("call") || Inst.equals_insensitive("jmp");
 }
 
+static SDValue getFlagsOfCmpZeroFori1(SelectionDAG &DAG, const SDLoc &DL,
+                                      SDValue V) {
+  assert(V.getValueType() == MVT::i1 && "assume i1 value");
+  EVT Ty = MVT::i8;
+  SDValue VE = DAG.getZExtOrTrunc(V, DL, Ty);
+  SDValue Zero = DAG.getConstant(0, DL, Ty);
+  SDVTList X86SubVTs = DAG.getVTList(Ty, MVT::i32);
+  SDValue CmpZero = DAG.getNode(X86ISD::SUB, DL, X86SubVTs, Zero, VE);
+  return SDValue(CmpZero.getNode(), 1);
+}
+
+SDValue X86TargetLowering::visitMaskedLoad(
+    SelectionDAG &DAG, const SDLoc &DL, SDValue Chain, MachineMemOperand *MMO,
+    SDValue &NewLoad, SDValue Ptr, SDValue PassThru, SDValue Mask) const {
+  // @llvm.masked.load.*(ptr, alignment, mask, passthru)
+  // ->
+  // _, flags = SUB 0, mask
+  // res, chain = CLOAD inchain, ptr, (bit_cast_to_scalar passthru), cond, flags
+  // bit_cast_to_vector<res>
+  EVT VTy = PassThru.getValueType();
+  EVT Ty = VTy.getVectorElementType();
+  SDVTList Tys = DAG.getVTList(Ty, MVT::Other);
+  SDValue ScalarPassThru = DAG.getBitcast(Ty, PassThru);
+  SDValue ScalarMask = DAG.getBitcast(MVT::i1, Mask);
+  SDValue Flags = getFlagsOfCmpZeroFori1(DAG, DL, ScalarMask);
+  SDValue COND_NE = DAG.getTargetConstant(X86::COND_NE, DL, MVT::i8);
+  SDValue Ops[] = {Chain, Ptr, ScalarPassThru, COND_NE, Flags};
+  NewLoad = DAG.getMemIntrinsicNode(X86ISD::CLOAD, DL, Tys, Ops, Ty, MMO);
+  return DAG.getBitcast(VTy, NewLoad);
+}
+
+SDValue X86TargetLowering::visitMaskedStore(SelectionDAG &DAG, const SDLoc &DL,
+                                            SDValue Chain,
+                                            MachineMemOperand *MMO, SDValue Ptr,
+                                            SDValue Val, SDValue Mask) const {
+  // llvm.masked.store.*(Src0, Ptr, alignment, Mask)
+  // ->
+  // _, flags = SUB 0, mask
+  // chain = CSTORE inchain, (bit_cast_to_scalar val), ptr, cond, flags
+  EVT Ty = Val.getValueType().getVectorElementType();
+  SDVTList Tys = DAG.getVTList(MVT::Other);
+  SDValue ScalarVal = DAG.getBitcast(Ty, Val);
+  SDValue ScalarMask = DAG.getBitcast(MVT::i1, Mask);
+  SDValue Flags = getFlagsOfCmpZeroFori1(DAG, DL, ScalarMask);
+  SDValue COND_NE = DAG.getTargetConstant(X86::COND_NE, DL, MVT::i8);
+  SDValue Ops[] = {Chain, ScalarVal, Ptr, COND_NE, Flags};
+  return DAG.getMemIntrinsicNode(X86ISD::CSTORE, DL, Tys, Ops, Ty, MMO);
+}
+
 /// Provide custom lowering hooks for some operations.
 SDValue X86TargetLowering::LowerOperation(SDValue Op, SelectionDAG &DAG) const {
   switch (Op.getOpcode()) {
@@ -34024,6 +34073,8 @@ const char *X86TargetLowering::getTargetNodeName(unsigned Opcode) const {
   NODE_NAME_CASE(STRICT_FP80_ADD)
   NODE_NAME_CASE(CCMP)
   NODE_NAME_CASE(CTEST)
+  NODE_NAME_CASE(CLOAD)
+  NODE_NAME_CASE(CSTORE)
   }
   return nullptr;
 #undef NODE_NAME_CASE
@@ -55633,6 +55684,36 @@ static SDValue combineSubSetcc(SDNode *N, SelectionDAG &DAG) {
   return SDValue();
 }
 
+static SDValue combineX86CloadCstore(SDNode *N, SelectionDAG &DAG) {
+  // res, flags2 = sub 0, (setcc cc, flag)
+  // cload/cstore ..., cond_ne, flag2
+  // ->
+  // cload/cstore cc, flag
+  //
+  // if res has no users, where op is cload/cstore.
+  if (N->getConstantOperandVal(3) != X86::COND_NE)
+    return SDValue();
+
+  SDNode *Sub = N->getOperand(4).getNode();
+  if (Sub->getOpcode() != X86ISD::SUB)
+    return SDValue();
+
+  SDValue Op1 = Sub->getOperand(1);
+
+  if (Sub->hasAnyUseOfValue(0) || !X86::isZeroNode(Sub->getOperand(0)) ||
+      Op1.getOpcode() != X86ISD::SETCC)
+    return SDValue();
+
+
+  SmallVector<SDValue> Ops(N->op_values());
+  Ops[3] = Op1.getOperand(0);
+  Ops[4] = Op1.getOperand(1);
+
+  return DAG.getMemIntrinsicNode(N->getOpcode(), SDLoc(N), N->getVTList(), Ops,
+                                 cast<MemSDNode>(N)->getMemoryVT(),
+                                 cast<MemSDNode>(N)->getMemOperand());
+}
+
 static SDValue combineSub(SDNode *N, SelectionDAG &DAG,
                           TargetLowering::DAGCombinerInfo &DCI,
                           const X86Subtarget &Subtarget) {
@@ -57340,6 +57421,8 @@ SDValue X86TargetLowering::PerformDAGCombine(SDNode *N,
   case ISD::SUB:            return combineSub(N, DAG, DCI, Subtarget);
   case X86ISD::ADD:
   case X86ISD::SUB:         return combineX86AddSub(N, DAG, DCI, Subtarget);
+  case X86ISD::CLOAD:
+  case X86ISD::CSTORE:      return combineX86CloadCstore(N, DAG);
   case X86ISD::SBB:         return combineSBB(N, DAG);
   case X86ISD::ADC:         return combineADC(N, DAG, DCI);
   case ISD::MUL:            return combineMul(N, DAG, DCI, Subtarget);
diff --git a/llvm/lib/Target/X86/X86ISelLowering.h b/llvm/lib/Target/X86/X86ISelLowering.h
index 3c5c903bc0d98..362daa98e1f8e 100644
--- a/llvm/lib/Target/X86/X86ISelLowering.h
+++ b/llvm/lib/Target/X86/X86ISelLowering.h
@@ -903,6 +903,10 @@ namespace llvm {
     // is needed so that this can be expanded with control flow.
     VASTART_SAVE_XMM_REGS,
 
+    // Conditional load/store instructions
+    CLOAD,
+    CSTORE,
+
     // WARNING: Do not add anything in the end unless you want the node to
     // have memop! In fact, starting from FIRST_TARGET_MEMORY_OPCODE all
     // opcodes will be thought as target memory ops!
@@ -1556,6 +1560,14 @@ namespace llvm {
     bool isInlineAsmTargetBranch(const SmallVectorImpl<StringRef> &AsmStrs,
                                  unsigned OpNo) const override;
 
+    SDValue visitMaskedLoad(SelectionDAG &DAG, const SDLoc &DL, SDValue Chain,
+                            MachineMemOperand *MMO, SDValue &NewLoad,
+                            SDValue Ptr, SDValue PassThru,
+                            SDValue Mask) const override;
+    SDValue visitMaskedStore(SelectionDAG &DAG, const SDLoc &DL, SDValue Chain,
+                             MachineMemOperand *MMO, SDValue Ptr, SDValue Val,
+                             SDValue Mask) const override;
+
     /// Lower interleaved load(s) into target specific
     /// instructions/intrinsics.
     bool lowerInterleavedLoad(LoadInst *LI,
diff --git a/llvm/lib/Target/X86/X86InstrCMovSetCC.td b/llvm/lib/Target/X86/X86InstrCMovSetCC.td
index e27aa4115990e..543057c58035a 100644
--- a/llvm/lib/Target/X86/X86InstrCMovSetCC.td
+++ b/llvm/lib/Target/X86/X86InstrCMovSetCC.td
@@ -113,6 +113,35 @@ let Predicates = [HasCMOV, HasCF] in {
             (CFCMOV32rr GR32:$src1, (inv_cond_XFORM timm:$cond))>;
   def : Pat<(X86cmov GR64:$src1, 0, timm:$cond, EFLAGS),
             (CFCMOV64rr GR64:$src1, (inv_cond_XFORM timm:$cond))>;
+
+  def : Pat<(X86cload addr:$src1, 0, timm:$cond, EFLAGS),
+            (CFCMOV16rm addr:$src1, timm:$cond)>;
+  def : Pat<(X86cload addr:$src1, 0, timm:$cond, EFLAGS),
+            (CFCMOV32rm addr:$src1, timm:$cond)>;
+  def : Pat<(X86cload addr:$src1, 0, timm:$cond, EFLAGS),
+            (CFCMOV64rm addr:$src1, timm:$cond)>;
+
+  // FIXME: Shouldn't patterns for 0 work for undef?
+  def : Pat<(X86cload addr:$src1, undef, timm:$cond, EFLAGS),
+            (CFCMOV16rm addr:$src1, timm:$cond)>;
+  def : Pat<(X86cload addr:$src1, undef, timm:$cond, EFLAGS),
+            (CFCMOV32rm addr:$src1, timm:$cond)>;
+  def : Pat<(X86cload addr:$src1, undef, timm:$cond, EFLAGS),
+            (CFCMOV64rm addr:$src1, timm:$cond)>;
+
+  def : Pat<(X86cload addr:$src2, GR16:$src1, timm:$cond, EFLAGS),
+            (CFCMOV16rm_ND GR16:$src1, addr:$src2, timm:$cond)>;
+  def : Pat<(X86cload addr:$src2, GR32:$src1, timm:$cond, EFLAGS),
+            (CFCMOV32rm_ND GR32:$src1, addr:$src2, timm:$cond)>;
+  def : Pat<(X86cload addr:$src2, GR64:$src1, timm:$cond, EFLAGS),
+            (CFCMOV64rm_ND GR64:$src1, addr:$src2, timm:$cond)>;
+
+  def : Pat<(X86cstore GR16:$src2, addr:$src1, timm:$cond, EFLAGS),
+            (CFCMOV16mr addr:$src1, GR16:$src2, timm:$cond)>;
+  def : Pat<(X86cstore GR32:$src2, addr:$src1, timm:$cond, EFLAGS),
+            (CFCMOV32mr addr:$src1, GR32:$src2, timm:$cond)>;
+  def : Pat<(X86cstore GR64:$src2, addr:$src1, timm:$cond, EFLAGS),
+            (CFCMOV64mr addr:$src1, GR64:$src2, timm:$cond)>;
 }
 
 // SetCC instructions.
diff --git a/llvm/lib/Target/X86/X86InstrFragments.td b/llvm/lib/Target/X86/X86InstrFragments.td
index 162e322712a6d..972b56e0f0cfe 100644
--- a/llvm/lib/Target/X86/X86InstrFragments.td
+++ b/llvm/lib/Target/X86/X86InstrFragments.td
@@ -15,6 +15,15 @@ def SDTX86FCmp    : SDTypeProfile<1, 2, [SDTCisVT<0, i32>, SDTCisFP<1>,
 def SDTX86Ccmp    : SDTypeProfile<1, 5,
                                   [SDTCisVT<3, i8>, SDTCisVT<4, i8>, SDTCisVT<5, i32>]>;
 
+// res, chain = CLOAD inchain, ptr, passthru, cond, flags
+def SDTX86Cload    : SDTypeProfile<1, 4,
+                                  [SDTCisInt<0>, SDTCisPtrTy<1>, SDTCisSameAs<0, 2>,
+                                   SDTCisVT<3, i8>, SDTCisVT<4, i32>]>;
+// chain = CSTORE inchain, val, ptr, cond, flags
+def SDTX86Cstore    : SDTypeProfile<0, 4,
+                                  [SDTCisInt<0>, SDTCisPtrTy<1>,
+                                   SDTCisVT<2, i8>, SDTCisVT<3, i32>]>;
+
 def SDTX86Cmov    : SDTypeProfile<1, 4,
                                   [SDTCisSameAs<0, 1>, SDTCisSameAs<1, 2>,
                                    SDTCisVT<3, i8>, SDTCisVT<4, i32>]>;
@@ -144,6 +153,9 @@ def X86bt      : SDNode<"X86ISD::BT",       SDTX86CmpTest>;
 def X86ccmp    : SDNode<"X86ISD::CCMP",     SDTX86Ccmp>;
 def X86ctest   : SDNode<"X86ISD::CTEST",    SDTX86Ccmp>;
 
+def X86cload    : SDNode<"X86ISD::CLOAD",   SDTX86Cload, [SDNPHasChain, SDNPMayLoad, SDNPMemOperand]>;
+def X86cstore   : SDNode<"X86ISD::CSTORE",  SDTX86Cstore, [SDNPHasChain, SDNPMayStore, SDNPMemOperand]>;
+
 def X86cmov    : SDNode<"X86ISD::CMOV",     SDTX86Cmov>;
 def X86brcond  : SDNode<"X86ISD::BRCOND",   SDTX86BrCond,
                         [SDNPHasChain]>;
diff --git a/llvm/lib/Target/X86/X86TargetTransformInfo.cpp b/llvm/lib/Target/X86/X86TargetTransformInfo.cpp
index de0144331dba3..aad4b9039bbb1 100644
--- a/llvm/lib/Target/X86/X86TargetTransformInfo.cpp
+++ b/llvm/lib/Target/X86/X86TargetTransformInfo.cpp
@@ -176,6 +176,27 @@ unsigned X86TTIImpl::getNumberOfRegisters(unsigned ClassID) const {
   return 8;
 }
 
+bool X86TTIImpl::hasConditionalLoadStoreForType(Type *Ty) const {
+  if (!ST->hasCF())
+    return false;
+  if (!Ty)
+    return true;
+  // Conditional faulting is supported by CFCMOV, which only accepts
+  // 16/32/64-bit operands.
+  // TODO: Support f32/f64 with VMOVSS/VMOVSD with zero mask when it's
+  // profitable.
+  if (!Ty->isIntegerTy())
+    return false;
+  switch (cast<IntegerType>(Ty)->getBitWidth()) {
+  default:
+    return false;
+  case 16:
+  case 32:
+  case 64:
+    return true;
+  }
+}
+
 TypeSize
 X86TTIImpl::getRegisterBitWidth(TargetTransformInfo::RegisterKind K) const {
   unsigned PreferVectorWidth = ST->getPreferVectorWidth();
@@ -5062,7 +5083,12 @@ X86TTIImpl::getMaskedMemoryOpCost(unsigned Opcode, Type *SrcTy, Align Alignment,
   std::pair<InstructionCost, MVT> LT = getTypeLegalizationCost(SrcVTy);
   auto VT = TLI->getValueType(DL, SrcVTy);
   InstructionCost Cost = 0;
-  if (VT.isSimple() && LT.second != VT.getSimpleVT() &&
+  MVT Ty = LT.second;
+  if (Ty == MVT::i16 || Ty == MVT::i32 || Ty == MVT::i64)
+    // APX masked load/store for scalar is cheap.
+    return Cost + LT.first;
+
+  if (VT.isSimple() && Ty != VT.getSimpleVT() &&
       LT.second.getVectorNumElements() == NumElem)
     // Promotion requires extend/truncate for data and a shuffle for mask.
     Cost += getShuffleCost(TTI::SK_PermuteTwoSrc, SrcVTy, std::nullopt,
@@ -5070,9 +5096,9 @@ X86TTIImpl::getMaskedMemoryOpCost(unsigned Opcode, Type *SrcTy, Align Alignment,
             getShuffleCost(TTI::SK_PermuteTwoSrc, MaskTy, std::nullopt,
                            CostKind, 0, nullptr);
 
-  else if (LT.first * LT.second.getVectorNumElements() > NumElem) {
+  else if (LT.first * Ty.getVectorNumElements() > NumElem) {
     auto *NewMaskTy = FixedVectorType::get(MaskTy->getElementType(),
-                                           LT.second.getVectorNumElements());
+                                           Ty.getVectorNumElements());
     // Expanding requires fill mask with zeroes
     Cost += getShuffleCost(TTI::SK_InsertSubvector, NewMaskTy, std::nullopt,
                            CostKind, 0, MaskTy);
@@ -5891,14 +5917,21 @@ bool X86TTIImpl::canMacroFuseCmp() {
 }
 
 bool X86TTIImpl::isLegalMaskedLoad(Type *DataTy, Align Alignment) {
+  bool IsSingleElementVector =
+      isa<VectorType>(DataTy) &&
+      cast<FixedVectorType>(DataTy)->getNumElements() == 1;
+  Type *ScalarTy = DataTy->getScalarType();
+
+  if (ST->hasCF() && IsSingleElementVector &&
+      hasConditionalLoadStoreForType(ScalarTy))
+    return true;
+
   if (!ST->hasAVX())
     return false;
 
-  // The backend can't handle a single element vector.
-  if (isa<VectorType>(Dat...
[truncated]

@llvmbot
Copy link
Member

llvmbot commented Jun 26, 2024

@llvm/pr-subscribers-backend-x86

Author: Shengchen Kan (KanRobert)

Changes
  1. Add TTI interface for conditional load/store.
  2. Mark 1 x i16/i32/i64 masked load/store legal so that it's not
    legalized in pass scalarize-masked-mem-intrin.
  3. Visit 1 x i16/i32/i64 masked load/store to build a target-specific
    CLOAD/CSTORE node to avoid error in
    DAGTypeLegalizer::ScalarizeVectorResult.
  4. Combine DAG to simplify the nodes for CLOAD/CSTORE.
  5. Lower CLOAD/CSTORE to CFCMOV by pattern match.

This is CodeGen part of #95515


Patch is 24.06 KiB, truncated to 20.00 KiB below, full version: https://github.com/llvm/llvm-project/pull/96720.diff

12 Files Affected:

  • (modified) llvm/include/llvm/Analysis/TargetTransformInfo.h (+8)
  • (modified) llvm/include/llvm/Analysis/TargetTransformInfoImpl.h (+1)
  • (modified) llvm/include/llvm/CodeGen/TargetLowering.h (+14)
  • (modified) llvm/lib/Analysis/TargetTransformInfo.cpp (+4)
  • (modified) llvm/lib/CodeGen/SelectionDAG/SelectionDAGBuilder.cpp (+25-6)
  • (modified) llvm/lib/Target/X86/X86ISelLowering.cpp (+83)
  • (modified) llvm/lib/Target/X86/X86ISelLowering.h (+12)
  • (modified) llvm/lib/Target/X86/X86InstrCMovSetCC.td (+29)
  • (modified) llvm/lib/Target/X86/X86InstrFragments.td (+12)
  • (modified) llvm/lib/Target/X86/X86TargetTransformInfo.cpp (+40-7)
  • (modified) llvm/lib/Target/X86/X86TargetTransformInfo.h (+1)
  • (added) llvm/test/CodeGen/X86/apx/cf.ll (+85)
diff --git a/llvm/include/llvm/Analysis/TargetTransformInfo.h b/llvm/include/llvm/Analysis/TargetTransformInfo.h
index f55f21c94a85a..f5c0127e1d422 100644
--- a/llvm/include/llvm/Analysis/TargetTransformInfo.h
+++ b/llvm/include/llvm/Analysis/TargetTransformInfo.h
@@ -1113,6 +1113,10 @@ class TargetTransformInfo {
   /// \return the number of registers in the target-provided register class.
   unsigned getNumberOfRegisters(unsigned ClassID) const;
 
+  /// \return true if the target supports load/store that enables fault
+  /// suppression of memory operands when the source condition is false.
+  bool hasConditionalLoadStoreForType(Type *Ty = nullptr) const;
+
   /// \return the target-provided register class ID for the provided type,
   /// accounting for type promotion and other type-legalization techniques that
   /// the target might apply. However, it specifically does not account for the
@@ -1956,6 +1960,7 @@ class TargetTransformInfo::Concept {
   virtual bool preferToKeepConstantsAttached(const Instruction &Inst,
                                              const Function &Fn) const = 0;
   virtual unsigned getNumberOfRegisters(unsigned ClassID) const = 0;
+  virtual bool hasConditionalLoadStoreForType(Type *Ty = nullptr) const = 0;
   virtual unsigned getRegisterClassForType(bool Vector,
                                            Type *Ty = nullptr) const = 0;
   virtual const char *getRegisterClassName(unsigned ClassID) const = 0;
@@ -2543,6 +2548,9 @@ class TargetTransformInfo::Model final : public TargetTransformInfo::Concept {
   unsigned getNumberOfRegisters(unsigned ClassID) const override {
     return Impl.getNumberOfRegisters(ClassID);
   }
+  bool hasConditionalLoadStoreForType(Type *Ty = nullptr) const override {
+    return Impl.hasConditionalLoadStoreForType(Ty);
+  }
   unsigned getRegisterClassForType(bool Vector,
                                    Type *Ty = nullptr) const override {
     return Impl.getRegisterClassForType(Vector, Ty);
diff --git a/llvm/include/llvm/Analysis/TargetTransformInfoImpl.h b/llvm/include/llvm/Analysis/TargetTransformInfoImpl.h
index 7828bdc1f1f43..49b4bd00baed4 100644
--- a/llvm/include/llvm/Analysis/TargetTransformInfoImpl.h
+++ b/llvm/include/llvm/Analysis/TargetTransformInfoImpl.h
@@ -457,6 +457,7 @@ class TargetTransformInfoImplBase {
   }
 
   unsigned getNumberOfRegisters(unsigned ClassID) const { return 8; }
+  bool hasConditionalLoadStoreForType(Type *Ty) const { return false; }
 
   unsigned getRegisterClassForType(bool Vector, Type *Ty = nullptr) const {
     return Vector ? 1 : 0;
diff --git a/llvm/include/llvm/CodeGen/TargetLowering.h b/llvm/include/llvm/CodeGen/TargetLowering.h
index 06f7ee2a589c8..9a0df8b29d752 100644
--- a/llvm/include/llvm/CodeGen/TargetLowering.h
+++ b/llvm/include/llvm/CodeGen/TargetLowering.h
@@ -3895,6 +3895,20 @@ class TargetLowering : public TargetLoweringBase {
                            const SDValue OldRHS, SDValue &Chain,
                            bool IsSignaling = false) const;
 
+  virtual SDValue visitMaskedLoad(SelectionDAG &DAG, const SDLoc &DL,
+                                  SDValue Chain, MachineMemOperand *MMO,
+                                  SDValue &NewLoad, SDValue Ptr,
+                                  SDValue PassThru, SDValue Mask) const {
+    llvm_unreachable("Not Implemented");
+  }
+
+  virtual SDValue visitMaskedStore(SelectionDAG &DAG, const SDLoc &DL,
+                                   SDValue Chain, MachineMemOperand *MMO,
+                                   SDValue Ptr, SDValue Val,
+                                   SDValue Mask) const {
+    llvm_unreachable("Not Implemented");
+  }
+
   /// Returns a pair of (return value, chain).
   /// It is an error to pass RTLIB::UNKNOWN_LIBCALL as \p LC.
   std::pair<SDValue, SDValue> makeLibCall(SelectionDAG &DAG, RTLIB::Libcall LC,
diff --git a/llvm/lib/Analysis/TargetTransformInfo.cpp b/llvm/lib/Analysis/TargetTransformInfo.cpp
index 7e721cbc87f3f..0db8a4201fead 100644
--- a/llvm/lib/Analysis/TargetTransformInfo.cpp
+++ b/llvm/lib/Analysis/TargetTransformInfo.cpp
@@ -722,6 +722,10 @@ unsigned TargetTransformInfo::getNumberOfRegisters(unsigned ClassID) const {
   return TTIImpl->getNumberOfRegisters(ClassID);
 }
 
+bool TargetTransformInfo::hasConditionalLoadStoreForType(Type *Ty) const {
+  return TTIImpl->hasConditionalLoadStoreForType(Ty);
+}
+
 unsigned TargetTransformInfo::getRegisterClassForType(bool Vector,
                                                       Type *Ty) const {
   return TTIImpl->getRegisterClassForType(Vector, Ty);
diff --git a/llvm/lib/CodeGen/SelectionDAG/SelectionDAGBuilder.cpp b/llvm/lib/CodeGen/SelectionDAG/SelectionDAGBuilder.cpp
index 296b06187ec0f..1f9e73ef949e8 100644
--- a/llvm/lib/CodeGen/SelectionDAG/SelectionDAGBuilder.cpp
+++ b/llvm/lib/CodeGen/SelectionDAG/SelectionDAGBuilder.cpp
@@ -4783,9 +4783,18 @@ void SelectionDAGBuilder::visitMaskedStore(const CallInst &I,
   MachineMemOperand *MMO = DAG.getMachineFunction().getMachineMemOperand(
       MachinePointerInfo(PtrOperand), MMOFlags,
       LocationSize::beforeOrAfterPointer(), Alignment, I.getAAMetadata());
+
+  const auto &TLI = DAG.getTargetLoweringInfo();
+  const auto &TTI =
+      TLI.getTargetMachine().getTargetTransformInfo(*I.getFunction());
   SDValue StoreNode =
-      DAG.getMaskedStore(getMemoryRoot(), sdl, Src0, Ptr, Offset, Mask, VT, MMO,
-                         ISD::UNINDEXED, false /* Truncating */, IsCompressing);
+      (!IsCompressing && TTI.hasConditionalLoadStoreForType(
+                             I.getArgOperand(0)->getType()->getScalarType()))
+          ? TLI.visitMaskedStore(DAG, sdl, getMemoryRoot(), MMO, Ptr, Src0,
+                                 Mask)
+          : DAG.getMaskedStore(getMemoryRoot(), sdl, Src0, Ptr, Offset, Mask,
+                               VT, MMO, ISD::UNINDEXED, /*Truncating=*/false,
+                               IsCompressing);
   DAG.setRoot(StoreNode);
   setValue(&I, StoreNode);
 }
@@ -4958,12 +4967,22 @@ void SelectionDAGBuilder::visitMaskedLoad(const CallInst &I, bool IsExpanding) {
       MachinePointerInfo(PtrOperand), MMOFlags,
       LocationSize::beforeOrAfterPointer(), Alignment, AAInfo, Ranges);
 
-  SDValue Load =
-      DAG.getMaskedLoad(VT, sdl, InChain, Ptr, Offset, Mask, Src0, VT, MMO,
-                        ISD::UNINDEXED, ISD::NON_EXTLOAD, IsExpanding);
+  const auto &TLI = DAG.getTargetLoweringInfo();
+  const auto &TTI =
+      TLI.getTargetMachine().getTargetTransformInfo(*I.getFunction());
+  // The Load/Res may point to different values.
+  SDValue Load;
+  SDValue Res;
+  if (!IsExpanding && TTI.hasConditionalLoadStoreForType(
+                          Src0Operand->getType()->getScalarType()))
+    Res = TLI.visitMaskedLoad(DAG, sdl, InChain, MMO, Load, Ptr, Src0, Mask);
+  else
+    Res = Load =
+        DAG.getMaskedLoad(VT, sdl, InChain, Ptr, Offset, Mask, Src0, VT, MMO,
+                          ISD::UNINDEXED, ISD::NON_EXTLOAD, IsExpanding);
   if (AddToChain)
     PendingLoads.push_back(Load.getValue(1));
-  setValue(&I, Load);
+  setValue(&I, Res);
 }
 
 void SelectionDAGBuilder::visitMaskedGather(const CallInst &I) {
diff --git a/llvm/lib/Target/X86/X86ISelLowering.cpp b/llvm/lib/Target/X86/X86ISelLowering.cpp
index f27c935812f51..a45e18ae67a91 100644
--- a/llvm/lib/Target/X86/X86ISelLowering.cpp
+++ b/llvm/lib/Target/X86/X86ISelLowering.cpp
@@ -32308,6 +32308,55 @@ bool X86TargetLowering::isInlineAsmTargetBranch(
   return Inst.equals_insensitive("call") || Inst.equals_insensitive("jmp");
 }
 
+static SDValue getFlagsOfCmpZeroFori1(SelectionDAG &DAG, const SDLoc &DL,
+                                      SDValue V) {
+  assert(V.getValueType() == MVT::i1 && "assume i1 value");
+  EVT Ty = MVT::i8;
+  SDValue VE = DAG.getZExtOrTrunc(V, DL, Ty);
+  SDValue Zero = DAG.getConstant(0, DL, Ty);
+  SDVTList X86SubVTs = DAG.getVTList(Ty, MVT::i32);
+  SDValue CmpZero = DAG.getNode(X86ISD::SUB, DL, X86SubVTs, Zero, VE);
+  return SDValue(CmpZero.getNode(), 1);
+}
+
+SDValue X86TargetLowering::visitMaskedLoad(
+    SelectionDAG &DAG, const SDLoc &DL, SDValue Chain, MachineMemOperand *MMO,
+    SDValue &NewLoad, SDValue Ptr, SDValue PassThru, SDValue Mask) const {
+  // @llvm.masked.load.*(ptr, alignment, mask, passthru)
+  // ->
+  // _, flags = SUB 0, mask
+  // res, chain = CLOAD inchain, ptr, (bit_cast_to_scalar passthru), cond, flags
+  // bit_cast_to_vector<res>
+  EVT VTy = PassThru.getValueType();
+  EVT Ty = VTy.getVectorElementType();
+  SDVTList Tys = DAG.getVTList(Ty, MVT::Other);
+  SDValue ScalarPassThru = DAG.getBitcast(Ty, PassThru);
+  SDValue ScalarMask = DAG.getBitcast(MVT::i1, Mask);
+  SDValue Flags = getFlagsOfCmpZeroFori1(DAG, DL, ScalarMask);
+  SDValue COND_NE = DAG.getTargetConstant(X86::COND_NE, DL, MVT::i8);
+  SDValue Ops[] = {Chain, Ptr, ScalarPassThru, COND_NE, Flags};
+  NewLoad = DAG.getMemIntrinsicNode(X86ISD::CLOAD, DL, Tys, Ops, Ty, MMO);
+  return DAG.getBitcast(VTy, NewLoad);
+}
+
+SDValue X86TargetLowering::visitMaskedStore(SelectionDAG &DAG, const SDLoc &DL,
+                                            SDValue Chain,
+                                            MachineMemOperand *MMO, SDValue Ptr,
+                                            SDValue Val, SDValue Mask) const {
+  // llvm.masked.store.*(Src0, Ptr, alignment, Mask)
+  // ->
+  // _, flags = SUB 0, mask
+  // chain = CSTORE inchain, (bit_cast_to_scalar val), ptr, cond, flags
+  EVT Ty = Val.getValueType().getVectorElementType();
+  SDVTList Tys = DAG.getVTList(MVT::Other);
+  SDValue ScalarVal = DAG.getBitcast(Ty, Val);
+  SDValue ScalarMask = DAG.getBitcast(MVT::i1, Mask);
+  SDValue Flags = getFlagsOfCmpZeroFori1(DAG, DL, ScalarMask);
+  SDValue COND_NE = DAG.getTargetConstant(X86::COND_NE, DL, MVT::i8);
+  SDValue Ops[] = {Chain, ScalarVal, Ptr, COND_NE, Flags};
+  return DAG.getMemIntrinsicNode(X86ISD::CSTORE, DL, Tys, Ops, Ty, MMO);
+}
+
 /// Provide custom lowering hooks for some operations.
 SDValue X86TargetLowering::LowerOperation(SDValue Op, SelectionDAG &DAG) const {
   switch (Op.getOpcode()) {
@@ -34024,6 +34073,8 @@ const char *X86TargetLowering::getTargetNodeName(unsigned Opcode) const {
   NODE_NAME_CASE(STRICT_FP80_ADD)
   NODE_NAME_CASE(CCMP)
   NODE_NAME_CASE(CTEST)
+  NODE_NAME_CASE(CLOAD)
+  NODE_NAME_CASE(CSTORE)
   }
   return nullptr;
 #undef NODE_NAME_CASE
@@ -55633,6 +55684,36 @@ static SDValue combineSubSetcc(SDNode *N, SelectionDAG &DAG) {
   return SDValue();
 }
 
+static SDValue combineX86CloadCstore(SDNode *N, SelectionDAG &DAG) {
+  // res, flags2 = sub 0, (setcc cc, flag)
+  // cload/cstore ..., cond_ne, flag2
+  // ->
+  // cload/cstore cc, flag
+  //
+  // if res has no users, where op is cload/cstore.
+  if (N->getConstantOperandVal(3) != X86::COND_NE)
+    return SDValue();
+
+  SDNode *Sub = N->getOperand(4).getNode();
+  if (Sub->getOpcode() != X86ISD::SUB)
+    return SDValue();
+
+  SDValue Op1 = Sub->getOperand(1);
+
+  if (Sub->hasAnyUseOfValue(0) || !X86::isZeroNode(Sub->getOperand(0)) ||
+      Op1.getOpcode() != X86ISD::SETCC)
+    return SDValue();
+
+
+  SmallVector<SDValue> Ops(N->op_values());
+  Ops[3] = Op1.getOperand(0);
+  Ops[4] = Op1.getOperand(1);
+
+  return DAG.getMemIntrinsicNode(N->getOpcode(), SDLoc(N), N->getVTList(), Ops,
+                                 cast<MemSDNode>(N)->getMemoryVT(),
+                                 cast<MemSDNode>(N)->getMemOperand());
+}
+
 static SDValue combineSub(SDNode *N, SelectionDAG &DAG,
                           TargetLowering::DAGCombinerInfo &DCI,
                           const X86Subtarget &Subtarget) {
@@ -57340,6 +57421,8 @@ SDValue X86TargetLowering::PerformDAGCombine(SDNode *N,
   case ISD::SUB:            return combineSub(N, DAG, DCI, Subtarget);
   case X86ISD::ADD:
   case X86ISD::SUB:         return combineX86AddSub(N, DAG, DCI, Subtarget);
+  case X86ISD::CLOAD:
+  case X86ISD::CSTORE:      return combineX86CloadCstore(N, DAG);
   case X86ISD::SBB:         return combineSBB(N, DAG);
   case X86ISD::ADC:         return combineADC(N, DAG, DCI);
   case ISD::MUL:            return combineMul(N, DAG, DCI, Subtarget);
diff --git a/llvm/lib/Target/X86/X86ISelLowering.h b/llvm/lib/Target/X86/X86ISelLowering.h
index 3c5c903bc0d98..362daa98e1f8e 100644
--- a/llvm/lib/Target/X86/X86ISelLowering.h
+++ b/llvm/lib/Target/X86/X86ISelLowering.h
@@ -903,6 +903,10 @@ namespace llvm {
     // is needed so that this can be expanded with control flow.
     VASTART_SAVE_XMM_REGS,
 
+    // Conditional load/store instructions
+    CLOAD,
+    CSTORE,
+
     // WARNING: Do not add anything in the end unless you want the node to
     // have memop! In fact, starting from FIRST_TARGET_MEMORY_OPCODE all
     // opcodes will be thought as target memory ops!
@@ -1556,6 +1560,14 @@ namespace llvm {
     bool isInlineAsmTargetBranch(const SmallVectorImpl<StringRef> &AsmStrs,
                                  unsigned OpNo) const override;
 
+    SDValue visitMaskedLoad(SelectionDAG &DAG, const SDLoc &DL, SDValue Chain,
+                            MachineMemOperand *MMO, SDValue &NewLoad,
+                            SDValue Ptr, SDValue PassThru,
+                            SDValue Mask) const override;
+    SDValue visitMaskedStore(SelectionDAG &DAG, const SDLoc &DL, SDValue Chain,
+                             MachineMemOperand *MMO, SDValue Ptr, SDValue Val,
+                             SDValue Mask) const override;
+
     /// Lower interleaved load(s) into target specific
     /// instructions/intrinsics.
     bool lowerInterleavedLoad(LoadInst *LI,
diff --git a/llvm/lib/Target/X86/X86InstrCMovSetCC.td b/llvm/lib/Target/X86/X86InstrCMovSetCC.td
index e27aa4115990e..543057c58035a 100644
--- a/llvm/lib/Target/X86/X86InstrCMovSetCC.td
+++ b/llvm/lib/Target/X86/X86InstrCMovSetCC.td
@@ -113,6 +113,35 @@ let Predicates = [HasCMOV, HasCF] in {
             (CFCMOV32rr GR32:$src1, (inv_cond_XFORM timm:$cond))>;
   def : Pat<(X86cmov GR64:$src1, 0, timm:$cond, EFLAGS),
             (CFCMOV64rr GR64:$src1, (inv_cond_XFORM timm:$cond))>;
+
+  def : Pat<(X86cload addr:$src1, 0, timm:$cond, EFLAGS),
+            (CFCMOV16rm addr:$src1, timm:$cond)>;
+  def : Pat<(X86cload addr:$src1, 0, timm:$cond, EFLAGS),
+            (CFCMOV32rm addr:$src1, timm:$cond)>;
+  def : Pat<(X86cload addr:$src1, 0, timm:$cond, EFLAGS),
+            (CFCMOV64rm addr:$src1, timm:$cond)>;
+
+  // FIXME: Shouldn't patterns for 0 work for undef?
+  def : Pat<(X86cload addr:$src1, undef, timm:$cond, EFLAGS),
+            (CFCMOV16rm addr:$src1, timm:$cond)>;
+  def : Pat<(X86cload addr:$src1, undef, timm:$cond, EFLAGS),
+            (CFCMOV32rm addr:$src1, timm:$cond)>;
+  def : Pat<(X86cload addr:$src1, undef, timm:$cond, EFLAGS),
+            (CFCMOV64rm addr:$src1, timm:$cond)>;
+
+  def : Pat<(X86cload addr:$src2, GR16:$src1, timm:$cond, EFLAGS),
+            (CFCMOV16rm_ND GR16:$src1, addr:$src2, timm:$cond)>;
+  def : Pat<(X86cload addr:$src2, GR32:$src1, timm:$cond, EFLAGS),
+            (CFCMOV32rm_ND GR32:$src1, addr:$src2, timm:$cond)>;
+  def : Pat<(X86cload addr:$src2, GR64:$src1, timm:$cond, EFLAGS),
+            (CFCMOV64rm_ND GR64:$src1, addr:$src2, timm:$cond)>;
+
+  def : Pat<(X86cstore GR16:$src2, addr:$src1, timm:$cond, EFLAGS),
+            (CFCMOV16mr addr:$src1, GR16:$src2, timm:$cond)>;
+  def : Pat<(X86cstore GR32:$src2, addr:$src1, timm:$cond, EFLAGS),
+            (CFCMOV32mr addr:$src1, GR32:$src2, timm:$cond)>;
+  def : Pat<(X86cstore GR64:$src2, addr:$src1, timm:$cond, EFLAGS),
+            (CFCMOV64mr addr:$src1, GR64:$src2, timm:$cond)>;
 }
 
 // SetCC instructions.
diff --git a/llvm/lib/Target/X86/X86InstrFragments.td b/llvm/lib/Target/X86/X86InstrFragments.td
index 162e322712a6d..972b56e0f0cfe 100644
--- a/llvm/lib/Target/X86/X86InstrFragments.td
+++ b/llvm/lib/Target/X86/X86InstrFragments.td
@@ -15,6 +15,15 @@ def SDTX86FCmp    : SDTypeProfile<1, 2, [SDTCisVT<0, i32>, SDTCisFP<1>,
 def SDTX86Ccmp    : SDTypeProfile<1, 5,
                                   [SDTCisVT<3, i8>, SDTCisVT<4, i8>, SDTCisVT<5, i32>]>;
 
+// res, chain = CLOAD inchain, ptr, passthru, cond, flags
+def SDTX86Cload    : SDTypeProfile<1, 4,
+                                  [SDTCisInt<0>, SDTCisPtrTy<1>, SDTCisSameAs<0, 2>,
+                                   SDTCisVT<3, i8>, SDTCisVT<4, i32>]>;
+// chain = CSTORE inchain, val, ptr, cond, flags
+def SDTX86Cstore    : SDTypeProfile<0, 4,
+                                  [SDTCisInt<0>, SDTCisPtrTy<1>,
+                                   SDTCisVT<2, i8>, SDTCisVT<3, i32>]>;
+
 def SDTX86Cmov    : SDTypeProfile<1, 4,
                                   [SDTCisSameAs<0, 1>, SDTCisSameAs<1, 2>,
                                    SDTCisVT<3, i8>, SDTCisVT<4, i32>]>;
@@ -144,6 +153,9 @@ def X86bt      : SDNode<"X86ISD::BT",       SDTX86CmpTest>;
 def X86ccmp    : SDNode<"X86ISD::CCMP",     SDTX86Ccmp>;
 def X86ctest   : SDNode<"X86ISD::CTEST",    SDTX86Ccmp>;
 
+def X86cload    : SDNode<"X86ISD::CLOAD",   SDTX86Cload, [SDNPHasChain, SDNPMayLoad, SDNPMemOperand]>;
+def X86cstore   : SDNode<"X86ISD::CSTORE",  SDTX86Cstore, [SDNPHasChain, SDNPMayStore, SDNPMemOperand]>;
+
 def X86cmov    : SDNode<"X86ISD::CMOV",     SDTX86Cmov>;
 def X86brcond  : SDNode<"X86ISD::BRCOND",   SDTX86BrCond,
                         [SDNPHasChain]>;
diff --git a/llvm/lib/Target/X86/X86TargetTransformInfo.cpp b/llvm/lib/Target/X86/X86TargetTransformInfo.cpp
index de0144331dba3..aad4b9039bbb1 100644
--- a/llvm/lib/Target/X86/X86TargetTransformInfo.cpp
+++ b/llvm/lib/Target/X86/X86TargetTransformInfo.cpp
@@ -176,6 +176,27 @@ unsigned X86TTIImpl::getNumberOfRegisters(unsigned ClassID) const {
   return 8;
 }
 
+bool X86TTIImpl::hasConditionalLoadStoreForType(Type *Ty) const {
+  if (!ST->hasCF())
+    return false;
+  if (!Ty)
+    return true;
+  // Conditional faulting is supported by CFCMOV, which only accepts
+  // 16/32/64-bit operands.
+  // TODO: Support f32/f64 with VMOVSS/VMOVSD with zero mask when it's
+  // profitable.
+  if (!Ty->isIntegerTy())
+    return false;
+  switch (cast<IntegerType>(Ty)->getBitWidth()) {
+  default:
+    return false;
+  case 16:
+  case 32:
+  case 64:
+    return true;
+  }
+}
+
 TypeSize
 X86TTIImpl::getRegisterBitWidth(TargetTransformInfo::RegisterKind K) const {
   unsigned PreferVectorWidth = ST->getPreferVectorWidth();
@@ -5062,7 +5083,12 @@ X86TTIImpl::getMaskedMemoryOpCost(unsigned Opcode, Type *SrcTy, Align Alignment,
   std::pair<InstructionCost, MVT> LT = getTypeLegalizationCost(SrcVTy);
   auto VT = TLI->getValueType(DL, SrcVTy);
   InstructionCost Cost = 0;
-  if (VT.isSimple() && LT.second != VT.getSimpleVT() &&
+  MVT Ty = LT.second;
+  if (Ty == MVT::i16 || Ty == MVT::i32 || Ty == MVT::i64)
+    // APX masked load/store for scalar is cheap.
+    return Cost + LT.first;
+
+  if (VT.isSimple() && Ty != VT.getSimpleVT() &&
       LT.second.getVectorNumElements() == NumElem)
     // Promotion requires extend/truncate for data and a shuffle for mask.
     Cost += getShuffleCost(TTI::SK_PermuteTwoSrc, SrcVTy, std::nullopt,
@@ -5070,9 +5096,9 @@ X86TTIImpl::getMaskedMemoryOpCost(unsigned Opcode, Type *SrcTy, Align Alignment,
             getShuffleCost(TTI::SK_PermuteTwoSrc, MaskTy, std::nullopt,
                            CostKind, 0, nullptr);
 
-  else if (LT.first * LT.second.getVectorNumElements() > NumElem) {
+  else if (LT.first * Ty.getVectorNumElements() > NumElem) {
     auto *NewMaskTy = FixedVectorType::get(MaskTy->getElementType(),
-                                           LT.second.getVectorNumElements());
+                                           Ty.getVectorNumElements());
     // Expanding requires fill mask with zeroes
     Cost += getShuffleCost(TTI::SK_InsertSubvector, NewMaskTy, std::nullopt,
                            CostKind, 0, MaskTy);
@@ -5891,14 +5917,21 @@ bool X86TTIImpl::canMacroFuseCmp() {
 }
 
 bool X86TTIImpl::isLegalMaskedLoad(Type *DataTy, Align Alignment) {
+  bool IsSingleElementVector =
+      isa<VectorType>(DataTy) &&
+      cast<FixedVectorType>(DataTy)->getNumElements() == 1;
+  Type *ScalarTy = DataTy->getScalarType();
+
+  if (ST->hasCF() && IsSingleElementVector &&
+      hasConditionalLoadStoreForType(ScalarTy))
+    return true;
+
   if (!ST->hasAVX())
     return false;
 
-  // The backend can't handle a single element vector.
-  if (isa<VectorType>(Dat...
[truncated]

Copy link
Contributor

@phoebewang phoebewang left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM. Do we need to wait for the CFG patch?

@KanRobert
Copy link
Contributor Author

LGTM. Do we need to wait for the CFG patch?

Thanks. I don't think we need to. We always use masked.load/store for the CFG transform. The possible difference is that in which stage we emit them and in which cases it's profitable.

Also, middle-end patch can only be buildable after this lands.

@KanRobert KanRobert merged commit 15fc801 into llvm:main Jun 27, 2024
5 of 6 checks passed
@llvm-ci

This comment was marked as off-topic.

lravenclaw pushed a commit to lravenclaw/llvm-project that referenced this pull request Jul 3, 2024
…lvm#96720)

1. Add TTI interface for conditional load/store.
2. Mark 1 x i16/i32/i64 masked load/store legal so that it's not
   legalized in pass scalarize-masked-mem-intrin.
3. Visit 1 x i16/i32/i64 masked load/store to build a target-specific
   CLOAD/CSTORE node to avoid error in
   `DAGTypeLegalizer::ScalarizeVectorResult`.
4. Combine DAG to simplify the nodes for CLOAD/CSTORE.
5. Lower CLOAD/CSTORE to CFCMOV by pattern match.

This is CodeGen part of llvm#95515
AlexisPerry pushed a commit to llvm-project-tlp/llvm-project that referenced this pull request Jul 9, 2024
…lvm#96720)

1. Add TTI interface for conditional load/store.
2. Mark 1 x i16/i32/i64 masked load/store legal so that it's not
   legalized in pass scalarize-masked-mem-intrin.
3. Visit 1 x i16/i32/i64 masked load/store to build a target-specific
   CLOAD/CSTORE node to avoid error in
   `DAGTypeLegalizer::ScalarizeVectorResult`.
4. Combine DAG to simplify the nodes for CLOAD/CSTORE.
5. Lower CLOAD/CSTORE to CFCMOV by pattern match.

This is CodeGen part of llvm#95515
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
backend:X86 llvm:analysis Includes value tracking, cost tables and constant folding llvm:SelectionDAG SelectionDAGISel as well
Projects
None yet
Development

Successfully merging this pull request may close these issues.

5 participants