Skip to content

[SelectionDAG] Expand fixed point multiplication into libcall #79352

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 1 commit into from
Jan 30, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
18 changes: 18 additions & 0 deletions llvm/include/llvm/CodeGen/TargetLowering.h
Original file line number Diff line number Diff line change
Expand Up @@ -5287,6 +5287,24 @@ class TargetLowering : public TargetLoweringBase {
bool expandMULO(SDNode *Node, SDValue &Result, SDValue &Overflow,
SelectionDAG &DAG) const;

/// forceExpandWideMUL - Unconditionally expand a MUL into either a libcall or
/// brute force via a wide multiplication. The expansion works by
/// attempting to do a multiplication on a wider type twice the size of the
/// original operands. LL and LH represent the lower and upper halves of the
/// first operand. RL and RH represent the lower and upper halves of the
/// second operand. The upper and lower halves of the result are stored in Lo
/// and Hi.
void forceExpandWideMUL(SelectionDAG &DAG, const SDLoc &dl, bool Signed,
EVT WideVT, const SDValue LL, const SDValue LH,
const SDValue RL, const SDValue RH, SDValue &Lo,
SDValue &Hi) const;

/// Same as above, but creates the upper halves of each operand by
/// sign/zero-extending the operands.
void forceExpandWideMUL(SelectionDAG &DAG, const SDLoc &dl, bool Signed,
const SDValue LHS, const SDValue RHS, SDValue &Lo,
SDValue &Hi) const;

/// Expand a VECREDUCE_* into an explicit calculation. If Count is specified,
/// only the first Count elements of the vector are used.
SDValue expandVecReduce(SDNode *Node, SelectionDAG &DAG) const;
Expand Down
54 changes: 14 additions & 40 deletions llvm/lib/CodeGen/SelectionDAG/LegalizeIntegerTypes.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -4008,47 +4008,15 @@ void DAGTypeLegalizer::ExpandIntRes_MUL(SDNode *N,
LC = RTLIB::MUL_I128;

if (LC == RTLIB::UNKNOWN_LIBCALL || !TLI.getLibcallName(LC)) {
// We'll expand the multiplication by brute force because we have no other
// options. This is a trivially-generalized version of the code from
// Hacker's Delight (itself derived from Knuth's Algorithm M from section
// 4.3.1).
unsigned Bits = NVT.getSizeInBits();
unsigned HalfBits = Bits >> 1;
SDValue Mask = DAG.getConstant(APInt::getLowBitsSet(Bits, HalfBits), dl,
NVT);
SDValue LLL = DAG.getNode(ISD::AND, dl, NVT, LL, Mask);
SDValue RLL = DAG.getNode(ISD::AND, dl, NVT, RL, Mask);

SDValue T = DAG.getNode(ISD::MUL, dl, NVT, LLL, RLL);
SDValue TL = DAG.getNode(ISD::AND, dl, NVT, T, Mask);

SDValue Shift = DAG.getShiftAmountConstant(HalfBits, NVT, dl);
SDValue TH = DAG.getNode(ISD::SRL, dl, NVT, T, Shift);
SDValue LLH = DAG.getNode(ISD::SRL, dl, NVT, LL, Shift);
SDValue RLH = DAG.getNode(ISD::SRL, dl, NVT, RL, Shift);

SDValue U = DAG.getNode(ISD::ADD, dl, NVT,
DAG.getNode(ISD::MUL, dl, NVT, LLH, RLL), TH);
SDValue UL = DAG.getNode(ISD::AND, dl, NVT, U, Mask);
SDValue UH = DAG.getNode(ISD::SRL, dl, NVT, U, Shift);

SDValue V = DAG.getNode(ISD::ADD, dl, NVT,
DAG.getNode(ISD::MUL, dl, NVT, LLL, RLH), UL);
SDValue VH = DAG.getNode(ISD::SRL, dl, NVT, V, Shift);

SDValue W = DAG.getNode(ISD::ADD, dl, NVT,
DAG.getNode(ISD::MUL, dl, NVT, LLH, RLH),
DAG.getNode(ISD::ADD, dl, NVT, UH, VH));
Lo = DAG.getNode(ISD::ADD, dl, NVT, TL,
DAG.getNode(ISD::SHL, dl, NVT, V, Shift));

Hi = DAG.getNode(ISD::ADD, dl, NVT, W,
DAG.getNode(ISD::ADD, dl, NVT,
DAG.getNode(ISD::MUL, dl, NVT, RH, LL),
DAG.getNode(ISD::MUL, dl, NVT, RL, LH)));
// Perform a wide multiplication where the wide type is the original VT and
// the 4 parts are the split arguments.
TLI.forceExpandWideMUL(DAG, dl, /*Signed=*/true, VT, LL, LH, RL, RH, Lo,
Hi);
return;
}

// Note that we don't need to do a wide MUL here since we don't care about the
// upper half of the result if it exceeds VT.
SDValue Ops[2] = { N->getOperand(0), N->getOperand(1) };
TargetLowering::MakeLibCallOptions CallOptions;
CallOptions.setSExt(true);
Expand Down Expand Up @@ -4146,9 +4114,15 @@ void DAGTypeLegalizer::ExpandIntRes_MULFIX(SDNode *N, SDValue &Lo,
if (!TLI.expandMUL_LOHI(LoHiOp, VT, dl, LHS, RHS, Result, NVT, DAG,
TargetLowering::MulExpansionKind::OnlyLegalOrCustom,
LL, LH, RL, RH)) {
report_fatal_error("Unable to expand MUL_FIX using MUL_LOHI.");
return;
Result.clear();
Result.resize(4);

SDValue LoTmp, HiTmp;
TLI.forceExpandWideMUL(DAG, dl, Signed, LHS, RHS, LoTmp, HiTmp);
SplitInteger(LoTmp, Result[0], Result[1]);
SplitInteger(HiTmp, Result[2], Result[3]);
}
assert(Result.size() == 4 && "Unexpected number of partlets in the result");

unsigned NVTSize = NVT.getScalarSizeInBits();
assert((VTSize == NVTSize * 2) && "Expected the new value type to be half "
Expand Down
182 changes: 118 additions & 64 deletions llvm/lib/CodeGen/SelectionDAG/TargetLowering.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -10149,6 +10149,122 @@ SDValue TargetLowering::expandShlSat(SDNode *Node, SelectionDAG &DAG) const {
return DAG.getSelect(dl, VT, Cond, SatVal, Result);
}

void TargetLowering::forceExpandWideMUL(SelectionDAG &DAG, const SDLoc &dl,
bool Signed, EVT WideVT,
const SDValue LL, const SDValue LH,
const SDValue RL, const SDValue RH,
SDValue &Lo, SDValue &Hi) const {
// We can fall back to a libcall with an illegal type for the MUL if we
// have a libcall big enough.
// Also, we can fall back to a division in some cases, but that's a big
// performance hit in the general case.
RTLIB::Libcall LC = RTLIB::UNKNOWN_LIBCALL;
if (WideVT == MVT::i16)
LC = RTLIB::MUL_I16;
else if (WideVT == MVT::i32)
LC = RTLIB::MUL_I32;
else if (WideVT == MVT::i64)
LC = RTLIB::MUL_I64;
else if (WideVT == MVT::i128)
LC = RTLIB::MUL_I128;

if (LC == RTLIB::UNKNOWN_LIBCALL || !getLibcallName(LC)) {
// We'll expand the multiplication by brute force because we have no other
// options. This is a trivially-generalized version of the code from
// Hacker's Delight (itself derived from Knuth's Algorithm M from section
// 4.3.1).
EVT VT = LL.getValueType();
unsigned Bits = VT.getSizeInBits();
unsigned HalfBits = Bits >> 1;
SDValue Mask =
DAG.getConstant(APInt::getLowBitsSet(Bits, HalfBits), dl, VT);
SDValue LLL = DAG.getNode(ISD::AND, dl, VT, LL, Mask);
SDValue RLL = DAG.getNode(ISD::AND, dl, VT, RL, Mask);

SDValue T = DAG.getNode(ISD::MUL, dl, VT, LLL, RLL);
SDValue TL = DAG.getNode(ISD::AND, dl, VT, T, Mask);

SDValue Shift = DAG.getShiftAmountConstant(HalfBits, VT, dl);
SDValue TH = DAG.getNode(ISD::SRL, dl, VT, T, Shift);
SDValue LLH = DAG.getNode(ISD::SRL, dl, VT, LL, Shift);
SDValue RLH = DAG.getNode(ISD::SRL, dl, VT, RL, Shift);

SDValue U = DAG.getNode(ISD::ADD, dl, VT,
DAG.getNode(ISD::MUL, dl, VT, LLH, RLL), TH);
SDValue UL = DAG.getNode(ISD::AND, dl, VT, U, Mask);
SDValue UH = DAG.getNode(ISD::SRL, dl, VT, U, Shift);

SDValue V = DAG.getNode(ISD::ADD, dl, VT,
DAG.getNode(ISD::MUL, dl, VT, LLL, RLH), UL);
SDValue VH = DAG.getNode(ISD::SRL, dl, VT, V, Shift);

SDValue W =
DAG.getNode(ISD::ADD, dl, VT, DAG.getNode(ISD::MUL, dl, VT, LLH, RLH),
DAG.getNode(ISD::ADD, dl, VT, UH, VH));
Lo = DAG.getNode(ISD::ADD, dl, VT, TL,
DAG.getNode(ISD::SHL, dl, VT, V, Shift));

Hi = DAG.getNode(ISD::ADD, dl, VT, W,
DAG.getNode(ISD::ADD, dl, VT,
DAG.getNode(ISD::MUL, dl, VT, RH, LL),
DAG.getNode(ISD::MUL, dl, VT, RL, LH)));
} else {
// Attempt a libcall.
SDValue Ret;
TargetLowering::MakeLibCallOptions CallOptions;
CallOptions.setSExt(Signed);
CallOptions.setIsPostTypeLegalization(true);
if (shouldSplitFunctionArgumentsAsLittleEndian(DAG.getDataLayout())) {
// Halves of WideVT are packed into registers in different order
// depending on platform endianness. This is usually handled by
// the C calling convention, but we can't defer to it in
// the legalizer.
SDValue Args[] = {LL, LH, RL, RH};
Ret = makeLibCall(DAG, LC, WideVT, Args, CallOptions, dl).first;
} else {
SDValue Args[] = {LH, LL, RH, RL};
Ret = makeLibCall(DAG, LC, WideVT, Args, CallOptions, dl).first;
}
assert(Ret.getOpcode() == ISD::MERGE_VALUES &&
"Ret value is a collection of constituent nodes holding result.");
if (DAG.getDataLayout().isLittleEndian()) {
// Same as above.
Lo = Ret.getOperand(0);
Hi = Ret.getOperand(1);
} else {
Lo = Ret.getOperand(1);
Hi = Ret.getOperand(0);
}
}
}

void TargetLowering::forceExpandWideMUL(SelectionDAG &DAG, const SDLoc &dl,
bool Signed, const SDValue LHS,
const SDValue RHS, SDValue &Lo,
SDValue &Hi) const {
EVT VT = LHS.getValueType();
assert(RHS.getValueType() == VT && "Mismatching operand types");

SDValue HiLHS;
SDValue HiRHS;
if (Signed) {
// The high part is obtained by SRA'ing all but one of the bits of low
// part.
unsigned LoSize = VT.getFixedSizeInBits();
HiLHS = DAG.getNode(
ISD::SRA, dl, VT, LHS,
DAG.getConstant(LoSize - 1, dl, getPointerTy(DAG.getDataLayout())));
HiRHS = DAG.getNode(
ISD::SRA, dl, VT, RHS,
DAG.getConstant(LoSize - 1, dl, getPointerTy(DAG.getDataLayout())));
} else {
HiLHS = DAG.getConstant(0, dl, VT);
HiRHS = DAG.getConstant(0, dl, VT);
}
EVT WideVT = EVT::getIntegerVT(*DAG.getContext(), VT.getSizeInBits() * 2);
forceExpandWideMUL(DAG, dl, Signed, WideVT, LHS, HiLHS, RHS, HiRHS, Lo, Hi);
}

SDValue
TargetLowering::expandFixedPointMul(SDNode *Node, SelectionDAG &DAG) const {
assert((Node->getOpcode() == ISD::SMULFIX ||
Expand Down Expand Up @@ -10223,7 +10339,7 @@ TargetLowering::expandFixedPointMul(SDNode *Node, SelectionDAG &DAG) const {
} else if (VT.isVector()) {
return SDValue();
} else {
report_fatal_error("Unable to expand fixed point multiplication.");
forceExpandWideMUL(DAG, dl, Signed, LHS, RHS, Lo, Hi);
}

if (Scale == VTSize)
Expand Down Expand Up @@ -10522,69 +10638,7 @@ bool TargetLowering::expandMULO(SDNode *Node, SDValue &Result,
if (VT.isVector())
return false;

// We can fall back to a libcall with an illegal type for the MUL if we
// have a libcall big enough.
// Also, we can fall back to a division in some cases, but that's a big
// performance hit in the general case.
RTLIB::Libcall LC = RTLIB::UNKNOWN_LIBCALL;
if (WideVT == MVT::i16)
LC = RTLIB::MUL_I16;
else if (WideVT == MVT::i32)
LC = RTLIB::MUL_I32;
else if (WideVT == MVT::i64)
LC = RTLIB::MUL_I64;
else if (WideVT == MVT::i128)
LC = RTLIB::MUL_I128;
assert(LC != RTLIB::UNKNOWN_LIBCALL && "Cannot expand this operation!");

SDValue HiLHS;
SDValue HiRHS;
if (isSigned) {
// The high part is obtained by SRA'ing all but one of the bits of low
// part.
unsigned LoSize = VT.getFixedSizeInBits();
HiLHS =
DAG.getNode(ISD::SRA, dl, VT, LHS,
DAG.getConstant(LoSize - 1, dl,
getPointerTy(DAG.getDataLayout())));
HiRHS =
DAG.getNode(ISD::SRA, dl, VT, RHS,
DAG.getConstant(LoSize - 1, dl,
getPointerTy(DAG.getDataLayout())));
} else {
HiLHS = DAG.getConstant(0, dl, VT);
HiRHS = DAG.getConstant(0, dl, VT);
}

// Here we're passing the 2 arguments explicitly as 4 arguments that are
// pre-lowered to the correct types. This all depends upon WideVT not
// being a legal type for the architecture and thus has to be split to
// two arguments.
SDValue Ret;
TargetLowering::MakeLibCallOptions CallOptions;
CallOptions.setSExt(isSigned);
CallOptions.setIsPostTypeLegalization(true);
if (shouldSplitFunctionArgumentsAsLittleEndian(DAG.getDataLayout())) {
// Halves of WideVT are packed into registers in different order
// depending on platform endianness. This is usually handled by
// the C calling convention, but we can't defer to it in
// the legalizer.
SDValue Args[] = { LHS, HiLHS, RHS, HiRHS };
Ret = makeLibCall(DAG, LC, WideVT, Args, CallOptions, dl).first;
} else {
SDValue Args[] = { HiLHS, LHS, HiRHS, RHS };
Ret = makeLibCall(DAG, LC, WideVT, Args, CallOptions, dl).first;
}
assert(Ret.getOpcode() == ISD::MERGE_VALUES &&
"Ret value is a collection of constituent nodes holding result.");
if (DAG.getDataLayout().isLittleEndian()) {
// Same as above.
BottomHalf = Ret.getOperand(0);
TopHalf = Ret.getOperand(1);
} else {
BottomHalf = Ret.getOperand(1);
TopHalf = Ret.getOperand(0);
}
forceExpandWideMUL(DAG, dl, isSigned, LHS, RHS, BottomHalf, TopHalf);
}

Result = BottomHalf;
Expand Down
Loading