Skip to content

Commit d1d00dd

Browse files
committed
[ADT] Add signed and unsigned mulHi and mulLo to APInt
This addresses issue #84207
1 parent d9c8550 commit d1d00dd

File tree

4 files changed

+129
-13
lines changed

4 files changed

+129
-13
lines changed

llvm/include/llvm/ADT/APInt.h

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2193,6 +2193,18 @@ inline const APInt absdiff(const APInt &A, const APInt &B) {
21932193
return A.uge(B) ? (A - B) : (B - A);
21942194
}
21952195

2196+
/// Return the high bits of the signed multiplication of C1 and C2
2197+
APInt mulHiS(const APInt &C1, const APInt &C2);
2198+
2199+
/// Return the high bits of the unsigned multiplication of C1 and C2
2200+
APInt mulHiU(const APInt &C1, const APInt &C2);
2201+
2202+
/// Return the low bits of the signed multiplication of C1 and C2
2203+
APInt mulLoS(const APInt &C1, const APInt &C2);
2204+
2205+
/// Return the low bits of the unsigned multiplication of C1 and C2
2206+
APInt mulLoU(const APInt &C1, const APInt &C2);
2207+
21962208
/// Compute GCD of two unsigned APInt values.
21972209
///
21982210
/// This function returns the greatest common divisor of the two APInt values

llvm/lib/CodeGen/SelectionDAG/SelectionDAG.cpp

