Skip to content

Commit d225de6

Browse files
committed
Revert "[X86][AVX] Add getBROADCAST_LOAD helper function. NFCI."
This reverts commit 1cfecf4. This commit broke LLVM code generated through XLA by removing a conditional on Ld->getExtensionType() == ISD::NON_EXTLOAD This is not a perfect revert. The new function is left as other uses of it exist now.
1 parent 70fa947 commit d225de6

File tree

1 file changed

+60
-25
lines changed

1 file changed

+60
-25
lines changed

llvm/lib/Target/X86/X86ISelLowering.cpp

Lines changed: 60 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -16084,12 +16084,21 @@ static SDValue lowerV2X128Shuffle(const SDLoc &DL, MVT VT, SDValue V1,
1608416084
bool SplatHi = isShuffleEquivalent(Mask, {2, 3, 2, 3}, V1);
1608516085
if ((SplatLo || SplatHi) && !Subtarget.hasAVX512() && V1.hasOneUse() &&
1608616086
MayFoldLoad(peekThroughOneUseBitcasts(V1))) {
16087-
MVT MemVT = VT.getHalfNumVectorElementsVT();
16088-
unsigned Ofs = SplatLo ? 0 : MemVT.getStoreSize();
1608916087
auto *Ld = cast<LoadSDNode>(peekThroughOneUseBitcasts(V1));
16090-
if (SDValue BcstLd = getBROADCAST_LOAD(X86ISD::SUBV_BROADCAST_LOAD, DL,
16091-
VT, MemVT, Ld, Ofs, DAG))
16092-
return BcstLd;
16088+
if (!Ld->isNonTemporal()) {
16089+
MVT MemVT = VT.getHalfNumVectorElementsVT();
16090+
unsigned Ofs = SplatLo ? 0 : MemVT.getStoreSize();
16091+
SDVTList Tys = DAG.getVTList(VT, MVT::Other);
16092+
SDValue Ptr = DAG.getMemBasePlusOffset(Ld->getBasePtr(),
16093+
TypeSize::Fixed(Ofs), DL);
16094+
SDValue Ops[] = {Ld->getChain(), Ptr};
16095+
SDValue BcastLd = DAG.getMemIntrinsicNode(
16096+
X86ISD::SUBV_BROADCAST_LOAD, DL, Tys, Ops, MemVT,
16097+
DAG.getMachineFunction().getMachineMemOperand(
16098+
Ld->getMemOperand(), Ofs, MemVT.getStoreSize()));
16099+
DAG.ReplaceAllUsesOfValueWith(SDValue(Ld, 1), BcastLd.getValue(1));
16100+
return BcastLd;
16101+
}
1609316102
}
1609416103

1609516104
// With AVX2, use VPERMQ/VPERMPD for unary shuffles to allow memory folding.
@@ -38011,7 +38020,7 @@ static SDValue combineTargetShuffle(SDValue N, SelectionDAG &DAG,
3801138020
return Res;
3801238021

3801338022
// Fold vperm2x128 subvector shuffle with an inner concat pattern.
38014-
// vperm2x128(concat(X,Y),concat(Z,W)) --> concat X,Y etc.
38023+
// vperm2x128(concat(X,Y),concat(Z,W)) --> concat X,Y etc.
3801538024
auto FindSubVector128 = [&](unsigned Idx) {
3801638025
if (Idx > 3)
3801738026
return SDValue();
@@ -38992,10 +39001,10 @@ bool X86TargetLowering::SimplifyDemandedVectorEltsForTargetNode(
3899239001
}
3899339002
// Subvector broadcast.
3899439003
case X86ISD::SUBV_BROADCAST_LOAD: {
38995-
SDLoc DL(Op);
3899639004
auto *MemIntr = cast<MemIntrinsicSDNode>(Op);
3899739005
EVT MemVT = MemIntr->getMemoryVT();
3899839006
if (ExtSizeInBits == MemVT.getStoreSizeInBits()) {
39007+
SDLoc DL(Op);
3899939008
SDValue Ld =
3900039009
TLO.DAG.getLoad(MemVT, DL, MemIntr->getChain(),
3900139010
MemIntr->getBasePtr(), MemIntr->getMemOperand());
@@ -39004,13 +39013,18 @@ bool X86TargetLowering::SimplifyDemandedVectorEltsForTargetNode(
3900439013
return TLO.CombineTo(Op, insertSubVector(TLO.DAG.getUNDEF(VT), Ld, 0,
3900539014
TLO.DAG, DL, ExtSizeInBits));
3900639015
} else if ((ExtSizeInBits % MemVT.getStoreSizeInBits()) == 0) {
39016+
SDLoc DL(Op);
3900739017
EVT BcstVT = EVT::getVectorVT(*TLO.DAG.getContext(), VT.getScalarType(),
3900839018
ExtSizeInBits / VT.getScalarSizeInBits());
39009-
if (SDValue BcstLd =
39010-
getBROADCAST_LOAD(Opc, DL, BcstVT, MemVT, MemIntr, 0, TLO.DAG))
39011-
return TLO.CombineTo(Op,
39012-
insertSubVector(TLO.DAG.getUNDEF(VT), BcstLd, 0,
39013-
TLO.DAG, DL, ExtSizeInBits));
39019+
SDVTList Tys = TLO.DAG.getVTList(BcstVT, MVT::Other);
39020+
SDValue Ops[] = {MemIntr->getOperand(0), MemIntr->getOperand(1)};
39021+
SDValue Bcst =
39022+
TLO.DAG.getMemIntrinsicNode(X86ISD::SUBV_BROADCAST_LOAD, DL, Tys,
39023+
Ops, MemVT, MemIntr->getMemOperand());
39024+
TLO.DAG.makeEquivalentMemoryOrdering(SDValue(MemIntr, 1),
39025+
Bcst.getValue(1));
39026+
return TLO.CombineTo(Op, insertSubVector(TLO.DAG.getUNDEF(VT), Bcst, 0,
39027+
TLO.DAG, DL, ExtSizeInBits));
3901439028
}
3901539029
break;
3901639030
}
@@ -50083,21 +50097,36 @@ static SDValue combineConcatVectorOps(const SDLoc &DL, MVT VT,
5008350097
if (Op0.getOpcode() == X86ISD::VBROADCAST)
5008450098
return DAG.getNode(Op0.getOpcode(), DL, VT, Op0.getOperand(0));
5008550099

50086-
// If this simple subvector or scalar/subvector broadcast_load is inserted
50087-
// into both halves, use a larger broadcast_load. Update other uses to use
50088-
// an extracted subvector.
50089-
if (Op0.getOpcode() == ISD::LOAD ||
50090-
Op0.getOpcode() == X86ISD::VBROADCAST_LOAD ||
50100+
// If this scalar/subvector broadcast_load is inserted into both halves, use
50101+
// a larger broadcast_load. Update other uses to use an extracted subvector.
50102+
if (Op0.getOpcode() == X86ISD::VBROADCAST_LOAD ||
5009150103
Op0.getOpcode() == X86ISD::SUBV_BROADCAST_LOAD) {
50092-
auto *Mem = cast<MemSDNode>(Op0);
50093-
unsigned Opcode = Op0.getOpcode() == X86ISD::VBROADCAST_LOAD
50094-
? X86ISD::VBROADCAST_LOAD
50095-
: X86ISD::SUBV_BROADCAST_LOAD;
50096-
if (SDValue BcastLd = getBROADCAST_LOAD(
50097-
Opcode, DL, VT, Mem->getMemoryVT(), Mem, 0, DAG)) {
50104+
auto *MemIntr = cast<MemIntrinsicSDNode>(Op0);
50105+
SDVTList Tys = DAG.getVTList(VT, MVT::Other);
50106+
SDValue Ops[] = {MemIntr->getChain(), MemIntr->getBasePtr()};
50107+
SDValue BcastLd = DAG.getMemIntrinsicNode(Op0.getOpcode(), DL, Tys, Ops,
50108+
MemIntr->getMemoryVT(),
50109+
MemIntr->getMemOperand());
50110+
DAG.ReplaceAllUsesOfValueWith(
50111+
Op0, extractSubVector(BcastLd, 0, DAG, DL, Op0.getValueSizeInBits()));
50112+
DAG.ReplaceAllUsesOfValueWith(SDValue(MemIntr, 1), BcastLd.getValue(1));
50113+
return BcastLd;
50114+
}
50115+
50116+
// If this is a simple subvector load repeated across multiple lanes, then
50117+
// broadcast the load. Update other uses to use an extracted subvector.
50118+
if (auto *Ld = dyn_cast<LoadSDNode>(Op0)) {
50119+
if (Ld->isSimple() && !Ld->isNonTemporal() &&
50120+
Ld->getExtensionType() == ISD::NON_EXTLOAD) {
50121+
SDVTList Tys = DAG.getVTList(VT, MVT::Other);
50122+
SDValue Ops[] = {Ld->getChain(), Ld->getBasePtr()};
50123+
SDValue BcastLd =
50124+
DAG.getMemIntrinsicNode(X86ISD::SUBV_BROADCAST_LOAD, DL, Tys, Ops,
50125+
Ld->getMemoryVT(), Ld->getMemOperand());
5009850126
DAG.ReplaceAllUsesOfValueWith(
5009950127
Op0,
5010050128
extractSubVector(BcastLd, 0, DAG, DL, Op0.getValueSizeInBits()));
50129+
DAG.ReplaceAllUsesOfValueWith(SDValue(Ld, 1), BcastLd.getValue(1));
5010150130
return BcastLd;
5010250131
}
5010350132
}
@@ -50461,8 +50490,14 @@ static SDValue combineInsertSubvector(SDNode *N, SelectionDAG &DAG,
5046150490
if (Vec.isUndef() && IdxVal != 0 && SubVec.hasOneUse() &&
5046250491
SubVec.getOpcode() == X86ISD::VBROADCAST_LOAD) {
5046350492
auto *MemIntr = cast<MemIntrinsicSDNode>(SubVec);
50464-
return getBROADCAST_LOAD(X86ISD::VBROADCAST_LOAD, dl, OpVT,
50465-
MemIntr->getMemoryVT(), MemIntr, 0, DAG);
50493+
SDVTList Tys = DAG.getVTList(OpVT, MVT::Other);
50494+
SDValue Ops[] = { MemIntr->getChain(), MemIntr->getBasePtr() };
50495+
SDValue BcastLd =
50496+
DAG.getMemIntrinsicNode(X86ISD::VBROADCAST_LOAD, dl, Tys, Ops,
50497+
MemIntr->getMemoryVT(),
50498+
MemIntr->getMemOperand());
50499+
DAG.ReplaceAllUsesOfValueWith(SDValue(MemIntr, 1), BcastLd.getValue(1));
50500+
return BcastLd;
5046650501
}
5046750502

5046850503
// If we're splatting the lower half subvector of a full vector load into the

0 commit comments

Comments
 (0)