Skip to content

Commit aff0570

Browse files
authored
[ADT] Add implementations for avgFloor and avgCeil to APInt (#84431)
Supports both signed and unsigned expansions. SelectionDAG now calls the APInt implementation of these functions. Fixes #84211.
1 parent afec257 commit aff0570

File tree

4 files changed

+142
-24
lines changed

4 files changed

+142
-24
lines changed

llvm/include/llvm/ADT/APInt.h

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2198,6 +2198,18 @@ inline const APInt abdu(const APInt &A, const APInt &B) {
21982198
return A.uge(B) ? (A - B) : (B - A);
21992199
}
22002200

2201+
/// Compute the floor of the signed average of C1 and C2
2202+
APInt avgFloorS(const APInt &C1, const APInt &C2);
2203+
2204+
/// Compute the floor of the unsigned average of C1 and C2
2205+
APInt avgFloorU(const APInt &C1, const APInt &C2);
2206+
2207+
/// Compute the ceil of the signed average of C1 and C2
2208+
APInt avgCeilS(const APInt &C1, const APInt &C2);
2209+
2210+
/// Compute the ceil of the unsigned average of C1 and C2
2211+
APInt avgCeilU(const APInt &C1, const APInt &C2);
2212+
22012213
/// Compute GCD of two unsigned APInt values.
22022214
///
22032215
/// This function returns the greatest common divisor of the two APInt values

llvm/lib/CodeGen/SelectionDAG/SelectionDAG.cpp

Lines changed: 8 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -6039,30 +6039,14 @@ static std::optional<APInt> FoldValue(unsigned Opcode, const APInt &C1,
60396039
APInt C2Ext = C2.zext(FullWidth);
60406040
return (C1Ext * C2Ext).extractBits(C1.getBitWidth(), C1.getBitWidth());
60416041
}
6042-
case ISD::AVGFLOORS: {
6043-
unsigned FullWidth = C1.getBitWidth() + 1;
6044-
APInt C1Ext = C1.sext(FullWidth);
6045-
APInt C2Ext = C2.sext(FullWidth);
6046-
return (C1Ext + C2Ext).extractBits(C1.getBitWidth(), 1);
6047-
}
6048-
case ISD::AVGFLOORU: {
6049-
unsigned FullWidth = C1.getBitWidth() + 1;
6050-
APInt C1Ext = C1.zext(FullWidth);
6051-
APInt C2Ext = C2.zext(FullWidth);
6052-
return (C1Ext + C2Ext).extractBits(C1.getBitWidth(), 1);
6053-
}
6054-
case ISD::AVGCEILS: {
6055-
unsigned FullWidth = C1.getBitWidth() + 1;
6056-
APInt C1Ext = C1.sext(FullWidth);
6057-
APInt C2Ext = C2.sext(FullWidth);
6058-
return (C1Ext + C2Ext + 1).extractBits(C1.getBitWidth(), 1);
6059-
}
6060-
case ISD::AVGCEILU: {
6061-
unsigned FullWidth = C1.getBitWidth() + 1;
6062-
APInt C1Ext = C1.zext(FullWidth);
6063-
APInt C2Ext = C2.zext(FullWidth);
6064-
return (C1Ext + C2Ext + 1).extractBits(C1.getBitWidth(), 1);
6065-
}
6042+
case ISD::AVGFLOORS:
6043+
return APIntOps::avgFloorS(C1, C2);
6044+
case ISD::AVGFLOORU:
6045+
return APIntOps::avgFloorU(C1, C2);
6046+
case ISD::AVGCEILS:
6047+
return APIntOps::avgCeilS(C1, C2);
6048+
case ISD::AVGCEILU:
6049+
return APIntOps::avgCeilU(C1, C2);
60666050
case ISD::ABDS:
60676051
return APIntOps::abds(C1, C2);
60686052
case ISD::ABDU:

llvm/lib/Support/APInt.cpp

Lines changed: 36 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3094,3 +3094,39 @@ void llvm::LoadIntFromMemory(APInt &IntVal, const uint8_t *Src,
30943094
memcpy(Dst + sizeof(uint64_t) - LoadBytes, Src, LoadBytes);
30953095
}
30963096
}
3097+
3098+
APInt APIntOps::avgFloorS(const APInt &C1, const APInt &C2) {
3099+
// Return floor((C1 + C2)/2)
3100+
assert(C1.getBitWidth() == C2.getBitWidth() && "Unequal bitwidths");
3101+
unsigned FullWidth = C1.getBitWidth() + 1;
3102+
APInt C1Ext = C1.sext(FullWidth);
3103+
APInt C2Ext = C2.sext(FullWidth);
3104+
return (C1Ext + C2Ext).extractBits(C1.getBitWidth(), 1);
3105+
}
3106+
3107+
APInt APIntOps::avgFloorU(const APInt &C1, const APInt &C2) {
3108+
// Return floor((C1 + C2)/2)
3109+
assert(C1.getBitWidth() == C2.getBitWidth() && "Unequal bitwidths");
3110+
unsigned FullWidth = C1.getBitWidth() + 1;
3111+
APInt C1Ext = C1.zext(FullWidth);
3112+
APInt C2Ext = C2.zext(FullWidth);
3113+
return (C1Ext + C2Ext).extractBits(C1.getBitWidth(), 1);
3114+
}
3115+
3116+
APInt APIntOps::avgCeilS(const APInt &C1, const APInt &C2) {
3117+
// Return ceil((C1 + C2)/2)
3118+
assert(C1.getBitWidth() == C2.getBitWidth() && "Unequal bitwidths");
3119+
unsigned FullWidth = C1.getBitWidth() + 1;
3120+
APInt C1Ext = C1.sext(FullWidth);
3121+
APInt C2Ext = C2.sext(FullWidth);
3122+
return (C1Ext + C2Ext + 1).extractBits(C1.getBitWidth(), 1);
3123+
}
3124+
3125+
APInt APIntOps::avgCeilU(const APInt &C1, const APInt &C2) {
3126+
// Return ceil((C1 + C2)/2)
3127+
assert(C1.getBitWidth() == C2.getBitWidth() && "Unequal bitwidths");
3128+
unsigned FullWidth = C1.getBitWidth() + 1;
3129+
APInt C1Ext = C1.zext(FullWidth);
3130+
APInt C2Ext = C2.zext(FullWidth);
3131+
return (C1Ext + C2Ext + 1).extractBits(C1.getBitWidth(), 1);
3132+
}

