Skip to content

Commit 85d1442

Browse files
committed
[ADT] Add signed and unsigned mulh to APInt
This addresses issue #84207
1 parent 1d99d7a commit 85d1442

File tree

6 files changed

+84
-24
lines changed

6 files changed

+84
-24
lines changed

llvm/include/llvm/ADT/APInt.h

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2216,6 +2216,14 @@ APInt avgCeilS(const APInt &C1, const APInt &C2);
22162216
/// Compute the ceil of the unsigned average of C1 and C2
22172217
APInt avgCeilU(const APInt &C1, const APInt &C2);
22182218

2219+
/// Performs (2*N)-bit multiplication on sign-extended operands.
2220+
/// Returns the high N bits of the multiplication result.
2221+
APInt mulhs(const APInt &C1, const APInt &C2);
2222+
2223+
/// Performs (2*N)-bit multiplication on zero-extended operands.
2224+
/// Returns the high N bits of the multiplication result.
2225+
APInt mulhu(const APInt &C1, const APInt &C2);
2226+
22192227
/// Compute GCD of two unsigned APInt values.
22202228
///
22212229
/// This function returns the greatest common divisor of the two APInt values

llvm/lib/CodeGen/SelectionDAG/SelectionDAG.cpp

Lines changed: 4 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -6073,18 +6073,6 @@ static std::optional<APInt> FoldValue(unsigned Opcode, const APInt &C1,
60736073
if (!C2.getBoolValue())
60746074
break;
60756075
return C1.srem(C2);
6076-
case ISD::MULHS: {
6077-
unsigned FullWidth = C1.getBitWidth() * 2;
6078-
APInt C1Ext = C1.sext(FullWidth);
6079-
APInt C2Ext = C2.sext(FullWidth);
6080-
return (C1Ext * C2Ext).extractBits(C1.getBitWidth(), C1.getBitWidth());
6081-
}
6082-
case ISD::MULHU: {
6083-
unsigned FullWidth = C1.getBitWidth() * 2;
6084-
APInt C1Ext = C1.zext(FullWidth);
6085-
APInt C2Ext = C2.zext(FullWidth);
6086-
return (C1Ext * C2Ext).extractBits(C1.getBitWidth(), C1.getBitWidth());
6087-
}
60886076
case ISD::AVGFLOORS:
60896077
return APIntOps::avgFloorS(C1, C2);
60906078
case ISD::AVGFLOORU:
@@ -6097,10 +6085,13 @@ static std::optional<APInt> FoldValue(unsigned Opcode, const APInt &C1,
60976085
return APIntOps::abds(C1, C2);
60986086
case ISD::ABDU:
60996087
return APIntOps::abdu(C1, C2);
6088+
case ISD::MULHS:
6089+
return APIntOps::mulhs(C1, C2);
6090+
case ISD::MULHU:
6091+
return APIntOps::mulhu(C1, C2);
61006092
}
61016093
return std::nullopt;
61026094
}
6103-
61046095
// Handle constant folding with UNDEF.
61056096
// TODO: Handle more cases.
61066097
static std::optional<APInt> FoldValueWithUndef(unsigned Opcode, const APInt &C1,

llvm/lib/Support/APInt.cpp

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3121,3 +3121,19 @@ APInt APIntOps::avgCeilU(const APInt &C1, const APInt &C2) {
31213121
// Return ceil((C1 + C2) / 2)
31223122
return (C1 | C2) - (C1 ^ C2).lshr(1);
31233123
}
3124+
3125+
APInt APIntOps::mulhs(const APInt &C1, const APInt &C2) {
3126+
assert(C1.getBitWidth() == C2.getBitWidth() && "Unequal bitwidths");
3127+
unsigned FullWidth = C1.getBitWidth() * 2;
3128+
APInt C1Ext = C1.sext(FullWidth);
3129+
APInt C2Ext = C2.sext(FullWidth);
3130+
return (C1Ext * C2Ext).extractBits(C1.getBitWidth(), C1.getBitWidth());
3131+
}
3132+
3133+
APInt APIntOps::mulhu(const APInt &C1, const APInt &C2) {
3134+
assert(C1.getBitWidth() == C2.getBitWidth() && "Unequal bitwidths");
3135+
unsigned FullWidth = C1.getBitWidth() * 2;
3136+
APInt C1Ext = C1.zext(FullWidth);
3137+
APInt C2Ext = C2.zext(FullWidth);
3138+
return (C1Ext * C2Ext).extractBits(C1.getBitWidth(), C1.getBitWidth());
3139+
}

llvm/unittests/ADT/APIntTest.cpp

Lines changed: 52 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2841,6 +2841,58 @@ TEST(APIntTest, multiply) {
28412841
EXPECT_EQ(64U, i96.countr_zero());
28422842
}
28432843

2844+
TEST(APIntOpsTest, Mulh) {
2845+
2846+
// Unsigned
2847+
2848+
// 32 bits
2849+
APInt i32a(32, 0x0001'E235);
2850+
APInt i32b(32, 0xF623'55AD);
2851+
EXPECT_EQ(0x0001'CFA1, APIntOps::mulhu(i32a, i32b));
2852+
2853+
// 64 bits
2854+
APInt i64a(64, 0x1234'5678'90AB'CDEF);
2855+
APInt i64b(64, 0xFEDC'BA09'8765'4321);
2856+
EXPECT_EQ(0x121F'A000'A372'3A57, APIntOps::mulhu(i64a, i64b));
2857+
2858+
// 128 bits
2859+
APInt i128a(128, "1234567890ABCDEF1234567890ABCDEF", 16);
2860+
APInt i128b(128, "FEDCBA0987654321FEDCBA0987654321", 16);
2861+
APInt i128Res = APIntOps::mulhu(i128a, i128b);
2862+
EXPECT_EQ(APInt(128, "121FA000A3723A57E68984312C3A8D7E", 16), i128Res);
2863+
2864+
// Signed
2865+
2866+
// 32 bits
2867+
APInt i32c(32, 0x1234'5678); // +ve
2868+
APInt i32d(32, 0x10AB'CDEF); // +ve
2869+
APInt i32e(32, 0xFEDC'BA09); // -ve
2870+
2871+
EXPECT_EQ(0x012F'7D02, APIntOps::mulhs(i32c, i32d));
2872+
EXPECT_EQ(0xFFEB'4988, APIntOps::mulhs(i32c, i32e));
2873+
EXPECT_EQ(0x0001'4B68, APIntOps::mulhs(i32e, i32e));
2874+
2875+
// 64 bits
2876+
APInt i64c(64, 0x1234'5678'90AB'CDEF); // +ve
2877+
APInt i64d(64, 0x1234'5678'90FE'DCBA); // +ve
2878+
APInt i64e(64, 0xFEDC'BA09'8765'4321); // -ve
2879+
2880+
EXPECT_EQ(0x014B'66DC'328E'10C1, APIntOps::mulhs(i64c, i64d));
2881+
EXPECT_EQ(0xFFEB'4988'12C6'6C68, APIntOps::mulhs(i64c, i64e));
2882+
EXPECT_EQ(0x0001'4B68'2174'FA18, APIntOps::mulhs(i64e, i64e));
2883+
2884+
// 128 bits
2885+
APInt i128c(128, "1234567890ABCDEF1234567890ABCDEF", 16); // +ve
2886+
APInt i128d(128, "1234567890FEDCBA1234567890FEDCBA", 16); // +ve
2887+
APInt i128e(128, "FEDCBA0987654321FEDCBA0987654321", 16); // -ve
2888+
2889+
i128Res = APIntOps::mulhs(i128c, i128d);
2890+
EXPECT_EQ(APInt(128, "14B66DC328E10C1FE303DF9EA0B2529", 16), i128Res);
2891+
2892+
i128Res = APIntOps::mulhs(i128c, i128e);
2893+
EXPECT_EQ(APInt(128, "FFEB498812C66C68D4552DB89B8EBF8F", 16), i128Res);
2894+
}
2895+
28442896
TEST(APIntTest, RoundingUDiv) {
28452897
for (uint64_t Ai = 1; Ai <= 255; Ai++) {
28462898
APInt A(8, Ai);

mlir/lib/Dialect/Arith/IR/ArithOps.cpp

Lines changed: 2 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -434,9 +434,7 @@ arith::MulSIExtendedOp::fold(FoldAdaptor adaptor,
434434
// Invoke the constant fold helper again to calculate the 'high' result.
435435
Attribute highAttr = constFoldBinaryOp<IntegerAttr>(
436436
adaptor.getOperands(), [](const APInt &a, const APInt &b) {
437-
unsigned bitWidth = a.getBitWidth();
438-
APInt fullProduct = a.sext(bitWidth * 2) * b.sext(bitWidth * 2);
439-
return fullProduct.extractBits(bitWidth, bitWidth);
437+
return llvm::APIntOps::mulhs(a, b);
440438
});
441439
assert(highAttr && "Unexpected constant-folding failure");
442440

@@ -491,9 +489,7 @@ arith::MulUIExtendedOp::fold(FoldAdaptor adaptor,
491489
// Invoke the constant fold helper again to calculate the 'high' result.
492490
Attribute highAttr = constFoldBinaryOp<IntegerAttr>(
493491
adaptor.getOperands(), [](const APInt &a, const APInt &b) {
494-
unsigned bitWidth = a.getBitWidth();
495-
APInt fullProduct = a.zext(bitWidth * 2) * b.zext(bitWidth * 2);
496-
return fullProduct.extractBits(bitWidth, bitWidth);
492+
return llvm::APIntOps::mulhu(a, b);
497493
});
498494
assert(highAttr && "Unexpected constant-folding failure");
499495

mlir/lib/Dialect/SPIRV/IR/SPIRVCanonicalization.cpp

Lines changed: 2 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -250,14 +250,11 @@ struct MulExtendedFold final : OpRewritePattern<MulOp> {
250250

251251
auto highBits = constFoldBinaryOp<IntegerAttr>(
252252
{lhsAttr, rhsAttr}, [](const APInt &a, const APInt &b) {
253-
unsigned bitWidth = a.getBitWidth();
254-
APInt c;
255253
if (IsSigned) {
256-
c = a.sext(bitWidth * 2) * b.sext(bitWidth * 2);
254+
return llvm::APIntOps::mulhs(a, b);
257255
} else {
258-
c = a.zext(bitWidth * 2) * b.zext(bitWidth * 2);
256+
return llvm::APIntOps::mulhu(a, b);
259257
}
260-
return c.extractBits(bitWidth, bitWidth); // Extract high result
261258
});
262259

263260
if (!highBits)

0 commit comments

Comments
 (0)