Skip to content

[ADT] Add implementations for mulhs and mulhu to APInt #84609

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed
wants to merge 21 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 10 additions & 0 deletions llvm/include/llvm/ADT/APInt.h
Original file line number Diff line number Diff line change
Expand Up @@ -2193,6 +2193,16 @@ inline const APInt absdiff(const APInt &A, const APInt &B) {
return A.uge(B) ? (A - B) : (B - A);
}

/// Compute the higher order bits of unsigned multiplication of two APInts.
/// Mathematically, this computes the value: `(C1 * C2) >> C2.getBitWidth()`
/// where `(C1 * C2)` has double the bit width of the original values.
APInt mulhu(const APInt &C1, const APInt &C2);

/// Compute the higher order bits of signed multiplication of two APInts.
/// Mathematically, this is `(C1 * C2) >> C2.getBitWidth()` while preserving
/// the signed bit. Example: `mulhs(-2097152, 524288) == -256`
APInt mulhs(const APInt &C1, const APInt &C2);

/// Compute GCD of two unsigned APInt values.
///
/// This function returns the greatest common divisor of the two APInt values
Expand Down
16 changes: 4 additions & 12 deletions llvm/lib/CodeGen/SelectionDAG/SelectionDAG.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -6015,18 +6015,10 @@ static std::optional<APInt> FoldValue(unsigned Opcode, const APInt &C1,
if (!C2.getBoolValue())
break;
return C1.srem(C2);
case ISD::MULHS: {
unsigned FullWidth = C1.getBitWidth() * 2;
APInt C1Ext = C1.sext(FullWidth);
APInt C2Ext = C2.sext(FullWidth);
return (C1Ext * C2Ext).extractBits(C1.getBitWidth(), C1.getBitWidth());
}
case ISD::MULHU: {
unsigned FullWidth = C1.getBitWidth() * 2;
APInt C1Ext = C1.zext(FullWidth);
APInt C2Ext = C2.zext(FullWidth);
return (C1Ext * C2Ext).extractBits(C1.getBitWidth(), C1.getBitWidth());
}
case ISD::MULHU:
return APIntOps::mulhu(C1, C2);
case ISD::MULHS:
return APIntOps::mulhs(C1, C2);
case ISD::AVGFLOORS: {
unsigned FullWidth = C1.getBitWidth() + 1;
APInt C1Ext = C1.sext(FullWidth);
Expand Down
16 changes: 16 additions & 0 deletions llvm/lib/Support/APInt.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -3067,6 +3067,22 @@ void llvm::StoreIntToMemory(const APInt &IntVal, uint8_t *Dst,
}
}

APInt APIntOps::mulhu(const APInt &C1, const APInt &C2) {
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

add same-bitwidth assertion: assert(C1.BitWidth == C2.BitWidth && "Bit widths must be the same");

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It should be C1.getBitWidth(). Updated the code with it.

// Return higher order bits for unsigned (C1 * C2)
unsigned FullWidth = C1.getBitWidth() * 2;
APInt C1Ext = C1.zext(FullWidth);
APInt C2Ext = C2.zext(FullWidth);
return (C1Ext * C2Ext).extractBits(C1.getBitWidth(), C1.getBitWidth());
}

APInt APIntOps::mulhs(const APInt &C1, const APInt &C2) {
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

add same-bitwidth assertion: assert(C1.BitWidth == C2.BitWidth && "Bit widths must be the same");

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ditto.

// Return higher order bits for signed (C1 * C2)
unsigned FullWidth = C1.getBitWidth() * 2;
APInt C1Ext = C1.sext(FullWidth);
APInt C2Ext = C2.sext(FullWidth);
return (C1Ext * C2Ext).extractBits(C1.getBitWidth(), C1.getBitWidth());
}

/// LoadIntFromMemory - Loads the integer stored in the LoadBytes bytes starting
/// from Src into IntVal, which is assumed to be wide enough and to hold zero.
void llvm::LoadIntFromMemory(APInt &IntVal, const uint8_t *Src,
Expand Down
20 changes: 20 additions & 0 deletions llvm/unittests/ADT/APIntTest.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -2805,6 +2805,26 @@ TEST(APIntTest, multiply) {
EXPECT_EQ(64U, i96.countr_zero());
}

TEST(APIntTest, Hmultiply) {
APInt i1048576(32, 1048576);

EXPECT_EQ(APInt(32, 256), APIntOps::mulhu(i1048576, i1048576));

APInt i16777216(32, 16777216);
APInt i32768(32, 32768);

EXPECT_EQ(APInt(32, 128), APIntOps::mulhu(i16777216, i32768));
EXPECT_EQ(APInt(32, 128), APIntOps::mulhu(i32768, i16777216));

APInt i2097152(32, -2097152);
APInt i524288(32, 524288);

EXPECT_EQ(APInt(32, 1024), APIntOps::mulhs(i2097152, i2097152));

EXPECT_EQ(APInt(32, -256), APIntOps::mulhs(i2097152, i524288));
EXPECT_EQ(APInt(32, -256), APIntOps::mulhs(i524288, i2097152));
}

TEST(APIntTest, RoundingUDiv) {
for (uint64_t Ai = 1; Ai <= 255; Ai++) {
APInt A(8, Ai);
Expand Down
20 changes: 4 additions & 16 deletions llvm/unittests/Support/DivisionByConstantTest.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -21,12 +21,6 @@ template <typename Fn> static void EnumerateAPInts(unsigned Bits, Fn TestFn) {
} while (++N != 0);
}

APInt MULHS(APInt X, APInt Y) {
unsigned Bits = X.getBitWidth();
unsigned WideBits = 2 * Bits;
return (X.sext(WideBits) * Y.sext(WideBits)).lshr(Bits).trunc(Bits);
}

APInt SignedDivideUsingMagic(APInt Numerator, APInt Divisor,
SignedDivisionByConstantInfo Magics) {
unsigned Bits = Numerator.getBitWidth();
Expand All @@ -48,7 +42,7 @@ APInt SignedDivideUsingMagic(APInt Numerator, APInt Divisor,
}

// Multiply the numerator by the magic value.
APInt Q = MULHS(Numerator, Magics.Magic);
APInt Q = APIntOps::mulhs(Numerator, Magics.Magic);

// (Optionally) Add/subtract the numerator using Factor.
Factor = Numerator * Factor;
Expand Down Expand Up @@ -89,12 +83,6 @@ TEST(SignedDivisionByConstantTest, Test) {
}
}

APInt MULHU(APInt X, APInt Y) {
unsigned Bits = X.getBitWidth();
unsigned WideBits = 2 * Bits;
return (X.zext(WideBits) * Y.zext(WideBits)).lshr(Bits).trunc(Bits);
}

APInt UnsignedDivideUsingMagic(const APInt &Numerator, const APInt &Divisor,
bool LZOptimization,
bool AllowEvenDivisorOptimization, bool ForceNPQ,
Expand Down Expand Up @@ -129,16 +117,16 @@ APInt UnsignedDivideUsingMagic(const APInt &Numerator, const APInt &Divisor,
APInt Q = Numerator.lshr(PreShift);

// Multiply the numerator by the magic value.
Q = MULHU(Q, Magics.Magic);
Q = APIntOps::mulhu(Q, Magics.Magic);

if (UseNPQ || ForceNPQ) {
APInt NPQ = Numerator - Q;

// For vectors we might have a mix of non-NPQ/NPQ paths, so use
// MULHU to act as a SRL-by-1 for NPQ, else multiply by zero.
// mulhu to act as a SRL-by-1 for NPQ, else multiply by zero.
APInt NPQ_Scalar = NPQ.lshr(1);
(void)NPQ_Scalar;
NPQ = MULHU(NPQ, NPQFactor);
NPQ = APIntOps::mulhu(NPQ, NPQFactor);
assert(!UseNPQ || NPQ == NPQ_Scalar);

Q = NPQ + Q;
Expand Down
10 changes: 2 additions & 8 deletions llvm/unittests/Support/KnownBitsTest.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -537,19 +537,13 @@ TEST(KnownBitsTest, BinaryExhaustive) {
[](const KnownBits &Known1, const KnownBits &Known2) {
return KnownBits::mulhs(Known1, Known2);
},
[](const APInt &N1, const APInt &N2) {
unsigned Bits = N1.getBitWidth();
return (N1.sext(2 * Bits) * N2.sext(2 * Bits)).extractBits(Bits, Bits);
},
[](const APInt &N1, const APInt &N2) { return APIntOps::mulhs(N1, N2); },
checkCorrectnessOnlyBinary);
testBinaryOpExhaustive(
[](const KnownBits &Known1, const KnownBits &Known2) {
return KnownBits::mulhu(Known1, Known2);
},
[](const APInt &N1, const APInt &N2) {
unsigned Bits = N1.getBitWidth();
return (N1.zext(2 * Bits) * N2.zext(2 * Bits)).extractBits(Bits, Bits);
},
[](const APInt &N1, const APInt &N2) { return APIntOps::mulhu(N1, N2); },
checkCorrectnessOnlyBinary);
}

Expand Down
8 changes: 2 additions & 6 deletions mlir/lib/Dialect/Arith/IR/ArithOps.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -434,9 +434,7 @@ arith::MulSIExtendedOp::fold(FoldAdaptor adaptor,
// Invoke the constant fold helper again to calculate the 'high' result.
Attribute highAttr = constFoldBinaryOp<IntegerAttr>(
adaptor.getOperands(), [](const APInt &a, const APInt &b) {
unsigned bitWidth = a.getBitWidth();
APInt fullProduct = a.sext(bitWidth * 2) * b.sext(bitWidth * 2);
return fullProduct.extractBits(bitWidth, bitWidth);
return llvm::APIntOps::mulhs(a, b);
});
assert(highAttr && "Unexpected constant-folding failure");

Expand Down Expand Up @@ -491,9 +489,7 @@ arith::MulUIExtendedOp::fold(FoldAdaptor adaptor,
// Invoke the constant fold helper again to calculate the 'high' result.
Attribute highAttr = constFoldBinaryOp<IntegerAttr>(
adaptor.getOperands(), [](const APInt &a, const APInt &b) {
unsigned bitWidth = a.getBitWidth();
APInt fullProduct = a.zext(bitWidth * 2) * b.zext(bitWidth * 2);
return fullProduct.extractBits(bitWidth, bitWidth);
return llvm::APIntOps::mulhu(a, b);
});
assert(highAttr && "Unexpected constant-folding failure");

Expand Down
7 changes: 2 additions & 5 deletions mlir/lib/Dialect/SPIRV/IR/SPIRVCanonicalization.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -250,14 +250,11 @@ struct MulExtendedFold final : OpRewritePattern<MulOp> {

auto highBits = constFoldBinaryOp<IntegerAttr>(
{lhsAttr, rhsAttr}, [](const APInt &a, const APInt &b) {
unsigned bitWidth = a.getBitWidth();
APInt c;
if (IsSigned) {
c = a.sext(bitWidth * 2) * b.sext(bitWidth * 2);
return llvm::APIntOps::mulhs(a, b);
} else {
c = a.zext(bitWidth * 2) * b.zext(bitWidth * 2);
return llvm::APIntOps::mulhu(a, b);
}
return c.extractBits(bitWidth, bitWidth); // Extract high result
});

if (!highBits)
Expand Down