Skip to content

Commit 2665b2a

Browse files
committed
[X86] Pull out combineConstantPoolLoads helper from combineLoad. NFC.
The logic is already pretty dense and a future patch will further complicate this.
1 parent 4e251e7 commit 2665b2a

File tree

1 file changed

+86
-51
lines changed

1 file changed

+86
-51
lines changed

llvm/lib/Target/X86/X86ISelLowering.cpp

Lines changed: 86 additions & 51 deletions
Original file line numberDiff line numberDiff line change
@@ -50823,10 +50823,83 @@ static SDValue detectAVGPattern(SDValue In, EVT VT, SelectionDAG &DAG,
5082350823
return SDValue();
5082450824
}
5082550825

50826+
static SDValue combineConstantPoolLoads(SDNode *N, const SDLoc &dl,
50827+
SelectionDAG &DAG,
50828+
TargetLowering::DAGCombinerInfo &DCI,
50829+
const X86Subtarget &Subtarget) {
50830+
auto *Ld = cast<LoadSDNode>(N);
50831+
EVT RegVT = Ld->getValueType(0);
50832+
EVT MemVT = Ld->getMemoryVT();
50833+
SDValue Ptr = Ld->getBasePtr();
50834+
SDValue Chain = Ld->getChain();
50835+
ISD::LoadExtType Ext = Ld->getExtensionType();
50836+
50837+
if (Ext != ISD::NON_EXTLOAD || !Subtarget.hasAVX() || !Ld->isSimple())
50838+
return SDValue();
50839+
50840+
if (!(RegVT.is128BitVector() || RegVT.is256BitVector()))
50841+
return SDValue();
50842+
50843+
auto MatchingBits = [](const APInt &Undefs, const APInt &UserUndefs,
50844+
ArrayRef<APInt> Bits, ArrayRef<APInt> UserBits) {
50845+
for (unsigned I = 0, E = Undefs.getBitWidth(); I != E; ++I) {
50846+
if (Undefs[I])
50847+
continue;
50848+
if (UserUndefs[I] || Bits[I] != UserBits[I])
50849+
return false;
50850+
}
50851+
return true;
50852+
};
50853+
50854+
// Look through all other loads/broadcasts in the chain for another constant
50855+
// pool entry.
50856+
for (SDNode *User : Chain->uses()) {
50857+
auto *UserLd = dyn_cast<MemSDNode>(User);
50858+
if (User != N && UserLd &&
50859+
(User->getOpcode() == X86ISD::SUBV_BROADCAST_LOAD ||
50860+
User->getOpcode() == X86ISD::VBROADCAST_LOAD ||
50861+
ISD::isNormalLoad(User)) &&
50862+
UserLd->getChain() == Chain && !User->hasAnyUseOfValue(1) &&
50863+
User->getValueSizeInBits(0).getFixedValue() >
50864+
RegVT.getFixedSizeInBits()) {
50865+
EVT UserVT = User->getValueType(0);
50866+
SDValue UserPtr = UserLd->getBasePtr();
50867+
const Constant *LdC = getTargetConstantFromBasePtr(Ptr);
50868+
const Constant *UserC = getTargetConstantFromBasePtr(UserPtr);
50869+
50870+
// See if we are loading a constant that matches in the lower
50871+
// bits of a longer constant (but from a different constant pool ptr).
50872+
if (LdC && UserC && UserPtr != Ptr) {
50873+
unsigned LdSize = LdC->getType()->getPrimitiveSizeInBits();
50874+
unsigned UserSize = UserC->getType()->getPrimitiveSizeInBits();
50875+
if (LdSize < UserSize || !ISD::isNormalLoad(User)) {
50876+
APInt Undefs, UserUndefs;
50877+
SmallVector<APInt> Bits, UserBits;
50878+
unsigned NumBits = std::min(RegVT.getScalarSizeInBits(),
50879+
UserVT.getScalarSizeInBits());
50880+
if (getTargetConstantBitsFromNode(SDValue(N, 0), NumBits, Undefs,
50881+
Bits) &&
50882+
getTargetConstantBitsFromNode(SDValue(User, 0), NumBits,
50883+
UserUndefs, UserBits)) {
50884+
if (MatchingBits(Undefs, UserUndefs, Bits, UserBits)) {
50885+
SDValue Extract = extractSubVector(
50886+
SDValue(User, 0), 0, DAG, SDLoc(N), RegVT.getSizeInBits());
50887+
Extract = DAG.getBitcast(RegVT, Extract);
50888+
return DCI.CombineTo(N, Extract, SDValue(User, 1));
50889+
}
50890+
}
50891+
}
50892+
}
50893+
}
50894+
}
50895+
50896+
return SDValue();
50897+
}
50898+
5082650899
static SDValue combineLoad(SDNode *N, SelectionDAG &DAG,
5082750900
TargetLowering::DAGCombinerInfo &DCI,
5082850901
const X86Subtarget &Subtarget) {
50829-
LoadSDNode *Ld = cast<LoadSDNode>(N);
50902+
auto *Ld = cast<LoadSDNode>(N);
5083050903
EVT RegVT = Ld->getValueType(0);
5083150904
EVT MemVT = Ld->getMemoryVT();
5083250905
SDLoc dl(Ld);
@@ -50885,7 +50958,7 @@ static SDValue combineLoad(SDNode *N, SelectionDAG &DAG,
5088550958
}
5088650959
}
5088750960

