Skip to content

Commit 9f237c9

Browse files
committed
[DAGCombine] Refactor DAGCombiner::ReduceLoadWidth. NFCI
Update code comments in DAGCombiner::ReduceLoadWidth and refactor the handling of SRL a bit. The refactoring is done with the intent of adding support for folding away SRA by using SEXTLOAD in a follow-up patch. The function is also renamed as DAGCombiner::reduceLoadWidth. Differential Revision: https://reviews.llvm.org/D117104
1 parent 37e6496 commit 9f237c9

File tree

1 file changed

+104
-81
lines changed

1 file changed

+104
-81
lines changed

llvm/lib/CodeGen/SelectionDAG/DAGCombiner.cpp

Lines changed: 104 additions & 81 deletions
Original file line numberDiff line numberDiff line change
@@ -593,7 +593,7 @@ namespace {
593593
SDValue MatchRotate(SDValue LHS, SDValue RHS, const SDLoc &DL);
594594
SDValue MatchLoadCombine(SDNode *N);
595595
SDValue mergeTruncStores(StoreSDNode *N);
596-
SDValue ReduceLoadWidth(SDNode *N);
596+
SDValue reduceLoadWidth(SDNode *N);
597597
SDValue ReduceLoadOpStoreWidth(SDNode *N);
598598
SDValue splitMergedValStore(StoreSDNode *ST);
599599
SDValue TransformFPLoadStorePair(SDNode *N);
@@ -5624,7 +5624,7 @@ bool DAGCombiner::BackwardsPropagateMask(SDNode *N) {
56245624
if (And.getOpcode() == ISD ::AND)
56255625
And = SDValue(
56265626
DAG.UpdateNodeOperands(And.getNode(), SDValue(Load, 0), MaskOp), 0);
5627-
SDValue NewLoad = ReduceLoadWidth(And.getNode());
5627+
SDValue NewLoad = reduceLoadWidth(And.getNode());
56285628
assert(NewLoad &&
56295629
"Shouldn't be masking the load if it can't be narrowed");
56305630
CombineTo(Load, NewLoad, NewLoad.getValue(1));
@@ -6024,7 +6024,7 @@ SDValue DAGCombiner::visitAND(SDNode *N) {
60246024
if (!VT.isVector() && N1C && (N0.getOpcode() == ISD::LOAD ||
60256025
(N0.getOpcode() == ISD::ANY_EXTEND &&
60266026
N0.getOperand(0).getOpcode() == ISD::LOAD))) {
6027-
if (SDValue Res = ReduceLoadWidth(N)) {
6027+
if (SDValue Res = reduceLoadWidth(N)) {
60286028
LoadSDNode *LN0 = N0->getOpcode() == ISD::ANY_EXTEND
60296029
? cast<LoadSDNode>(N0.getOperand(0)) : cast<LoadSDNode>(N0);
60306030
AddToWorklist(N);
@@ -9140,7 +9140,7 @@ SDValue DAGCombiner::visitSRL(SDNode *N) {
91409140
return NewSRL;
91419141

91429142
// Attempt to convert a srl of a load into a narrower zero-extending load.
9143-
if (SDValue NarrowLoad = ReduceLoadWidth(N))
9143+
if (SDValue NarrowLoad = reduceLoadWidth(N))
91449144
return NarrowLoad;
91459145

91469146
// Here is a common situation. We want to optimize:
@@ -11357,7 +11357,7 @@ SDValue DAGCombiner::visitSIGN_EXTEND(SDNode *N) {
1135711357
if (N0.getOpcode() == ISD::TRUNCATE) {
1135811358
// fold (sext (truncate (load x))) -> (sext (smaller load x))
1135911359
// fold (sext (truncate (srl (load x), c))) -> (sext (smaller load (x+c/n)))
11360-
if (SDValue NarrowLoad = ReduceLoadWidth(N0.getNode())) {
11360+
if (SDValue NarrowLoad = reduceLoadWidth(N0.getNode())) {
1136111361
SDNode *oye = N0.getOperand(0).getNode();
1136211362
if (NarrowLoad.getNode() != N0.getNode()) {
1136311363
CombineTo(N0.getNode(), NarrowLoad);
@@ -11621,7 +11621,7 @@ SDValue DAGCombiner::visitZERO_EXTEND(SDNode *N) {
1162111621
if (N0.getOpcode() == ISD::TRUNCATE) {
1162211622
// fold (zext (truncate (load x))) -> (zext (smaller load x))
1162311623
// fold (zext (truncate (srl (load x), c))) -> (zext (smaller load (x+c/n)))
11624-
if (SDValue NarrowLoad = ReduceLoadWidth(N0.getNode())) {
11624+
if (SDValue NarrowLoad = reduceLoadWidth(N0.getNode())) {
1162511625
SDNode *oye = N0.getOperand(0).getNode();
1162611626
if (NarrowLoad.getNode() != N0.getNode()) {
1162711627
CombineTo(N0.getNode(), NarrowLoad);
@@ -11864,7 +11864,7 @@ SDValue DAGCombiner::visitANY_EXTEND(SDNode *N) {
1186411864
// fold (aext (truncate (load x))) -> (aext (smaller load x))
1186511865
// fold (aext (truncate (srl (load x), c))) -> (aext (small load (x+c/n)))
1186611866
if (N0.getOpcode() == ISD::TRUNCATE) {
11867-
if (SDValue NarrowLoad = ReduceLoadWidth(N0.getNode())) {
11867+
if (SDValue NarrowLoad = reduceLoadWidth(N0.getNode())) {
1186811868
SDNode *oye = N0.getOperand(0).getNode();
1186911869
if (NarrowLoad.getNode() != N0.getNode()) {
1187011870
CombineTo(N0.getNode(), NarrowLoad);
@@ -12095,13 +12095,10 @@ SDValue DAGCombiner::visitAssertAlign(SDNode *N) {
1209512095
return SDValue();
1209612096
}
1209712097

12098-
/// If the result of a wider load is shifted to right of N bits and then
12099-
/// truncated to a narrower type and where N is a multiple of number of bits of
12100-
/// the narrower type, transform it to a narrower load from address + N / num of
12101-
/// bits of new type. Also narrow the load if the result is masked with an AND
12102-
/// to effectively produce a smaller type. If the result is to be extended, also
12103-
/// fold the extension to form a extending load.
12104-
SDValue DAGCombiner::ReduceLoadWidth(SDNode *N) {
12098+
/// If the result of a load is shifted/masked/truncated to an effectively
12099+
/// narrower type, try to transform the load to a narrower type and/or
12100+
/// use an extending load.
12101+
SDValue DAGCombiner::reduceLoadWidth(SDNode *N) {
1210512102
unsigned Opc = N->getOpcode();
1210612103

1210712104
ISD::LoadExtType ExtType = ISD::NON_EXTLOAD;
@@ -12113,7 +12110,14 @@ SDValue DAGCombiner::ReduceLoadWidth(SDNode *N) {
1211312110
if (VT.isVector())
1211412111
return SDValue();
1211512112

12113+
// The ShAmt variable is used to indicate that we've consumed a right
12114+
// shift. I.e. we want to narrow the width of the load by skipping to load the
12115+
// ShAmt least significant bits.
1211612116
unsigned ShAmt = 0;
12117+
// A special case is when the least significant bits from the load are masked
12118+
// away, but using an AND rather than a right shift. HasShiftedOffset is used
12119+
// to indicate that the narrowed load should be left-shifted ShAmt bits to get
12120+
// the result.
1211712121
bool HasShiftedOffset = false;
1211812122
// Special case: SIGN_EXTEND_INREG is basically truncating to ExtVT then
1211912123
// extended to VT.
@@ -12122,23 +12126,29 @@ SDValue DAGCombiner::ReduceLoadWidth(SDNode *N) {
1212212126
ExtVT = cast<VTSDNode>(N->getOperand(1))->getVT();
1212312127
} else if (Opc == ISD::SRL) {
1212412128
// Another special-case: SRL is basically zero-extending a narrower value,
12125-
// or it maybe shifting a higher subword, half or byte into the lowest
12129+
// or it may be shifting a higher subword, half or byte into the lowest
1212612130
// bits.
12127-
ExtType = ISD::ZEXTLOAD;
12128-
N0 = SDValue(N, 0);
1212912131

12130-
auto *LN0 = dyn_cast<LoadSDNode>(N0.getOperand(0));
12131-
auto *N01 = dyn_cast<ConstantSDNode>(N0.getOperand(1));
12132-
if (!N01 || !LN0)
12132+
// Only handle shift with constant shift amount, and the shiftee must be a
12133+
// load.
12134+
auto *LN = dyn_cast<LoadSDNode>(N0);
12135+
auto *N1C = dyn_cast<ConstantSDNode>(N->getOperand(1));
12136+
if (!N1C || !LN)
12137+
return SDValue();
12138+
// If the shift amount is larger than the memory type then we're not
12139+
// accessing any of the loaded bytes.
12140+
ShAmt = N1C->getZExtValue();
12141+
uint64_t MemoryWidth = LN->getMemoryVT().getScalarSizeInBits();
12142+
if (MemoryWidth <= ShAmt)
12143+
return SDValue();
12144+
// Attempt to fold away the SRL by using ZEXTLOAD.
12145+
ExtType = ISD::ZEXTLOAD;
12146+
ExtVT = EVT::getIntegerVT(*DAG.getContext(), MemoryWidth - ShAmt);
12147+
// If original load is a SEXTLOAD then we can't simply replace it by a
12148+
// ZEXTLOAD (we could potentially replace it by a more narrow SEXTLOAD
12149+
// followed by a ZEXT, but that is not handled at the moment).
12150+
if (LN->getExtensionType() == ISD::SEXTLOAD)
1213312151
return SDValue();
12134-
12135-
uint64_t ShiftAmt = N01->getZExtValue();
12136-
uint64_t MemoryWidth = LN0->getMemoryVT().getScalarSizeInBits();
12137-
if (LN0->getExtensionType() != ISD::SEXTLOAD && MemoryWidth > ShiftAmt)
12138-
ExtVT = EVT::getIntegerVT(*DAG.getContext(), MemoryWidth - ShiftAmt);
12139-
else
12140-
ExtVT = EVT::getIntegerVT(*DAG.getContext(),
12141-
VT.getScalarSizeInBits() - ShiftAmt);
1214212152
} else if (Opc == ISD::AND) {
1214312153
// An AND with a constant mask is the same as a truncate + zero-extend.
1214412154
auto AndC = dyn_cast<ConstantSDNode>(N->getOperand(1));
@@ -12161,55 +12171,73 @@ SDValue DAGCombiner::ReduceLoadWidth(SDNode *N) {
1216112171
ExtVT = EVT::getIntegerVT(*DAG.getContext(), ActiveBits);
1216212172
}
1216312173

12164-
if (N0.getOpcode() == ISD::SRL && N0.hasOneUse()) {
12165-
SDValue SRL = N0;
12166-
if (auto *ConstShift = dyn_cast<ConstantSDNode>(SRL.getOperand(1))) {
12167-
ShAmt = ConstShift->getZExtValue();
12168-
unsigned EVTBits = ExtVT.getScalarSizeInBits();
12169-
// Is the shift amount a multiple of size of VT?
12170-
if ((ShAmt & (EVTBits-1)) == 0) {
12171-
N0 = N0.getOperand(0);
12172-
// Is the load width a multiple of size of VT?
12173-
if ((N0.getScalarValueSizeInBits() & (EVTBits - 1)) != 0)
12174-
return SDValue();
12175-
}
12174+
// In case Opc==SRL we've already prepared ExtVT/ExtType/ShAmt based on doing
12175+
// a right shift. Here we redo some of those checks, to possibly adjust the
12176+
// ExtVT even further based on "a masking AND". We could also end up here for
12177+
// other reasons (e.g. based on Opc==TRUNCATE) and that is why some checks
12178+
// need to be done here as well.
12179+
if (Opc == ISD::SRL || N0.getOpcode() == ISD::SRL) {
12180+
SDValue SRL = Opc == ISD::SRL ? SDValue(N, 0) : N0;
12181+
// Bail out when the SRL has more than one use. This is done for historical
12182+
// (undocumented) reasons. Maybe intent was to guard the AND-masking below
12183+
// check below? And maybe it could be non-profitable to do the transform in
12184+
// case the SRL has multiple uses and we get here with Opc!=ISD::SRL?
12185+
// FIXME: Can't we just skip this check for the Opc==ISD::SRL case.
12186+
if (!SRL.hasOneUse())
12187+
return SDValue();
12188+
12189+
// Only handle shift with constant shift amount, and the shiftee must be a
12190+
// load.
12191+
auto *LN = dyn_cast<LoadSDNode>(SRL.getOperand(0));
12192+
auto *SRL1C = dyn_cast<ConstantSDNode>(SRL.getOperand(1));
12193+
if (!SRL1C || !LN)
12194+
return SDValue();
1217612195

12177-
// At this point, we must have a load or else we can't do the transform.
12178-
auto *LN0 = dyn_cast<LoadSDNode>(N0);
12179-
if (!LN0) return SDValue();
12196+
// If the shift amount is larger than the input type then we're not
12197+
// accessing any of the loaded bytes. If the load was a zextload/extload
12198+
// then the result of the shift+trunc is zero/undef (handled elsewhere).
12199+
ShAmt = SRL1C->getZExtValue();
12200+
if (ShAmt >= LN->getMemoryVT().getSizeInBits())
12201+
return SDValue();
1218012202

12181-
// Because a SRL must be assumed to *need* to zero-extend the high bits
12182-
// (as opposed to anyext the high bits), we can't combine the zextload
12183-
// lowering of SRL and an sextload.
12184-
if (LN0->getExtensionType() == ISD::SEXTLOAD)
12185-
return SDValue();
12203+
// Because a SRL must be assumed to *need* to zero-extend the high bits
12204+
// (as opposed to anyext the high bits), we can't combine the zextload
12205+
// lowering of SRL and an sextload.
12206+
if (LN->getExtensionType() == ISD::SEXTLOAD)
12207+
return SDValue();
1218612208

12187-
// If the shift amount is larger than the input type then we're not
12188-
// accessing any of the loaded bytes. If the load was a zextload/extload
12189-
// then the result of the shift+trunc is zero/undef (handled elsewhere).
12190-
if (ShAmt >= LN0->getMemoryVT().getSizeInBits())
12191-
return SDValue();
12209+
unsigned ExtVTBits = ExtVT.getScalarSizeInBits();
12210+
// Is the shift amount a multiple of size of ExtVT?
12211+
if ((ShAmt & (ExtVTBits - 1)) != 0)
12212+
return SDValue();
12213+
// Is the load width a multiple of size of ExtVT?
12214+
if ((SRL.getScalarValueSizeInBits() & (ExtVTBits - 1)) != 0)
12215+
return SDValue();
1219212216

12193-
// If the SRL is only used by a masking AND, we may be able to adjust
12194-
// the ExtVT to make the AND redundant.
12195-
SDNode *Mask = *(SRL->use_begin());
12196-
if (Mask->getOpcode() == ISD::AND &&
12197-
isa<ConstantSDNode>(Mask->getOperand(1))) {
12198-
const APInt& ShiftMask = Mask->getConstantOperandAPInt(1);
12199-
if (ShiftMask.isMask()) {
12200-
EVT MaskedVT = EVT::getIntegerVT(*DAG.getContext(),
12201-
ShiftMask.countTrailingOnes());
12202-
// If the mask is smaller, recompute the type.
12203-
if ((ExtVT.getScalarSizeInBits() > MaskedVT.getScalarSizeInBits()) &&
12204-
TLI.isLoadExtLegal(ExtType, N0.getValueType(), MaskedVT))
12205-
ExtVT = MaskedVT;
12206-
}
12217+
// If the SRL is only used by a masking AND, we may be able to adjust
12218+
// the ExtVT to make the AND redundant.
12219+
SDNode *Mask = *(SRL->use_begin());
12220+
if (SRL.hasOneUse() && Mask->getOpcode() == ISD::AND &&
12221+
isa<ConstantSDNode>(Mask->getOperand(1))) {
12222+
const APInt& ShiftMask = Mask->getConstantOperandAPInt(1);
12223+
if (ShiftMask.isMask()) {
12224+
EVT MaskedVT = EVT::getIntegerVT(*DAG.getContext(),
12225+
ShiftMask.countTrailingOnes());
12226+
// If the mask is smaller, recompute the type.
12227+
if ((ExtVTBits > MaskedVT.getScalarSizeInBits()) &&
12228+
TLI.isLoadExtLegal(ExtType, SRL.getValueType(), MaskedVT))
12229+
ExtVT = MaskedVT;
1220712230
}
1220812231
}
12232+
12233+
N0 = SRL.getOperand(0);
1220912234
}
1221012235

12211-
// If the load is shifted left (and the result isn't shifted back right),
12212-
// we can fold the truncate through the shift.
12236+
// If the load is shifted left (and the result isn't shifted back right), we
12237+
// can fold a truncate through the shift. The typical scenario is that N
12238+
// points at a TRUNCATE here so the attempted fold is:
12239+
// (truncate (shl (load x), c))) -> (shl (narrow load x), c)
12240+
// ShLeftAmt will indicate how much a narrowed load should be shifted left.
1221312241
unsigned ShLeftAmt = 0;
1221412242
if (ShAmt == 0 && N0.getOpcode() == ISD::SHL && N0.hasOneUse() &&
1221512243
ExtVT == VT && TLI.isNarrowingProfitable(N0.getValueType(), VT)) {
@@ -12237,12 +12265,12 @@ SDValue DAGCombiner::ReduceLoadWidth(SDNode *N) {
1223712265
return LVTStoreBits - EVTStoreBits - ShAmt;
1223812266
};
1223912267

12240-
// For big endian targets, we need to adjust the offset to the pointer to
12241-
// load the correct bytes.
12242-
if (DAG.getDataLayout().isBigEndian())
12243-
ShAmt = AdjustBigEndianShift(ShAmt);
12268+
// We need to adjust the pointer to the load by ShAmt bits in order to load
12269+
// the correct bytes.
12270+
unsigned PtrAdjustmentInBits =
12271+
DAG.getDataLayout().isBigEndian() ? AdjustBigEndianShift(ShAmt) : ShAmt;
1224412272

12245-
uint64_t PtrOff = ShAmt / 8;
12273+
uint64_t PtrOff = PtrAdjustmentInBits / 8;
1224612274
Align NewAlign = commonAlignment(LN0->getAlign(), PtrOff);
1224712275
SDLoc DL(LN0);
1224812276
// The original load itself didn't wrap, so an offset within it doesn't.
@@ -12285,11 +12313,6 @@ SDValue DAGCombiner::ReduceLoadWidth(SDNode *N) {
1228512313
}
1228612314

1228712315
if (HasShiftedOffset) {
12288-
// Recalculate the shift amount after it has been altered to calculate
12289-
// the offset.
12290-
if (DAG.getDataLayout().isBigEndian())
12291-
ShAmt = AdjustBigEndianShift(ShAmt);
12292-
1229312316
// We're using a shifted mask, so the load now has an offset. This means
1229412317
// that data has been loaded into the lower bytes than it would have been
1229512318
// before, so we need to shl the loaded data into the correct position in the
@@ -12382,7 +12405,7 @@ SDValue DAGCombiner::visitSIGN_EXTEND_INREG(SDNode *N) {
1238212405

1238312406
// fold (sext_in_reg (load x)) -> (smaller sextload x)
1238412407
// fold (sext_in_reg (srl (load x), c)) -> (smaller sextload (x+c/evtbits))
12385-
if (SDValue NarrowLoad = ReduceLoadWidth(N))
12408+
if (SDValue NarrowLoad = reduceLoadWidth(N))
1238612409
return NarrowLoad;
1238712410

1238812411
// fold (sext_in_reg (srl X, 24), i8) -> (sra X, 24)
@@ -12669,7 +12692,7 @@ SDValue DAGCombiner::visitTRUNCATE(SDNode *N) {
1266912692
// fold (truncate (load x)) -> (smaller load x)
1267012693
// fold (truncate (srl (load x), c)) -> (smaller load (x+c/evtbits))
1267112694
if (!LegalTypes || TLI.isTypeDesirableForOp(N0.getOpcode(), VT)) {
12672-
if (SDValue Reduced = ReduceLoadWidth(N))
12695+
if (SDValue Reduced = reduceLoadWidth(N))
1267312696
return Reduced;
1267412697

1267512698
// Handle the case where the load remains an extending load even

0 commit comments

Comments
 (0)