Skip to content

Commit 225fc4f

Browse files
authored
[AMDGPU][SDAG] Try folding "lshr i64 + mad" to "mad_u64_u32" (#119218)
The intention is to use a "copy" instead of a "sub" to handle the high parts of 64-bit multiply for this specific case. This unlocks copy prop use cases where the copy can be reused by later multiply+add sequences if possible. Fixes: SWDEV-487672, SWDEV-487669
1 parent f999b11 commit 225fc4f

File tree

2 files changed

+190
-115
lines changed

2 files changed

+190
-115
lines changed

llvm/lib/Target/AMDGPU/SIISelLowering.cpp

Lines changed: 34 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -13884,6 +13884,37 @@ static SDValue getMad64_32(SelectionDAG &DAG, const SDLoc &SL, EVT VT,
1388413884
return DAG.getNode(ISD::TRUNCATE, SL, VT, Mad);
1388513885
}
1388613886

13887+
// Fold
13888+
// y = lshr i64 x, 32
13889+
// res = add (mul i64 y, Const), x where "Const" is a 64-bit constant
13890+
// with Const.hi == -1
13891+
// To
13892+
// res = mad_u64_u32 y.lo ,Const.lo, x.lo
13893+
static SDValue tryFoldMADwithSRL(SelectionDAG &DAG, const SDLoc &SL,
13894+
SDValue MulLHS, SDValue MulRHS,
13895+
SDValue AddRHS) {
13896+
if (MulRHS.getOpcode() == ISD::SRL)
13897+
std::swap(MulLHS, MulRHS);
13898+
13899+
if (MulLHS.getValueType() != MVT::i64 || MulLHS.getOpcode() != ISD::SRL)
13900+
return SDValue();
13901+
13902+
ConstantSDNode *ShiftVal = dyn_cast<ConstantSDNode>(MulLHS.getOperand(1));
13903+
if (!ShiftVal || ShiftVal->getAsZExtVal() != 32 ||
13904+
MulLHS.getOperand(0) != AddRHS)
13905+
return SDValue();
13906+
13907+
ConstantSDNode *Const = dyn_cast<ConstantSDNode>(MulRHS.getNode());
13908+
if (!Const || Hi_32(Const->getZExtValue()) != -1)
13909+
return SDValue();
13910+
13911+
SDValue ConstMul =
13912+
DAG.getConstant(Lo_32(Const->getZExtValue()), SL, MVT::i32);
13913+
return getMad64_32(DAG, SL, MVT::i64,
13914+
DAG.getNode(ISD::TRUNCATE, SL, MVT::i32, MulLHS), ConstMul,
13915+
DAG.getZeroExtendInReg(AddRHS, SL, MVT::i32), false);
13916+
}
13917+
1388713918
// Fold (add (mul x, y), z) --> (mad_[iu]64_[iu]32 x, y, z) plus high
1388813919
// multiplies, if any.
1388913920
//
@@ -13942,6 +13973,9 @@ SDValue SITargetLowering::tryFoldToMad64_32(SDNode *N,
1394213973
SDValue MulRHS = LHS.getOperand(1);
1394313974
SDValue AddRHS = RHS;
1394413975

13976+
if (SDValue FoldedMAD = tryFoldMADwithSRL(DAG, SL, MulLHS, MulRHS, AddRHS))
13977+
return FoldedMAD;
13978+
1394513979
// Always check whether operands are small unsigned values, since that
1394613980
// knowledge is useful in more cases. Check for small signed values only if
1394713981
// doing so can unlock a shorter code sequence.

0 commit comments

Comments
 (0)