llvm/unittests/ADT/APIntTest.cpp

Lines changed: 86 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,7 @@
1414
#include "llvm/Support/Alignment.h"
1515
#include "gtest/gtest.h"
1616
#include <array>
17+
#include <climits>
1718
#include <optional>
1819

1920
using namespace llvm;
@@ -2911,6 +2912,91 @@ TEST(APIntTest, RoundingSDiv) {
29112912
}
29122913
}
29132914

2915+
TEST(APIntTest, Average) {
2916+
APInt A0(32, 0);
2917+
APInt A2(32, 2);
2918+
APInt A100(32, 100);
2919+
APInt A101(32, 101);
2920+
APInt A200(32, 200, false);
2921+
APInt ApUMax(32, UINT_MAX, false);
2922+
2923+
EXPECT_EQ(APInt(32, 150), APIntOps::avgFloorU(A100, A200));
2924+
EXPECT_EQ(APIntOps::RoundingUDiv(A100 + A200, A2, APInt::Rounding::DOWN),
2925+
APIntOps::avgFloorU(A100, A200));
2926+
EXPECT_EQ(APIntOps::RoundingUDiv(A100 + A200, A2, APInt::Rounding::UP),
2927+
APIntOps::avgCeilU(A100, A200));
2928+
EXPECT_EQ(APIntOps::RoundingUDiv(A100 + A101, A2, APInt::Rounding::DOWN),
2929+
APIntOps::avgFloorU(A100, A101));
2930+
EXPECT_EQ(APIntOps::RoundingUDiv(A100 + A101, A2, APInt::Rounding::UP),
2931+
APIntOps::avgCeilU(A100, A101));
2932+
EXPECT_EQ(A0, APIntOps::avgFloorU(A0, A0));
2933+
EXPECT_EQ(A0, APIntOps::avgCeilU(A0, A0));
2934+
EXPECT_EQ(ApUMax, APIntOps::avgFloorU(ApUMax, ApUMax));
2935+
EXPECT_EQ(ApUMax, APIntOps::avgCeilU(ApUMax, ApUMax));
2936+
EXPECT_EQ(APIntOps::RoundingUDiv(ApUMax, A2, APInt::Rounding::DOWN),
2937+
APIntOps::avgFloorU(A0, ApUMax));
2938+
EXPECT_EQ(APIntOps::RoundingUDiv(ApUMax, A2, APInt::Rounding::UP),
2939+
APIntOps::avgCeilU(A0, ApUMax));
2940+
2941+
APInt Ap100(32, +100);
2942+
APInt Ap101(32, +101);
2943+
APInt Ap200(32, +200);
2944+
APInt Am1(32, -1);
2945+
APInt Am100(32, -100);
2946+
APInt Am101(32, -101);
2947+
APInt Am200(32, -200);
2948+
APInt AmSMin(32, INT_MIN);
2949+
APInt ApSMax(32, INT_MAX);
2950+
2951+
EXPECT_EQ(APInt(32, +150), APIntOps::avgFloorS(Ap100, Ap200));
2952+
EXPECT_EQ(APIntOps::RoundingSDiv(Ap100 + Ap200, A2, APInt::Rounding::DOWN),
2953+
APIntOps::avgFloorS(Ap100, Ap200));
2954+
EXPECT_EQ(APIntOps::RoundingSDiv(Ap100 + Ap200, A2, APInt::Rounding::UP),
2955+
APIntOps::avgCeilS(Ap100, Ap200));
2956+
2957+
EXPECT_EQ(APInt(32, -150), APIntOps::avgFloorS(Am100, Am200));
2958+
EXPECT_EQ(APIntOps::RoundingSDiv(Am100 + Am200, A2, APInt::Rounding::DOWN),
2959+
APIntOps::avgFloorS(Am100, Am200));
2960+
EXPECT_EQ(APIntOps::RoundingSDiv(Am100 + Am200, A2, APInt::Rounding::UP),
2961+
APIntOps::avgCeilS(Am100, Am200));
2962+
2963+
EXPECT_EQ(APInt(32, +100), APIntOps::avgFloorS(Ap100, Ap101));
2964+
EXPECT_EQ(APIntOps::RoundingSDiv(Ap100 + Ap101, A2, APInt::Rounding::DOWN),
2965+
APIntOps::avgFloorS(Ap100, Ap101));
2966+
EXPECT_EQ(APInt(32, +101), APIntOps::avgCeilS(Ap100, Ap101));
2967+
EXPECT_EQ(APIntOps::RoundingSDiv(Ap100 + Ap101, A2, APInt::Rounding::UP),
2968+
APIntOps::avgCeilS(Ap100, Ap101));
2969+
2970+
EXPECT_EQ(APInt(32, -101), APIntOps::avgFloorS(Am100, Am101));
2971+
EXPECT_EQ(APIntOps::RoundingSDiv(Am100 + Am101, A2, APInt::Rounding::DOWN),
2972+
APIntOps::avgFloorS(Am100, Am101));
2973+
EXPECT_EQ(APInt(32, -100), APIntOps::avgCeilS(Am100, Am101));
2974+
EXPECT_EQ(APIntOps::RoundingSDiv(Am100 + Am101, A2, APInt::Rounding::UP),
2975+
APIntOps::avgCeilS(Am100, Am101));
2976+
2977+
EXPECT_EQ(AmSMin, APIntOps::avgFloorS(AmSMin, AmSMin));
2978+
EXPECT_EQ(AmSMin, APIntOps::avgCeilS(AmSMin, AmSMin));
2979+
2980+
EXPECT_EQ(APIntOps::RoundingSDiv(AmSMin, A2, APInt::Rounding::DOWN),
2981+
APIntOps::avgFloorS(A0, AmSMin));
2982+
EXPECT_EQ(APIntOps::RoundingSDiv(AmSMin, A2, APInt::Rounding::UP),
2983+
APIntOps::avgCeilS(A0, AmSMin));
2984+
2985+
EXPECT_EQ(A0, APIntOps::avgFloorS(A0, A0));
2986+
EXPECT_EQ(A0, APIntOps::avgCeilS(A0, A0));
2987+
2988+
EXPECT_EQ(Am1, APIntOps::avgFloorS(AmSMin, ApSMax));
2989+
EXPECT_EQ(A0, APIntOps::avgCeilS(AmSMin, ApSMax));
2990+
2991+
EXPECT_EQ(APIntOps::RoundingSDiv(ApSMax, A2, APInt::Rounding::DOWN),
2992+
APIntOps::avgFloorS(A0, ApSMax));
2993+
EXPECT_EQ(APIntOps::RoundingSDiv(ApSMax, A2, APInt::Rounding::UP),
2994+
APIntOps::avgCeilS(A0, ApSMax));
2995+
2996+
EXPECT_EQ(ApSMax, APIntOps::avgFloorS(ApSMax, ApSMax));
2997+
EXPECT_EQ(ApSMax, APIntOps::avgCeilS(ApSMax, ApSMax));
2998+
}
2999+
29143000
TEST(APIntTest, umul_ov) {
29153001
const std::pair<uint64_t, uint64_t> Overflows[] = {
29163002
{0x8000000000000000, 2},

0 commit comments

Comments
 (0)