Lines changed: 5 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -6009,18 +6009,10 @@ static std::optional<APInt> FoldValue(unsigned Opcode, const APInt &C1,
60096009
if (!C2.getBoolValue())
60106010
break;
60116011
return C1.srem(C2);
6012-
case ISD::MULHS: {
6013-
unsigned FullWidth = C1.getBitWidth() * 2;
6014-
APInt C1Ext = C1.sext(FullWidth);
6015-
APInt C2Ext = C2.sext(FullWidth);
6016-
return (C1Ext * C2Ext).extractBits(C1.getBitWidth(), C1.getBitWidth());
6017-
}
6018-
case ISD::MULHU: {
6019-
unsigned FullWidth = C1.getBitWidth() * 2;
6020-
APInt C1Ext = C1.zext(FullWidth);
6021-
APInt C2Ext = C2.zext(FullWidth);
6022-
return (C1Ext * C2Ext).extractBits(C1.getBitWidth(), C1.getBitWidth());
6023-
}
6012+
case ISD::MULHS:
6013+
return APIntOps::mulHiS(C1, C2);
6014+
case ISD::MULHU:
6015+
return APIntOps::mulHiU(C1, C2);
60246016
case ISD::AVGFLOORS: {
60256017
unsigned FullWidth = C1.getBitWidth() + 1;
60266018
APInt C1Ext = C1.sext(FullWidth);
@@ -6706,8 +6698,8 @@ SDValue SelectionDAG::getNode(unsigned Opcode, const SDLoc &DL, EVT VT,
67066698
break;
67076699
case ISD::UDIV:
67086700
case ISD::UREM:
6709-
case ISD::MULHU:
67106701
case ISD::MULHS:
6702+
case ISD::MULHU:
67116703
case ISD::SDIV:
67126704
case ISD::SREM:
67136705
case ISD::SADDSAT:

llvm/lib/Support/APInt.cpp

Lines changed: 28 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3094,3 +3094,31 @@ void llvm::LoadIntFromMemory(APInt &IntVal, const uint8_t *Src,
30943094
memcpy(Dst + sizeof(uint64_t) - LoadBytes, Src, LoadBytes);
30953095
}
30963096
}
3097+
3098+
APInt APIntOps::mulHiS(const APInt &C1, const APInt &C2) {
3099+
unsigned FullWidth = C1.getBitWidth() * 2;
3100+
APInt C1Ext = C1.sext(FullWidth);
3101+
APInt C2Ext = C2.sext(FullWidth);
3102+
return (C1Ext * C2Ext).extractBits(C1.getBitWidth(), C1.getBitWidth());
3103+
}
3104+
3105+
APInt APIntOps::mulHiU(const APInt &C1, const APInt &C2) {
3106+
unsigned FullWidth = C1.getBitWidth() * 2;
3107+
APInt C1Ext = C1.zext(FullWidth);
3108+
APInt C2Ext = C2.zext(FullWidth);
3109+
return (C1Ext * C2Ext).extractBits(C1.getBitWidth(), C1.getBitWidth());
3110+
}
3111+
3112+
APInt APIntOps::mulLoS(const APInt &C1, const APInt &C2) {
3113+
unsigned FullWidth = C1.getBitWidth() * 2;
3114+
APInt C1Ext = C1.sext(FullWidth);
3115+
APInt C2Ext = C2.sext(FullWidth);
3116+
return (C1Ext * C2Ext).trunc(C1.getBitWidth());
3117+
}
3118+
3119+
APInt APIntOps::mulLoU(const APInt &C1, const APInt &C2) {
3120+
unsigned FullWidth = C1.getBitWidth() * 2;
3121+
APInt C1Ext = C1.zext(FullWidth);
3122+
APInt C2Ext = C2.zext(FullWidth);
3123+
return (C1Ext * C2Ext).trunc(C1.getBitWidth());
3124+
}

llvm/unittests/ADT/APIntTest.cpp

Lines changed: 84 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,7 @@
1010
#include "llvm/ADT/ArrayRef.h"
1111
#include "llvm/ADT/DenseMap.h"
1212
#include "llvm/ADT/SmallString.h"
13+
#include "llvm/ADT/StringExtras.h"
1314
#include "llvm/ADT/Twine.h"
1415
#include "llvm/Support/Alignment.h"
1516
#include "gtest/gtest.h"
@@ -2805,6 +2806,89 @@ TEST(APIntTest, multiply) {
28052806
EXPECT_EQ(64U, i96.countr_zero());
28062807
}
28072808

2809+
TEST(APIntOpsTest, MulHiLo) {
2810+
2811+
// Unsigned
2812+
2813+
// 32 bits
2814+
APInt i32a(32, 0x0001'E235);
2815+
APInt i32b(32, 0xF623'55AD);
2816+
EXPECT_EQ(0x0001'CFA1, APIntOps::mulHiU(i32a, i32b));
2817+
EXPECT_EQ(0x7CA0'76D1, APIntOps::mulLoU(i32a, i32b));
2818+
2819+
// 64 bits
2820+
APInt i64a(64, 0x1234'5678'90AB'CDEF);
2821+
APInt i64b(64, 0xFEDC'BA09'8765'4321);
2822+
EXPECT_EQ(0x121F'A000'A372'3A57, APIntOps::mulHiU(i64a, i64b));
2823+
EXPECT_EQ(0xC24A'442F'E556'18CF, APIntOps::mulLoU(i64a, i64b));
2824+
2825+
// 128 bits
2826+
APInt i128a(128, "1234567890ABCDEF1234567890ABCDEF", 16);
2827+
APInt i128b(128, "FEDCBA0987654321FEDCBA0987654321", 16);
2828+
APInt i128ResHi = APIntOps::mulHiU(i128a, i128b);
2829+
std::string strResHi = toString(i128ResHi, 16, false, true, true, true);
2830+
EXPECT_STREQ("0x121F'A000'A372'3A57'E689'8431'2C3A'8D7E", strResHi.c_str());
2831+
APInt i128ResLo = APIntOps::mulLoU(i128a, i128b);
2832+
std::string strResLo = toString(i128ResLo, 16, false, true, true, true);
2833+
EXPECT_STREQ("0x96B4'2860'6E1E'6BF5'C24A'442F'E556'18CF", strResLo.c_str());
2834+
2835+
// Signed
2836+
2837+
// 32 bits
2838+
APInt i32c(32, 0x1234'5678); // +ve
2839+
APInt i32d(32, 0x10AB'CDEF); // +ve
2840+
APInt i32e(32, 0xFEDC'BA09); // -ve
2841+
2842+
EXPECT_EQ(0x012F'7D02, APIntOps::mulHiS(i32c, i32d));
2843+
EXPECT_EQ(0x2A42'D208, APIntOps::mulLoS(i32c, i32d));
2844+
2845+
EXPECT_EQ(0xFFEB'4988, APIntOps::mulHiS(i32c, i32e));
2846+
EXPECT_EQ(0x09CA'3A38, APIntOps::mulLoS(i32c, i32e));
2847+
2848+
EXPECT_EQ(0x0001'4B68, APIntOps::mulHiS(i32e, i32e));
2849+
EXPECT_EQ(0x22A9'1451, APIntOps::mulLoS(i32e, i32e));
2850+
2851+
// 64 bits
2852+
APInt i64c(64, 0x1234'5678'90AB'CDEF); // +ve
2853+
APInt i64d(64, 0x1234'5678'90FE'DCBA); // +ve
2854+
APInt i64e(64, 0xFEDC'BA09'8765'4321); // -ve
2855+
2856+
EXPECT_EQ(0x014B'66DC'328E'10C1, APIntOps::mulHiS(i64c, i64d));
2857+
EXPECT_EQ(0xFB99'7041'84EF'03A6, APIntOps::mulLoS(i64c, i64d));
2858+
2859+
EXPECT_EQ(0xFFEB'4988'12C6'6C68, APIntOps::mulHiS(i64c, i64e));
2860+
EXPECT_EQ(0xC24A'442F'E556'18CF, APIntOps::mulLoS(i64c, i64e));
2861+
2862+
EXPECT_EQ(0x0001'4B68'2174'FA18, APIntOps::mulHiS(i64e, i64e));
2863+
EXPECT_EQ(0xCEFE'A12C'D7A4'4A41, APIntOps::mulLoS(i64e, i64e));
2864+
2865+
// 128 bits
2866+
APInt i128c(128, "1234567890ABCDEF1234567890ABCDEF", 16); // +ve
2867+
APInt i128d(128, "1234567890FEDCBA1234567890FEDCBA", 16); // +ve
2868+
APInt i128e(128, "FEDCBA0987654321FEDCBA0987654321", 16); // -ve
2869+
2870+
i128ResHi = APIntOps::mulHiS(i128c, i128d);
2871+
strResHi = toString(i128ResHi, 16, false, true, true, true);
2872+
EXPECT_STREQ("0x14B'66DC'328E'10C1'FE30'3DF9'EA0B'2529", strResHi.c_str());
2873+
i128ResLo = APIntOps::mulLoS(i128c, i128d);
2874+
strResLo = toString(i128ResLo, 16, false, true, true, true);
2875+
EXPECT_STREQ("0xF87E'475F'3C6C'180D'FB99'7041'84EF'03A6", strResLo.c_str());
2876+
2877+
i128ResHi = APIntOps::mulHiS(i128c, i128e);
2878+
strResHi = toString(i128ResHi, 16, false, true, true, true);
2879+
EXPECT_STREQ("0xFFEB'4988'12C6'6C68'D455'2DB8'9B8E'BF8F", strResHi.c_str());
2880+
i128ResLo = APIntOps::mulLoS(i128c, i128e);
2881+
strResLo = toString(i128ResLo, 16, false, true, true, true);
2882+
EXPECT_STREQ("0x96B4'2860'6E1E'6BF5'C24A'442F'E556'18CF", strResLo.c_str());
2883+
2884+
i128ResHi = APIntOps::mulHiS(i128e, i128e);
2885+
strResHi = toString(i128ResHi, 16, false, true, true, true);
2886+
EXPECT_STREQ("0x1'4B68'2174'FA18'CCBA'AC10'2958'C4B5", strResHi.c_str());
2887+
i128ResLo = APIntOps::mulLoS(i128e, i128e);
2888+
strResLo = toString(i128ResLo, 16, false, true, true, true);
2889+
EXPECT_STREQ("0x9BB8'01D4'DF88'14DC'CEFE'A12C'D7A4'4A41", strResLo.c_str());
2890+
}
2891+
28082892
TEST(APIntTest, RoundingUDiv) {
28092893
for (uint64_t Ai = 1; Ai <= 255; Ai++) {
28102894
APInt A(8, Ai);

0 commit comments

Comments
 (0)