50888-
// If we also load/broadcast this to a wider type, then just extract the
50961+
// If we also broadcast this vector to a wider type, then just extract the
5088950962
// lowest subvector.
5089050963
if (Ext == ISD::NON_EXTLOAD && Subtarget.hasAVX() && Ld->isSimple() &&
5089150964
(RegVT.is128BitVector() || RegVT.is256BitVector())) {
@@ -50894,61 +50967,23 @@ static SDValue combineLoad(SDNode *N, SelectionDAG &DAG,
5089450967
for (SDNode *User : Chain->uses()) {
5089550968
auto *UserLd = dyn_cast<MemSDNode>(User);
5089650969
if (User != N && UserLd &&
50897-
(User->getOpcode() == X86ISD::SUBV_BROADCAST_LOAD ||
50898-
User->getOpcode() == X86ISD::VBROADCAST_LOAD ||
50899-
ISD::isNormalLoad(User)) &&
50900-
UserLd->getChain() == Chain && !User->hasAnyUseOfValue(1) &&
50970+
User->getOpcode() == X86ISD::SUBV_BROADCAST_LOAD &&
50971+
UserLd->getChain() == Chain && UserLd->getBasePtr() == Ptr &&
50972+
UserLd->getMemoryVT().getSizeInBits() == MemVT.getSizeInBits() &&
50973+
!User->hasAnyUseOfValue(1) &&
5090150974
User->getValueSizeInBits(0).getFixedValue() >
5090250975
RegVT.getFixedSizeInBits()) {
50903-
if (User->getOpcode() == X86ISD::SUBV_BROADCAST_LOAD &&
50904-
UserLd->getBasePtr() == Ptr &&
50905-
UserLd->getMemoryVT().getSizeInBits() == MemVT.getSizeInBits()) {
50906-
SDValue Extract = extractSubVector(SDValue(User, 0), 0, DAG, SDLoc(N),
50907-
RegVT.getSizeInBits());
50908-
Extract = DAG.getBitcast(RegVT, Extract);
50909-
return DCI.CombineTo(N, Extract, SDValue(User, 1));
50910-
}
50911-
auto MatchingBits = [](const APInt &Undefs, const APInt &UserUndefs,
50912-
ArrayRef<APInt> Bits, ArrayRef<APInt> UserBits) {
50913-
for (unsigned I = 0, E = Undefs.getBitWidth(); I != E; ++I) {
50914-
if (Undefs[I])
50915-
continue;
50916-
if (UserUndefs[I] || Bits[I] != UserBits[I])
50917-
return false;
50918-
}
50919-
return true;
50920-
};
50921-
// See if we are loading a constant that matches in the lower
50922-
// bits of a longer constant (but from a different constant pool ptr).
50923-
EVT UserVT = User->getValueType(0);
50924-
SDValue UserPtr = UserLd->getBasePtr();
50925-
const Constant *LdC = getTargetConstantFromBasePtr(Ptr);
50926-
const Constant *UserC = getTargetConstantFromBasePtr(UserPtr);
50927-
if (LdC && UserC && UserPtr != Ptr) {
50928-
unsigned LdSize = LdC->getType()->getPrimitiveSizeInBits();
50929-
unsigned UserSize = UserC->getType()->getPrimitiveSizeInBits();
50930-
if (LdSize < UserSize || !ISD::isNormalLoad(User)) {
50931-
APInt Undefs, UserUndefs;
50932-
SmallVector<APInt> Bits, UserBits;
50933-
unsigned NumBits = std::min(RegVT.getScalarSizeInBits(),
50934-
UserVT.getScalarSizeInBits());
50935-
if (getTargetConstantBitsFromNode(SDValue(N, 0), NumBits, Undefs,
50936-
Bits) &&
50937-
getTargetConstantBitsFromNode(SDValue(User, 0), NumBits,
50938-
UserUndefs, UserBits)) {
50939-
if (MatchingBits(Undefs, UserUndefs, Bits, UserBits)) {
50940-
SDValue Extract = extractSubVector(
50941-
SDValue(User, 0), 0, DAG, SDLoc(N), RegVT.getSizeInBits());
50942-
Extract = DAG.getBitcast(RegVT, Extract);
50943-
return DCI.CombineTo(N, Extract, SDValue(User, 1));
50944-
}
50945-
}
50946-
}
50947-
}
50976+
SDValue Extract = extractSubVector(SDValue(User, 0), 0, DAG, dl,
50977+
RegVT.getSizeInBits());
50978+
Extract = DAG.getBitcast(RegVT, Extract);
50979+
return DCI.CombineTo(N, Extract, SDValue(User, 1));
5094850980
}
5094950981
}
5095050982
}
5095150983

50984+
if (SDValue V = combineConstantPoolLoads(Ld, dl, DAG, DCI, Subtarget))
50985+
return V;
50986+
5095250987
// Cast ptr32 and ptr64 pointers to the default address space before a load.
5095350988
unsigned AddrSpace = Ld->getAddressSpace();
5095450989
if (AddrSpace == X86AS::PTR64 || AddrSpace == X86AS::PTR32_SPTR ||

0 commit comments

Comments
 (0)