Skip to content

Commit 1baf9a1

Browse files
KanRobertlravenclaw
authored andcommitted
[X86][CodeGen] Support hoisting load/store with conditional faulting (llvm#96720)
1. Add TTI interface for conditional load/store. 2. Mark 1 x i16/i32/i64 masked load/store legal so that it's not legalized in pass scalarize-masked-mem-intrin. 3. Visit 1 x i16/i32/i64 masked load/store to build a target-specific CLOAD/CSTORE node to avoid error in `DAGTypeLegalizer::ScalarizeVectorResult`. 4. Combine DAG to simplify the nodes for CLOAD/CSTORE. 5. Lower CLOAD/CSTORE to CFCMOV by pattern match. This is CodeGen part of llvm#95515
1 parent 2dc04f6 commit 1baf9a1

File tree

12 files changed

+317
-15
lines changed

12 files changed

+317
-15
lines changed

llvm/include/llvm/Analysis/TargetTransformInfo.h

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1113,6 +1113,10 @@ class TargetTransformInfo {
11131113
/// \return the number of registers in the target-provided register class.
11141114
unsigned getNumberOfRegisters(unsigned ClassID) const;
11151115

1116+
/// \return true if the target supports load/store that enables fault
1117+
/// suppression of memory operands when the source condition is false.
1118+
bool hasConditionalLoadStoreForType(Type *Ty = nullptr) const;
1119+
11161120
/// \return the target-provided register class ID for the provided type,
11171121
/// accounting for type promotion and other type-legalization techniques that
11181122
/// the target might apply. However, it specifically does not account for the
@@ -1956,6 +1960,7 @@ class TargetTransformInfo::Concept {
19561960
virtual bool preferToKeepConstantsAttached(const Instruction &Inst,
19571961
const Function &Fn) const = 0;
19581962
virtual unsigned getNumberOfRegisters(unsigned ClassID) const = 0;
1963+
virtual bool hasConditionalLoadStoreForType(Type *Ty = nullptr) const = 0;
19591964
virtual unsigned getRegisterClassForType(bool Vector,
19601965
Type *Ty = nullptr) const = 0;
19611966
virtual const char *getRegisterClassName(unsigned ClassID) const = 0;
@@ -2543,6 +2548,9 @@ class TargetTransformInfo::Model final : public TargetTransformInfo::Concept {
25432548
unsigned getNumberOfRegisters(unsigned ClassID) const override {
25442549
return Impl.getNumberOfRegisters(ClassID);
25452550
}
2551+
bool hasConditionalLoadStoreForType(Type *Ty = nullptr) const override {
2552+
return Impl.hasConditionalLoadStoreForType(Ty);
2553+
}
25462554
unsigned getRegisterClassForType(bool Vector,
25472555
Type *Ty = nullptr) const override {
25482556
return Impl.getRegisterClassForType(Vector, Ty);

llvm/include/llvm/Analysis/TargetTransformInfoImpl.h

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -457,6 +457,7 @@ class TargetTransformInfoImplBase {
457457
}
458458

459459
unsigned getNumberOfRegisters(unsigned ClassID) const { return 8; }
460+
bool hasConditionalLoadStoreForType(Type *Ty) const { return false; }
460461

461462
unsigned getRegisterClassForType(bool Vector, Type *Ty = nullptr) const {
462463
return Vector ? 1 : 0;

llvm/include/llvm/CodeGen/TargetLowering.h

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3895,6 +3895,20 @@ class TargetLowering : public TargetLoweringBase {
38953895
const SDValue OldRHS, SDValue &Chain,
38963896
bool IsSignaling = false) const;
38973897

3898+
virtual SDValue visitMaskedLoad(SelectionDAG &DAG, const SDLoc &DL,
3899+
SDValue Chain, MachineMemOperand *MMO,
3900+
SDValue &NewLoad, SDValue Ptr,
3901+
SDValue PassThru, SDValue Mask) const {
3902+
llvm_unreachable("Not Implemented");
3903+
}
3904+
3905+
virtual SDValue visitMaskedStore(SelectionDAG &DAG, const SDLoc &DL,
3906+
SDValue Chain, MachineMemOperand *MMO,
3907+
SDValue Ptr, SDValue Val,
3908+
SDValue Mask) const {
3909+
llvm_unreachable("Not Implemented");
3910+
}
3911+
38983912
/// Returns a pair of (return value, chain).
38993913
/// It is an error to pass RTLIB::UNKNOWN_LIBCALL as \p LC.
39003914
std::pair<SDValue, SDValue> makeLibCall(SelectionDAG &DAG, RTLIB::Libcall LC,

llvm/lib/Analysis/TargetTransformInfo.cpp

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -722,6 +722,10 @@ unsigned TargetTransformInfo::getNumberOfRegisters(unsigned ClassID) const {
722722
return TTIImpl->getNumberOfRegisters(ClassID);
723723
}
724724

725+
bool TargetTransformInfo::hasConditionalLoadStoreForType(Type *Ty) const {
726+
return TTIImpl->hasConditionalLoadStoreForType(Ty);
727+
}
728+
725729
unsigned TargetTransformInfo::getRegisterClassForType(bool Vector,
726730
Type *Ty) const {
727731
return TTIImpl->getRegisterClassForType(Vector, Ty);

llvm/lib/CodeGen/SelectionDAG/SelectionDAGBuilder.cpp

Lines changed: 26 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -4783,9 +4783,18 @@ void SelectionDAGBuilder::visitMaskedStore(const CallInst &I,
47834783
MachineMemOperand *MMO = DAG.getMachineFunction().getMachineMemOperand(
47844784
MachinePointerInfo(PtrOperand), MMOFlags,
47854785
LocationSize::beforeOrAfterPointer(), Alignment, I.getAAMetadata());
4786+
4787+
const auto &TLI = DAG.getTargetLoweringInfo();
4788+
const auto &TTI =
4789+
TLI.getTargetMachine().getTargetTransformInfo(*I.getFunction());
47864790
SDValue StoreNode =
4787-
DAG.getMaskedStore(getMemoryRoot(), sdl, Src0, Ptr, Offset, Mask, VT, MMO,
4788-
ISD::UNINDEXED, false /* Truncating */, IsCompressing);
4791+
!IsCompressing && TTI.hasConditionalLoadStoreForType(
4792+
I.getArgOperand(0)->getType()->getScalarType())
4793+
? TLI.visitMaskedStore(DAG, sdl, getMemoryRoot(), MMO, Ptr, Src0,
4794+
Mask)
4795+
: DAG.getMaskedStore(getMemoryRoot(), sdl, Src0, Ptr, Offset, Mask,
4796+
VT, MMO, ISD::UNINDEXED, /*Truncating=*/false,
4797+
IsCompressing);
47894798
DAG.setRoot(StoreNode);
47904799
setValue(&I, StoreNode);
47914800
}
@@ -4958,12 +4967,23 @@ void SelectionDAGBuilder::visitMaskedLoad(const CallInst &I, bool IsExpanding) {
49584967
MachinePointerInfo(PtrOperand), MMOFlags,
49594968
LocationSize::beforeOrAfterPointer(), Alignment, AAInfo, Ranges);
49604969

4961-
SDValue Load =
4962-
DAG.getMaskedLoad(VT, sdl, InChain, Ptr, Offset, Mask, Src0, VT, MMO,
4963-
ISD::UNINDEXED, ISD::NON_EXTLOAD, IsExpanding);
4970+
const auto &TLI = DAG.getTargetLoweringInfo();
4971+
const auto &TTI =
4972+
TLI.getTargetMachine().getTargetTransformInfo(*I.getFunction());
4973+
// The Load/Res may point to different values and both of them are output
4974+
// variables.
4975+
SDValue Load;
4976+
SDValue Res;
4977+
if (!IsExpanding && TTI.hasConditionalLoadStoreForType(
4978+
Src0Operand->getType()->getScalarType()))
4979+
Res = TLI.visitMaskedLoad(DAG, sdl, InChain, MMO, Load, Ptr, Src0, Mask);
4980+
else
4981+
Res = Load =
4982+
DAG.getMaskedLoad(VT, sdl, InChain, Ptr, Offset, Mask, Src0, VT, MMO,
4983+
ISD::UNINDEXED, ISD::NON_EXTLOAD, IsExpanding);
49644984
if (AddToChain)
49654985
PendingLoads.push_back(Load.getValue(1));
4966-
setValue(&I, Load);
4986+
setValue(&I, Res);
49674987
}
49684988

49694989
void SelectionDAGBuilder::visitMaskedGather(const CallInst &I) {

llvm/lib/Target/X86/X86ISelLowering.cpp

Lines changed: 78 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -32315,6 +32315,54 @@ bool X86TargetLowering::isInlineAsmTargetBranch(
3231532315
return Inst.equals_insensitive("call") || Inst.equals_insensitive("jmp");
3231632316
}
3231732317

32318+
static SDValue getFlagsOfCmpZeroFori1(SelectionDAG &DAG, const SDLoc &DL,
32319+
SDValue Mask) {
32320+
EVT Ty = MVT::i8;
32321+
auto V = DAG.getBitcast(MVT::i1, Mask);
32322+
auto VE = DAG.getZExtOrTrunc(V, DL, Ty);
32323+
auto Zero = DAG.getConstant(0, DL, Ty);
32324+
SDVTList X86SubVTs = DAG.getVTList(Ty, MVT::i32);
32325+
auto CmpZero = DAG.getNode(X86ISD::SUB, DL, X86SubVTs, Zero, VE);
32326+
return SDValue(CmpZero.getNode(), 1);
32327+
}
32328+
32329+
SDValue X86TargetLowering::visitMaskedLoad(
32330+
SelectionDAG &DAG, const SDLoc &DL, SDValue Chain, MachineMemOperand *MMO,
32331+
SDValue &NewLoad, SDValue Ptr, SDValue PassThru, SDValue Mask) const {
32332+
// @llvm.masked.load.v1*(ptr, alignment, mask, passthru)
32333+
// ->
32334+
// _, flags = SUB 0, mask
32335+
// res, chain = CLOAD inchain, ptr, (bit_cast_to_scalar passthru), cond, flags
32336+
// bit_cast_to_vector<res>
32337+
EVT VTy = PassThru.getValueType();
32338+
EVT Ty = VTy.getVectorElementType();
32339+
SDVTList Tys = DAG.getVTList(Ty, MVT::Other);
32340+
auto ScalarPassThru = PassThru.isUndef() ? DAG.getConstant(0, DL, Ty)
32341+
: DAG.getBitcast(Ty, PassThru);
32342+
auto Flags = getFlagsOfCmpZeroFori1(DAG, DL, Mask);
32343+
auto COND_NE = DAG.getTargetConstant(X86::COND_NE, DL, MVT::i8);
32344+
SDValue Ops[] = {Chain, Ptr, ScalarPassThru, COND_NE, Flags};
32345+
NewLoad = DAG.getMemIntrinsicNode(X86ISD::CLOAD, DL, Tys, Ops, Ty, MMO);
32346+
return DAG.getBitcast(VTy, NewLoad);
32347+
}
32348+
32349+
SDValue X86TargetLowering::visitMaskedStore(SelectionDAG &DAG, const SDLoc &DL,
32350+
SDValue Chain,
32351+
MachineMemOperand *MMO, SDValue Ptr,
32352+
SDValue Val, SDValue Mask) const {
32353+
// llvm.masked.store.v1*(Src0, Ptr, alignment, Mask)
32354+
// ->
32355+
// _, flags = SUB 0, mask
32356+
// chain = CSTORE inchain, (bit_cast_to_scalar val), ptr, cond, flags
32357+
EVT Ty = Val.getValueType().getVectorElementType();
32358+
SDVTList Tys = DAG.getVTList(MVT::Other);
32359+
auto ScalarVal = DAG.getBitcast(Ty, Val);
32360+
auto Flags = getFlagsOfCmpZeroFori1(DAG, DL, Mask);
32361+
auto COND_NE = DAG.getTargetConstant(X86::COND_NE, DL, MVT::i8);
32362+
SDValue Ops[] = {Chain, ScalarVal, Ptr, COND_NE, Flags};
32363+
return DAG.getMemIntrinsicNode(X86ISD::CSTORE, DL, Tys, Ops, Ty, MMO);
32364+
}
32365+
3231832366
/// Provide custom lowering hooks for some operations.
3231932367
SDValue X86TargetLowering::LowerOperation(SDValue Op, SelectionDAG &DAG) const {
3232032368
switch (Op.getOpcode()) {
@@ -34031,6 +34079,8 @@ const char *X86TargetLowering::getTargetNodeName(unsigned Opcode) const {
3403134079
NODE_NAME_CASE(STRICT_FP80_ADD)
3403234080
NODE_NAME_CASE(CCMP)
3403334081
NODE_NAME_CASE(CTEST)
34082+
NODE_NAME_CASE(CLOAD)
34083+
NODE_NAME_CASE(CSTORE)
3403434084
}
3403534085
return nullptr;
3403634086
#undef NODE_NAME_CASE
@@ -55636,6 +55686,32 @@ static SDValue combineSubSetcc(SDNode *N, SelectionDAG &DAG) {
5563655686
return SDValue();
5563755687
}
5563855688

55689+
static SDValue combineX86CloadCstore(SDNode *N, SelectionDAG &DAG) {
55690+
// res, flags2 = sub 0, (setcc cc, flag)
55691+
// cload/cstore ..., cond_ne, flag2
55692+
// ->
55693+
// cload/cstore cc, flag
55694+
if (N->getConstantOperandVal(3) != X86::COND_NE)
55695+
return SDValue();
55696+
55697+
SDValue Sub = N->getOperand(4);
55698+
if (Sub.getOpcode() != X86ISD::SUB)
55699+
return SDValue();
55700+
55701+
SDValue SetCC = Sub.getOperand(1);
55702+
55703+
if (!X86::isZeroNode(Sub.getOperand(0)) || SetCC.getOpcode() != X86ISD::SETCC)
55704+
return SDValue();
55705+
55706+
SmallVector<SDValue, 5> Ops(N->op_values());
55707+
Ops[3] = SetCC.getOperand(0);
55708+
Ops[4] = SetCC.getOperand(1);
55709+
55710+
return DAG.getMemIntrinsicNode(N->getOpcode(), SDLoc(N), N->getVTList(), Ops,
55711+
cast<MemSDNode>(N)->getMemoryVT(),
55712+
cast<MemSDNode>(N)->getMemOperand());
55713+
}
55714+
5563955715
static SDValue combineSub(SDNode *N, SelectionDAG &DAG,
5564055716
TargetLowering::DAGCombinerInfo &DCI,
5564155717
const X86Subtarget &Subtarget) {
@@ -57345,6 +57421,8 @@ SDValue X86TargetLowering::PerformDAGCombine(SDNode *N,
5734557421
case ISD::SUB: return combineSub(N, DAG, DCI, Subtarget);
5734657422
case X86ISD::ADD:
5734757423
case X86ISD::SUB: return combineX86AddSub(N, DAG, DCI, Subtarget);
57424+
case X86ISD::CLOAD:
57425+
case X86ISD::CSTORE: return combineX86CloadCstore(N, DAG);
5734857426
case X86ISD::SBB: return combineSBB(N, DAG);
5734957427
case X86ISD::ADC: return combineADC(N, DAG, DCI);
5735057428
case ISD::MUL: return combineMul(N, DAG, DCI, Subtarget);

llvm/lib/Target/X86/X86ISelLowering.h

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -903,6 +903,10 @@ namespace llvm {
903903
// is needed so that this can be expanded with control flow.
904904
VASTART_SAVE_XMM_REGS,
905905

906+
// Conditional load/store instructions
907+
CLOAD,
908+
CSTORE,
909+
906910
// WARNING: Do not add anything in the end unless you want the node to
907911
// have memop! In fact, starting from FIRST_TARGET_MEMORY_OPCODE all
908912
// opcodes will be thought as target memory ops!
@@ -1556,6 +1560,14 @@ namespace llvm {
15561560
bool isInlineAsmTargetBranch(const SmallVectorImpl<StringRef> &AsmStrs,
15571561
unsigned OpNo) const override;
15581562

1563+
SDValue visitMaskedLoad(SelectionDAG &DAG, const SDLoc &DL, SDValue Chain,
1564+
MachineMemOperand *MMO, SDValue &NewLoad,
1565+
SDValue Ptr, SDValue PassThru,
1566+
SDValue Mask) const override;
1567+
SDValue visitMaskedStore(SelectionDAG &DAG, const SDLoc &DL, SDValue Chain,
1568+
MachineMemOperand *MMO, SDValue Ptr, SDValue Val,
1569+
SDValue Mask) const override;
1570+
15591571
/// Lower interleaved load(s) into target specific
15601572
/// instructions/intrinsics.
15611573
bool lowerInterleavedLoad(LoadInst *LI,

llvm/lib/Target/X86/X86InstrCMovSetCC.td

Lines changed: 21 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -113,6 +113,27 @@ let Predicates = [HasCMOV, HasCF] in {
113113
(CFCMOV32rr GR32:$src1, (inv_cond_XFORM timm:$cond))>;
114114
def : Pat<(X86cmov GR64:$src1, 0, timm:$cond, EFLAGS),
115115
(CFCMOV64rr GR64:$src1, (inv_cond_XFORM timm:$cond))>;
116+
117+
def : Pat<(X86cload addr:$src1, 0, timm:$cond, EFLAGS),
118+
(CFCMOV16rm addr:$src1, timm:$cond)>;
119+
def : Pat<(X86cload addr:$src1, 0, timm:$cond, EFLAGS),
120+
(CFCMOV32rm addr:$src1, timm:$cond)>;
121+
def : Pat<(X86cload addr:$src1, 0, timm:$cond, EFLAGS),
122+
(CFCMOV64rm addr:$src1, timm:$cond)>;
123+
124+
def : Pat<(X86cload addr:$src2, GR16:$src1, timm:$cond, EFLAGS),
125+
(CFCMOV16rm_ND GR16:$src1, addr:$src2, timm:$cond)>;
126+
def : Pat<(X86cload addr:$src2, GR32:$src1, timm:$cond, EFLAGS),
127+
(CFCMOV32rm_ND GR32:$src1, addr:$src2, timm:$cond)>;
128+
def : Pat<(X86cload addr:$src2, GR64:$src1, timm:$cond, EFLAGS),
129+
(CFCMOV64rm_ND GR64:$src1, addr:$src2, timm:$cond)>;
130+
131+
def : Pat<(X86cstore GR16:$src2, addr:$src1, timm:$cond, EFLAGS),
132+
(CFCMOV16mr addr:$src1, GR16:$src2, timm:$cond)>;
133+
def : Pat<(X86cstore GR32:$src2, addr:$src1, timm:$cond, EFLAGS),
134+
(CFCMOV32mr addr:$src1, GR32:$src2, timm:$cond)>;
135+
def : Pat<(X86cstore GR64:$src2, addr:$src1, timm:$cond, EFLAGS),
136+
(CFCMOV64mr addr:$src1, GR64:$src2, timm:$cond)>;
116137
}
117138

118139
// SetCC instructions.

llvm/lib/Target/X86/X86InstrFragments.td

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,15 @@ def SDTX86FCmp : SDTypeProfile<1, 2, [SDTCisVT<0, i32>, SDTCisFP<1>,
1515
def SDTX86Ccmp : SDTypeProfile<1, 5,
1616
[SDTCisVT<3, i8>, SDTCisVT<4, i8>, SDTCisVT<5, i32>]>;
1717

18+
// RES = op PTR, PASSTHRU, COND, EFLAGS
19+
def SDTX86Cload : SDTypeProfile<1, 4,
20+
[SDTCisInt<0>, SDTCisPtrTy<1>, SDTCisSameAs<0, 2>,
21+
SDTCisVT<3, i8>, SDTCisVT<4, i32>]>;
22+
// op VAL, PTR, COND, EFLAGS
23+
def SDTX86Cstore : SDTypeProfile<0, 4,
24+
[SDTCisInt<0>, SDTCisPtrTy<1>,
25+
SDTCisVT<2, i8>, SDTCisVT<3, i32>]>;
26+
1827
def SDTX86Cmov : SDTypeProfile<1, 4,
1928
[SDTCisSameAs<0, 1>, SDTCisSameAs<1, 2>,
2029
SDTCisVT<3, i8>, SDTCisVT<4, i32>]>;
@@ -144,6 +153,9 @@ def X86bt : SDNode<"X86ISD::BT", SDTX86CmpTest>;
144153
def X86ccmp : SDNode<"X86ISD::CCMP", SDTX86Ccmp>;
145154
def X86ctest : SDNode<"X86ISD::CTEST", SDTX86Ccmp>;
146155

156+
def X86cload : SDNode<"X86ISD::CLOAD", SDTX86Cload, [SDNPHasChain, SDNPMayLoad, SDNPMemOperand]>;
157+
def X86cstore : SDNode<"X86ISD::CSTORE", SDTX86Cstore, [SDNPHasChain, SDNPMayStore, SDNPMemOperand]>;
158+
147159
def X86cmov : SDNode<"X86ISD::CMOV", SDTX86Cmov>;
148160
def X86brcond : SDNode<"X86ISD::BRCOND", SDTX86BrCond,
149161
[SDNPHasChain]>;

llvm/lib/Target/X86/X86TargetTransformInfo.cpp

Lines changed: 35 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -176,6 +176,27 @@ unsigned X86TTIImpl::getNumberOfRegisters(unsigned ClassID) const {
176176
return 8;
177177
}
178178

179+
bool X86TTIImpl::hasConditionalLoadStoreForType(Type *Ty) const {
180+
if (!ST->hasCF())
181+
return false;
182+
if (!Ty)
183+
return true;
184+
// Conditional faulting is supported by CFCMOV, which only accepts
185+
// 16/32/64-bit operands.
186+
// TODO: Support f32/f64 with VMOVSS/VMOVSD with zero mask when it's
187+
// profitable.
188+
if (!Ty->isIntegerTy())
189+
return false;
190+
switch (cast<IntegerType>(Ty)->getBitWidth()) {
191+
default:
192+
return false;
193+
case 16:
194+
case 32:
195+
case 64:
196+
return true;
197+
}
198+
}
199+
179200
TypeSize
180201
X86TTIImpl::getRegisterBitWidth(TargetTransformInfo::RegisterKind K) const {
181202
unsigned PreferVectorWidth = ST->getPreferVectorWidth();
@@ -5070,17 +5091,22 @@ X86TTIImpl::getMaskedMemoryOpCost(unsigned Opcode, Type *SrcTy, Align Alignment,
50705091
std::pair<InstructionCost, MVT> LT = getTypeLegalizationCost(SrcVTy);
50715092
auto VT = TLI->getValueType(DL, SrcVTy);
50725093
InstructionCost Cost = 0;
5073-
if (VT.isSimple() && LT.second != VT.getSimpleVT() &&
5094+
MVT Ty = LT.second;
5095+
if (Ty == MVT::i16 || Ty == MVT::i32 || Ty == MVT::i64)
5096+
// APX masked load/store for scalar is cheap.
5097+
return Cost + LT.first;
5098+
5099+
if (VT.isSimple() && Ty != VT.getSimpleVT() &&
50745100
LT.second.getVectorNumElements() == NumElem)
50755101
// Promotion requires extend/truncate for data and a shuffle for mask.
50765102
Cost += getShuffleCost(TTI::SK_PermuteTwoSrc, SrcVTy, std::nullopt,
50775103
CostKind, 0, nullptr) +
50785104
getShuffleCost(TTI::SK_PermuteTwoSrc, MaskTy, std::nullopt,
50795105
CostKind, 0, nullptr);
50805106

5081-
else if (LT.first * LT.second.getVectorNumElements() > NumElem) {
5107+
else if (LT.first * Ty.getVectorNumElements() > NumElem) {
50825108
auto *NewMaskTy = FixedVectorType::get(MaskTy->getElementType(),
5083-
LT.second.getVectorNumElements());
5109+
Ty.getVectorNumElements());
50845110
// Expanding requires fill mask with zeroes
50855111
Cost += getShuffleCost(TTI::SK_InsertSubvector, NewMaskTy, std::nullopt,
50865112
CostKind, 0, MaskTy);
@@ -5899,14 +5925,14 @@ bool X86TTIImpl::canMacroFuseCmp() {
58995925
}
59005926

59015927
bool X86TTIImpl::isLegalMaskedLoad(Type *DataTy, Align Alignment) {
5902-
if (!ST->hasAVX())
5903-
return false;
5928+
Type *ScalarTy = DataTy->getScalarType();
59045929

5905-
// The backend can't handle a single element vector.
5906-
if (isa<VectorType>(DataTy) &&
5907-
cast<FixedVectorType>(DataTy)->getNumElements() == 1)
5930+
// The backend can't handle a single element vector w/o CFCMOV.
5931+
if (isa<VectorType>(DataTy) && cast<FixedVectorType>(DataTy)->getNumElements() == 1)
5932+
return ST->hasCF() && hasConditionalLoadStoreForType(ScalarTy);
5933+
5934+
if (!ST->hasAVX())
59085935
return false;
5909-
Type *ScalarTy = DataTy->getScalarType();
59105936

59115937
if (ScalarTy->isPointerTy())
59125938
return true;

llvm/lib/Target/X86/X86TargetTransformInfo.h

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -132,6 +132,7 @@ class X86TTIImpl : public BasicTTIImplBase<X86TTIImpl> {
132132
/// @{
133133

134134
unsigned getNumberOfRegisters(unsigned ClassID) const;
135+
bool hasConditionalLoadStoreForType(Type *Ty = nullptr) const;
135136
TypeSize getRegisterBitWidth(TargetTransformInfo::RegisterKind K) const;
136137
unsigned getLoadStoreVecRegBitWidth(unsigned AS) const;
137138
unsigned getMaxInterleaveFactor(ElementCount VF);

0 commit comments

Comments
 (0)