@@ -593,7 +593,7 @@ namespace {
593
593
SDValue MatchRotate(SDValue LHS, SDValue RHS, const SDLoc &DL);
594
594
SDValue MatchLoadCombine(SDNode *N);
595
595
SDValue mergeTruncStores(StoreSDNode *N);
596
- SDValue ReduceLoadWidth (SDNode *N);
596
+ SDValue reduceLoadWidth (SDNode *N);
597
597
SDValue ReduceLoadOpStoreWidth(SDNode *N);
598
598
SDValue splitMergedValStore(StoreSDNode *ST);
599
599
SDValue TransformFPLoadStorePair(SDNode *N);
@@ -5624,7 +5624,7 @@ bool DAGCombiner::BackwardsPropagateMask(SDNode *N) {
5624
5624
if (And.getOpcode() == ISD ::AND)
5625
5625
And = SDValue(
5626
5626
DAG.UpdateNodeOperands(And.getNode(), SDValue(Load, 0), MaskOp), 0);
5627
- SDValue NewLoad = ReduceLoadWidth (And.getNode());
5627
+ SDValue NewLoad = reduceLoadWidth (And.getNode());
5628
5628
assert(NewLoad &&
5629
5629
"Shouldn't be masking the load if it can't be narrowed");
5630
5630
CombineTo(Load, NewLoad, NewLoad.getValue(1));
@@ -6024,7 +6024,7 @@ SDValue DAGCombiner::visitAND(SDNode *N) {
6024
6024
if (!VT.isVector() && N1C && (N0.getOpcode() == ISD::LOAD ||
6025
6025
(N0.getOpcode() == ISD::ANY_EXTEND &&
6026
6026
N0.getOperand(0).getOpcode() == ISD::LOAD))) {
6027
- if (SDValue Res = ReduceLoadWidth (N)) {
6027
+ if (SDValue Res = reduceLoadWidth (N)) {
6028
6028
LoadSDNode *LN0 = N0->getOpcode() == ISD::ANY_EXTEND
6029
6029
? cast<LoadSDNode>(N0.getOperand(0)) : cast<LoadSDNode>(N0);
6030
6030
AddToWorklist(N);
@@ -9140,7 +9140,7 @@ SDValue DAGCombiner::visitSRL(SDNode *N) {
9140
9140
return NewSRL;
9141
9141
9142
9142
// Attempt to convert a srl of a load into a narrower zero-extending load.
9143
- if (SDValue NarrowLoad = ReduceLoadWidth (N))
9143
+ if (SDValue NarrowLoad = reduceLoadWidth (N))
9144
9144
return NarrowLoad;
9145
9145
9146
9146
// Here is a common situation. We want to optimize:
@@ -11357,7 +11357,7 @@ SDValue DAGCombiner::visitSIGN_EXTEND(SDNode *N) {
11357
11357
if (N0.getOpcode() == ISD::TRUNCATE) {
11358
11358
// fold (sext (truncate (load x))) -> (sext (smaller load x))
11359
11359
// fold (sext (truncate (srl (load x), c))) -> (sext (smaller load (x+c/n)))
11360
- if (SDValue NarrowLoad = ReduceLoadWidth (N0.getNode())) {
11360
+ if (SDValue NarrowLoad = reduceLoadWidth (N0.getNode())) {
11361
11361
SDNode *oye = N0.getOperand(0).getNode();
11362
11362
if (NarrowLoad.getNode() != N0.getNode()) {
11363
11363
CombineTo(N0.getNode(), NarrowLoad);
@@ -11621,7 +11621,7 @@ SDValue DAGCombiner::visitZERO_EXTEND(SDNode *N) {
11621
11621
if (N0.getOpcode() == ISD::TRUNCATE) {
11622
11622
// fold (zext (truncate (load x))) -> (zext (smaller load x))
11623
11623
// fold (zext (truncate (srl (load x), c))) -> (zext (smaller load (x+c/n)))
11624
- if (SDValue NarrowLoad = ReduceLoadWidth (N0.getNode())) {
11624
+ if (SDValue NarrowLoad = reduceLoadWidth (N0.getNode())) {
11625
11625
SDNode *oye = N0.getOperand(0).getNode();
11626
11626
if (NarrowLoad.getNode() != N0.getNode()) {
11627
11627
CombineTo(N0.getNode(), NarrowLoad);
@@ -11864,7 +11864,7 @@ SDValue DAGCombiner::visitANY_EXTEND(SDNode *N) {
11864
11864
// fold (aext (truncate (load x))) -> (aext (smaller load x))
11865
11865
// fold (aext (truncate (srl (load x), c))) -> (aext (small load (x+c/n)))
11866
11866
if (N0.getOpcode() == ISD::TRUNCATE) {
11867
- if (SDValue NarrowLoad = ReduceLoadWidth (N0.getNode())) {
11867
+ if (SDValue NarrowLoad = reduceLoadWidth (N0.getNode())) {
11868
11868
SDNode *oye = N0.getOperand(0).getNode();
11869
11869
if (NarrowLoad.getNode() != N0.getNode()) {
11870
11870
CombineTo(N0.getNode(), NarrowLoad);
@@ -12095,13 +12095,10 @@ SDValue DAGCombiner::visitAssertAlign(SDNode *N) {
12095
12095
return SDValue();
12096
12096
}
12097
12097
12098
- /// If the result of a wider load is shifted to right of N bits and then
12099
- /// truncated to a narrower type and where N is a multiple of number of bits of
12100
- /// the narrower type, transform it to a narrower load from address + N / num of
12101
- /// bits of new type. Also narrow the load if the result is masked with an AND
12102
- /// to effectively produce a smaller type. If the result is to be extended, also
12103
- /// fold the extension to form a extending load.
12104
- SDValue DAGCombiner::ReduceLoadWidth(SDNode *N) {
12098
+ /// If the result of a load is shifted/masked/truncated to an effectively
12099
+ /// narrower type, try to transform the load to a narrower type and/or
12100
+ /// use an extending load.
12101
+ SDValue DAGCombiner::reduceLoadWidth(SDNode *N) {
12105
12102
unsigned Opc = N->getOpcode();
12106
12103
12107
12104
ISD::LoadExtType ExtType = ISD::NON_EXTLOAD;
@@ -12113,7 +12110,14 @@ SDValue DAGCombiner::ReduceLoadWidth(SDNode *N) {
12113
12110
if (VT.isVector())
12114
12111
return SDValue();
12115
12112
12113
+ // The ShAmt variable is used to indicate that we've consumed a right
12114
+ // shift. I.e. we want to narrow the width of the load by skipping to load the
12115
+ // ShAmt least significant bits.
12116
12116
unsigned ShAmt = 0;
12117
+ // A special case is when the least significant bits from the load are masked
12118
+ // away, but using an AND rather than a right shift. HasShiftedOffset is used
12119
+ // to indicate that the narrowed load should be left-shifted ShAmt bits to get
12120
+ // the result.
12117
12121
bool HasShiftedOffset = false;
12118
12122
// Special case: SIGN_EXTEND_INREG is basically truncating to ExtVT then
12119
12123
// extended to VT.
@@ -12122,23 +12126,29 @@ SDValue DAGCombiner::ReduceLoadWidth(SDNode *N) {
12122
12126
ExtVT = cast<VTSDNode>(N->getOperand(1))->getVT();
12123
12127
} else if (Opc == ISD::SRL) {
12124
12128
// Another special-case: SRL is basically zero-extending a narrower value,
12125
- // or it maybe shifting a higher subword, half or byte into the lowest
12129
+ // or it may be shifting a higher subword, half or byte into the lowest
12126
12130
// bits.
12127
- ExtType = ISD::ZEXTLOAD;
12128
- N0 = SDValue(N, 0);
12129
12131
12130
- auto *LN0 = dyn_cast<LoadSDNode>(N0.getOperand(0));
12131
- auto *N01 = dyn_cast<ConstantSDNode>(N0.getOperand(1));
12132
- if (!N01 || !LN0)
12132
+ // Only handle shift with constant shift amount, and the shiftee must be a
12133
+ // load.
12134
+ auto *LN = dyn_cast<LoadSDNode>(N0);
12135
+ auto *N1C = dyn_cast<ConstantSDNode>(N->getOperand(1));
12136
+ if (!N1C || !LN)
12137
+ return SDValue();
12138
+ // If the shift amount is larger than the memory type then we're not
12139
+ // accessing any of the loaded bytes.
12140
+ ShAmt = N1C->getZExtValue();
12141
+ uint64_t MemoryWidth = LN->getMemoryVT().getScalarSizeInBits();
12142
+ if (MemoryWidth <= ShAmt)
12143
+ return SDValue();
12144
+ // Attempt to fold away the SRL by using ZEXTLOAD.
12145
+ ExtType = ISD::ZEXTLOAD;
12146
+ ExtVT = EVT::getIntegerVT(*DAG.getContext(), MemoryWidth - ShAmt);
12147
+ // If original load is a SEXTLOAD then we can't simply replace it by a
12148
+ // ZEXTLOAD (we could potentially replace it by a more narrow SEXTLOAD
12149
+ // followed by a ZEXT, but that is not handled at the moment).
12150
+ if (LN->getExtensionType() == ISD::SEXTLOAD)
12133
12151
return SDValue();
12134
-
12135
- uint64_t ShiftAmt = N01->getZExtValue();
12136
- uint64_t MemoryWidth = LN0->getMemoryVT().getScalarSizeInBits();
12137
- if (LN0->getExtensionType() != ISD::SEXTLOAD && MemoryWidth > ShiftAmt)
12138
- ExtVT = EVT::getIntegerVT(*DAG.getContext(), MemoryWidth - ShiftAmt);
12139
- else
12140
- ExtVT = EVT::getIntegerVT(*DAG.getContext(),
12141
- VT.getScalarSizeInBits() - ShiftAmt);
12142
12152
} else if (Opc == ISD::AND) {
12143
12153
// An AND with a constant mask is the same as a truncate + zero-extend.
12144
12154
auto AndC = dyn_cast<ConstantSDNode>(N->getOperand(1));
@@ -12161,55 +12171,73 @@ SDValue DAGCombiner::ReduceLoadWidth(SDNode *N) {
12161
12171
ExtVT = EVT::getIntegerVT(*DAG.getContext(), ActiveBits);
12162
12172
}
12163
12173
12164
- if (N0.getOpcode() == ISD::SRL && N0.hasOneUse()) {
12165
- SDValue SRL = N0;
12166
- if (auto *ConstShift = dyn_cast<ConstantSDNode>(SRL.getOperand(1))) {
12167
- ShAmt = ConstShift->getZExtValue();
12168
- unsigned EVTBits = ExtVT.getScalarSizeInBits();
12169
- // Is the shift amount a multiple of size of VT?
12170
- if ((ShAmt & (EVTBits-1)) == 0) {
12171
- N0 = N0.getOperand(0);
12172
- // Is the load width a multiple of size of VT?
12173
- if ((N0.getScalarValueSizeInBits() & (EVTBits - 1)) != 0)
12174
- return SDValue();
12175
- }
12174
+ // In case Opc==SRL we've already prepared ExtVT/ExtType/ShAmt based on doing
12175
+ // a right shift. Here we redo some of those checks, to possibly adjust the
12176
+ // ExtVT even further based on "a masking AND". We could also end up here for
12177
+ // other reasons (e.g. based on Opc==TRUNCATE) and that is why some checks
12178
+ // need to be done here as well.
12179
+ if (Opc == ISD::SRL || N0.getOpcode() == ISD::SRL) {
12180
+ SDValue SRL = Opc == ISD::SRL ? SDValue(N, 0) : N0;
12181
+ // Bail out when the SRL has more than one use. This is done for historical
12182
+ // (undocumented) reasons. Maybe intent was to guard the AND-masking below
12183
+ // check below? And maybe it could be non-profitable to do the transform in
12184
+ // case the SRL has multiple uses and we get here with Opc!=ISD::SRL?
12185
+ // FIXME: Can't we just skip this check for the Opc==ISD::SRL case.
12186
+ if (!SRL.hasOneUse())
12187
+ return SDValue();
12188
+
12189
+ // Only handle shift with constant shift amount, and the shiftee must be a
12190
+ // load.
12191
+ auto *LN = dyn_cast<LoadSDNode>(SRL.getOperand(0));
12192
+ auto *SRL1C = dyn_cast<ConstantSDNode>(SRL.getOperand(1));
12193
+ if (!SRL1C || !LN)
12194
+ return SDValue();
12176
12195
12177
- // At this point, we must have a load or else we can't do the transform.
12178
- auto *LN0 = dyn_cast<LoadSDNode>(N0);
12179
- if (!LN0) return SDValue();
12196
+ // If the shift amount is larger than the input type then we're not
12197
+ // accessing any of the loaded bytes. If the load was a zextload/extload
12198
+ // then the result of the shift+trunc is zero/undef (handled elsewhere).
12199
+ ShAmt = SRL1C->getZExtValue();
12200
+ if (ShAmt >= LN->getMemoryVT().getSizeInBits())
12201
+ return SDValue();
12180
12202
12181
- // Because a SRL must be assumed to *need* to zero-extend the high bits
12182
- // (as opposed to anyext the high bits), we can't combine the zextload
12183
- // lowering of SRL and an sextload.
12184
- if (LN0 ->getExtensionType() == ISD::SEXTLOAD)
12185
- return SDValue();
12203
+ // Because a SRL must be assumed to *need* to zero-extend the high bits
12204
+ // (as opposed to anyext the high bits), we can't combine the zextload
12205
+ // lowering of SRL and an sextload.
12206
+ if (LN ->getExtensionType() == ISD::SEXTLOAD)
12207
+ return SDValue();
12186
12208
12187
- // If the shift amount is larger than the input type then we're not
12188
- // accessing any of the loaded bytes. If the load was a zextload/extload
12189
- // then the result of the shift+trunc is zero/undef (handled elsewhere).
12190
- if (ShAmt >= LN0->getMemoryVT().getSizeInBits())
12191
- return SDValue();
12209
+ unsigned ExtVTBits = ExtVT.getScalarSizeInBits();
12210
+ // Is the shift amount a multiple of size of ExtVT?
12211
+ if ((ShAmt & (ExtVTBits - 1)) != 0)
12212
+ return SDValue();
12213
+ // Is the load width a multiple of size of ExtVT?
12214
+ if ((SRL.getScalarValueSizeInBits() & (ExtVTBits - 1)) != 0)
12215
+ return SDValue();
12192
12216
12193
- // If the SRL is only used by a masking AND, we may be able to adjust
12194
- // the ExtVT to make the AND redundant.
12195
- SDNode *Mask = *(SRL->use_begin());
12196
- if (Mask->getOpcode() == ISD::AND &&
12197
- isa<ConstantSDNode>(Mask->getOperand(1))) {
12198
- const APInt& ShiftMask = Mask->getConstantOperandAPInt(1);
12199
- if (ShiftMask.isMask()) {
12200
- EVT MaskedVT = EVT::getIntegerVT(*DAG.getContext(),
12201
- ShiftMask.countTrailingOnes());
12202
- // If the mask is smaller, recompute the type.
12203
- if ((ExtVT.getScalarSizeInBits() > MaskedVT.getScalarSizeInBits()) &&
12204
- TLI.isLoadExtLegal(ExtType, N0.getValueType(), MaskedVT))
12205
- ExtVT = MaskedVT;
12206
- }
12217
+ // If the SRL is only used by a masking AND, we may be able to adjust
12218
+ // the ExtVT to make the AND redundant.
12219
+ SDNode *Mask = *(SRL->use_begin());
12220
+ if (SRL.hasOneUse() && Mask->getOpcode() == ISD::AND &&
12221
+ isa<ConstantSDNode>(Mask->getOperand(1))) {
12222
+ const APInt& ShiftMask = Mask->getConstantOperandAPInt(1);
12223
+ if (ShiftMask.isMask()) {
12224
+ EVT MaskedVT = EVT::getIntegerVT(*DAG.getContext(),
12225
+ ShiftMask.countTrailingOnes());
12226
+ // If the mask is smaller, recompute the type.
12227
+ if ((ExtVTBits > MaskedVT.getScalarSizeInBits()) &&
12228
+ TLI.isLoadExtLegal(ExtType, SRL.getValueType(), MaskedVT))
12229
+ ExtVT = MaskedVT;
12207
12230
}
12208
12231
}
12232
+
12233
+ N0 = SRL.getOperand(0);
12209
12234
}
12210
12235
12211
- // If the load is shifted left (and the result isn't shifted back right),
12212
- // we can fold the truncate through the shift.
12236
+ // If the load is shifted left (and the result isn't shifted back right), we
12237
+ // can fold a truncate through the shift. The typical scenario is that N
12238
+ // points at a TRUNCATE here so the attempted fold is:
12239
+ // (truncate (shl (load x), c))) -> (shl (narrow load x), c)
12240
+ // ShLeftAmt will indicate how much a narrowed load should be shifted left.
12213
12241
unsigned ShLeftAmt = 0;
12214
12242
if (ShAmt == 0 && N0.getOpcode() == ISD::SHL && N0.hasOneUse() &&
12215
12243
ExtVT == VT && TLI.isNarrowingProfitable(N0.getValueType(), VT)) {
@@ -12237,12 +12265,12 @@ SDValue DAGCombiner::ReduceLoadWidth(SDNode *N) {
12237
12265
return LVTStoreBits - EVTStoreBits - ShAmt;
12238
12266
};
12239
12267
12240
- // For big endian targets, we need to adjust the offset to the pointer to
12241
- // load the correct bytes.
12242
- if (DAG.getDataLayout().isBigEndian())
12243
- ShAmt = AdjustBigEndianShift(ShAmt);
12268
+ // We need to adjust the pointer to the load by ShAmt bits in order to load
12269
+ // the correct bytes.
12270
+ unsigned PtrAdjustmentInBits =
12271
+ DAG.getDataLayout().isBigEndian() ? AdjustBigEndianShift(ShAmt) : ShAmt ;
12244
12272
12245
- uint64_t PtrOff = ShAmt / 8;
12273
+ uint64_t PtrOff = PtrAdjustmentInBits / 8;
12246
12274
Align NewAlign = commonAlignment(LN0->getAlign(), PtrOff);
12247
12275
SDLoc DL(LN0);
12248
12276
// The original load itself didn't wrap, so an offset within it doesn't.
@@ -12285,11 +12313,6 @@ SDValue DAGCombiner::ReduceLoadWidth(SDNode *N) {
12285
12313
}
12286
12314
12287
12315
if (HasShiftedOffset) {
12288
- // Recalculate the shift amount after it has been altered to calculate
12289
- // the offset.
12290
- if (DAG.getDataLayout().isBigEndian())
12291
- ShAmt = AdjustBigEndianShift(ShAmt);
12292
-
12293
12316
// We're using a shifted mask, so the load now has an offset. This means
12294
12317
// that data has been loaded into the lower bytes than it would have been
12295
12318
// before, so we need to shl the loaded data into the correct position in the
@@ -12382,7 +12405,7 @@ SDValue DAGCombiner::visitSIGN_EXTEND_INREG(SDNode *N) {
12382
12405
12383
12406
// fold (sext_in_reg (load x)) -> (smaller sextload x)
12384
12407
// fold (sext_in_reg (srl (load x), c)) -> (smaller sextload (x+c/evtbits))
12385
- if (SDValue NarrowLoad = ReduceLoadWidth (N))
12408
+ if (SDValue NarrowLoad = reduceLoadWidth (N))
12386
12409
return NarrowLoad;
12387
12410
12388
12411
// fold (sext_in_reg (srl X, 24), i8) -> (sra X, 24)
@@ -12669,7 +12692,7 @@ SDValue DAGCombiner::visitTRUNCATE(SDNode *N) {
12669
12692
// fold (truncate (load x)) -> (smaller load x)
12670
12693
// fold (truncate (srl (load x), c)) -> (smaller load (x+c/evtbits))
12671
12694
if (!LegalTypes || TLI.isTypeDesirableForOp(N0.getOpcode(), VT)) {
12672
- if (SDValue Reduced = ReduceLoadWidth (N))
12695
+ if (SDValue Reduced = reduceLoadWidth (N))
12673
12696
return Reduced;
12674
12697
12675
12698
// Handle the case where the load remains an extending load even
0 commit comments