@@ -50823,10 +50823,83 @@ static SDValue detectAVGPattern(SDValue In, EVT VT, SelectionDAG &DAG,
50823
50823
return SDValue();
50824
50824
}
50825
50825
50826
+ static SDValue combineConstantPoolLoads(SDNode *N, const SDLoc &dl,
50827
+ SelectionDAG &DAG,
50828
+ TargetLowering::DAGCombinerInfo &DCI,
50829
+ const X86Subtarget &Subtarget) {
50830
+ auto *Ld = cast<LoadSDNode>(N);
50831
+ EVT RegVT = Ld->getValueType(0);
50832
+ EVT MemVT = Ld->getMemoryVT();
50833
+ SDValue Ptr = Ld->getBasePtr();
50834
+ SDValue Chain = Ld->getChain();
50835
+ ISD::LoadExtType Ext = Ld->getExtensionType();
50836
+
50837
+ if (Ext != ISD::NON_EXTLOAD || !Subtarget.hasAVX() || !Ld->isSimple())
50838
+ return SDValue();
50839
+
50840
+ if (!(RegVT.is128BitVector() || RegVT.is256BitVector()))
50841
+ return SDValue();
50842
+
50843
+ auto MatchingBits = [](const APInt &Undefs, const APInt &UserUndefs,
50844
+ ArrayRef<APInt> Bits, ArrayRef<APInt> UserBits) {
50845
+ for (unsigned I = 0, E = Undefs.getBitWidth(); I != E; ++I) {
50846
+ if (Undefs[I])
50847
+ continue;
50848
+ if (UserUndefs[I] || Bits[I] != UserBits[I])
50849
+ return false;
50850
+ }
50851
+ return true;
50852
+ };
50853
+
50854
+ // Look through all other loads/broadcasts in the chain for another constant
50855
+ // pool entry.
50856
+ for (SDNode *User : Chain->uses()) {
50857
+ auto *UserLd = dyn_cast<MemSDNode>(User);
50858
+ if (User != N && UserLd &&
50859
+ (User->getOpcode() == X86ISD::SUBV_BROADCAST_LOAD ||
50860
+ User->getOpcode() == X86ISD::VBROADCAST_LOAD ||
50861
+ ISD::isNormalLoad(User)) &&
50862
+ UserLd->getChain() == Chain && !User->hasAnyUseOfValue(1) &&
50863
+ User->getValueSizeInBits(0).getFixedValue() >
50864
+ RegVT.getFixedSizeInBits()) {
50865
+ EVT UserVT = User->getValueType(0);
50866
+ SDValue UserPtr = UserLd->getBasePtr();
50867
+ const Constant *LdC = getTargetConstantFromBasePtr(Ptr);
50868
+ const Constant *UserC = getTargetConstantFromBasePtr(UserPtr);
50869
+
50870
+ // See if we are loading a constant that matches in the lower
50871
+ // bits of a longer constant (but from a different constant pool ptr).
50872
+ if (LdC && UserC && UserPtr != Ptr) {
50873
+ unsigned LdSize = LdC->getType()->getPrimitiveSizeInBits();
50874
+ unsigned UserSize = UserC->getType()->getPrimitiveSizeInBits();
50875
+ if (LdSize < UserSize || !ISD::isNormalLoad(User)) {
50876
+ APInt Undefs, UserUndefs;
50877
+ SmallVector<APInt> Bits, UserBits;
50878
+ unsigned NumBits = std::min(RegVT.getScalarSizeInBits(),
50879
+ UserVT.getScalarSizeInBits());
50880
+ if (getTargetConstantBitsFromNode(SDValue(N, 0), NumBits, Undefs,
50881
+ Bits) &&
50882
+ getTargetConstantBitsFromNode(SDValue(User, 0), NumBits,
50883
+ UserUndefs, UserBits)) {
50884
+ if (MatchingBits(Undefs, UserUndefs, Bits, UserBits)) {
50885
+ SDValue Extract = extractSubVector(
50886
+ SDValue(User, 0), 0, DAG, SDLoc(N), RegVT.getSizeInBits());
50887
+ Extract = DAG.getBitcast(RegVT, Extract);
50888
+ return DCI.CombineTo(N, Extract, SDValue(User, 1));
50889
+ }
50890
+ }
50891
+ }
50892
+ }
50893
+ }
50894
+ }
50895
+
50896
+ return SDValue();
50897
+ }
50898
+
50826
50899
static SDValue combineLoad(SDNode *N, SelectionDAG &DAG,
50827
50900
TargetLowering::DAGCombinerInfo &DCI,
50828
50901
const X86Subtarget &Subtarget) {
50829
- LoadSDNode *Ld = cast<LoadSDNode>(N);
50902
+ auto *Ld = cast<LoadSDNode>(N);
50830
50903
EVT RegVT = Ld->getValueType(0);
50831
50904
EVT MemVT = Ld->getMemoryVT();
50832
50905
SDLoc dl(Ld);
@@ -50885,7 +50958,7 @@ static SDValue combineLoad(SDNode *N, SelectionDAG &DAG,
50885
50958
}
50886
50959
}
50887
50960
50888
- // If we also load/ broadcast this to a wider type, then just extract the
50961
+ // If we also broadcast this vector to a wider type, then just extract the
50889
50962
// lowest subvector.
50890
50963
if (Ext == ISD::NON_EXTLOAD && Subtarget.hasAVX() && Ld->isSimple() &&
50891
50964
(RegVT.is128BitVector() || RegVT.is256BitVector())) {
@@ -50894,61 +50967,23 @@ static SDValue combineLoad(SDNode *N, SelectionDAG &DAG,
50894
50967
for (SDNode *User : Chain->uses()) {
50895
50968
auto *UserLd = dyn_cast<MemSDNode>(User);
50896
50969
if (User != N && UserLd &&
50897
- ( User->getOpcode() == X86ISD::SUBV_BROADCAST_LOAD ||
50898
- User->getOpcode () == X86ISD::VBROADCAST_LOAD ||
50899
- ISD::isNormalLoad(User) ) &&
50900
- UserLd->getChain() == Chain && !User->hasAnyUseOfValue(1) &&
50970
+ User->getOpcode() == X86ISD::SUBV_BROADCAST_LOAD &&
50971
+ UserLd->getChain() == Chain && UserLd->getBasePtr () == Ptr &&
50972
+ UserLd->getMemoryVT().getSizeInBits() == MemVT.getSizeInBits( ) &&
50973
+ !User->hasAnyUseOfValue(1) &&
50901
50974
User->getValueSizeInBits(0).getFixedValue() >
50902
50975
RegVT.getFixedSizeInBits()) {
50903
- if (User->getOpcode() == X86ISD::SUBV_BROADCAST_LOAD &&
50904
- UserLd->getBasePtr() == Ptr &&
50905
- UserLd->getMemoryVT().getSizeInBits() == MemVT.getSizeInBits()) {
50906
- SDValue Extract = extractSubVector(SDValue(User, 0), 0, DAG, SDLoc(N),
50907
- RegVT.getSizeInBits());
50908
- Extract = DAG.getBitcast(RegVT, Extract);
50909
- return DCI.CombineTo(N, Extract, SDValue(User, 1));
50910
- }
50911
- auto MatchingBits = [](const APInt &Undefs, const APInt &UserUndefs,
50912
- ArrayRef<APInt> Bits, ArrayRef<APInt> UserBits) {
50913
- for (unsigned I = 0, E = Undefs.getBitWidth(); I != E; ++I) {
50914
- if (Undefs[I])
50915
- continue;
50916
- if (UserUndefs[I] || Bits[I] != UserBits[I])
50917
- return false;
50918
- }
50919
- return true;
50920
- };
50921
- // See if we are loading a constant that matches in the lower
50922
- // bits of a longer constant (but from a different constant pool ptr).
50923
- EVT UserVT = User->getValueType(0);
50924
- SDValue UserPtr = UserLd->getBasePtr();
50925
- const Constant *LdC = getTargetConstantFromBasePtr(Ptr);
50926
- const Constant *UserC = getTargetConstantFromBasePtr(UserPtr);
50927
- if (LdC && UserC && UserPtr != Ptr) {
50928
- unsigned LdSize = LdC->getType()->getPrimitiveSizeInBits();
50929
- unsigned UserSize = UserC->getType()->getPrimitiveSizeInBits();
50930
- if (LdSize < UserSize || !ISD::isNormalLoad(User)) {
50931
- APInt Undefs, UserUndefs;
50932
- SmallVector<APInt> Bits, UserBits;
50933
- unsigned NumBits = std::min(RegVT.getScalarSizeInBits(),
50934
- UserVT.getScalarSizeInBits());
50935
- if (getTargetConstantBitsFromNode(SDValue(N, 0), NumBits, Undefs,
50936
- Bits) &&
50937
- getTargetConstantBitsFromNode(SDValue(User, 0), NumBits,
50938
- UserUndefs, UserBits)) {
50939
- if (MatchingBits(Undefs, UserUndefs, Bits, UserBits)) {
50940
- SDValue Extract = extractSubVector(
50941
- SDValue(User, 0), 0, DAG, SDLoc(N), RegVT.getSizeInBits());
50942
- Extract = DAG.getBitcast(RegVT, Extract);
50943
- return DCI.CombineTo(N, Extract, SDValue(User, 1));
50944
- }
50945
- }
50946
- }
50947
- }
50976
+ SDValue Extract = extractSubVector(SDValue(User, 0), 0, DAG, dl,
50977
+ RegVT.getSizeInBits());
50978
+ Extract = DAG.getBitcast(RegVT, Extract);
50979
+ return DCI.CombineTo(N, Extract, SDValue(User, 1));
50948
50980
}
50949
50981
}
50950
50982
}
50951
50983
50984
+ if (SDValue V = combineConstantPoolLoads(Ld, dl, DAG, DCI, Subtarget))
50985
+ return V;
50986
+
50952
50987
// Cast ptr32 and ptr64 pointers to the default address space before a load.
50953
50988
unsigned AddrSpace = Ld->getAddressSpace();
50954
50989
if (AddrSpace == X86AS::PTR64 || AddrSpace == X86AS::PTR32_SPTR ||
0 commit comments