Skip to content

Commit 362caa1

Browse files
committed
[ADT] Add implementations for avgFloor and avgCeil to APInt
Supports both signed and unsigned expansions. SelectionDAG now calls the APInt implementation of these functions.
1 parent d9c8550 commit 362caa1

File tree

4 files changed

+142
-24
lines changed

4 files changed

+142
-24
lines changed

llvm/include/llvm/ADT/APInt.h

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2193,6 +2193,18 @@ inline const APInt absdiff(const APInt &A, const APInt &B) {
21932193
return A.uge(B) ? (A - B) : (B - A);
21942194
}
21952195

2196+
/// Compute the floor of the signed average of C1 and C2
2197+
APInt avgFloorS(const APInt &C1, const APInt &C2);
2198+
2199+
/// Compute the floor of the unsigned average of C1 and C2
2200+
APInt avgFloorU(const APInt &C1, const APInt &C2);
2201+
2202+
/// Compute the ceil of the signed average of C1 and C2
2203+
APInt avgCeilS(const APInt &C1, const APInt &C2);
2204+
2205+
/// Compute the ceil of the unsigned average of C1 and C2
2206+
APInt avgCeilU(const APInt &C1, const APInt &C2);
2207+
21962208
/// Compute GCD of two unsigned APInt values.
21972209
///
21982210
/// This function returns the greatest common divisor of the two APInt values

llvm/lib/CodeGen/SelectionDAG/SelectionDAG.cpp

Lines changed: 8 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -6021,30 +6021,14 @@ static std::optional<APInt> FoldValue(unsigned Opcode, const APInt &C1,
60216021
APInt C2Ext = C2.zext(FullWidth);
60226022
return (C1Ext * C2Ext).extractBits(C1.getBitWidth(), C1.getBitWidth());
60236023
}
6024-
case ISD::AVGFLOORS: {
6025-
unsigned FullWidth = C1.getBitWidth() + 1;
6026-
APInt C1Ext = C1.sext(FullWidth);
6027-
APInt C2Ext = C2.sext(FullWidth);
6028-
return (C1Ext + C2Ext).extractBits(C1.getBitWidth(), 1);
6029-
}
6030-
case ISD::AVGFLOORU: {
6031-
unsigned FullWidth = C1.getBitWidth() + 1;
6032-
APInt C1Ext = C1.zext(FullWidth);
6033-
APInt C2Ext = C2.zext(FullWidth);
6034-
return (C1Ext + C2Ext).extractBits(C1.getBitWidth(), 1);
6035-
}
6036-
case ISD::AVGCEILS: {
6037-
unsigned FullWidth = C1.getBitWidth() + 1;
6038-
APInt C1Ext = C1.sext(FullWidth);
6039-
APInt C2Ext = C2.sext(FullWidth);
6040-
return (C1Ext + C2Ext + 1).extractBits(C1.getBitWidth(), 1);
6041-
}
6042-
case ISD::AVGCEILU: {
6043-
unsigned FullWidth = C1.getBitWidth() + 1;
6044-
APInt C1Ext = C1.zext(FullWidth);
6045-
APInt C2Ext = C2.zext(FullWidth);
6046-
return (C1Ext + C2Ext + 1).extractBits(C1.getBitWidth(), 1);
6047-
}
6024+
case ISD::AVGFLOORS:
6025+
return APIntOps::avgFloorS(C1, C2);
6026+
case ISD::AVGFLOORU:
6027+
return APIntOps::avgFloorU(C1, C2);
6028+
case ISD::AVGCEILS:
6029+
return APIntOps::avgCeilS(C1, C2);
6030+
case ISD::AVGCEILU:
6031+
return APIntOps::avgCeilU(C1, C2);
60486032
case ISD::ABDS:
60496033
return APIntOps::smax(C1, C2) - APIntOps::smin(C1, C2);
60506034
case ISD::ABDU:

llvm/lib/Support/APInt.cpp

Lines changed: 36 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3094,3 +3094,39 @@ void llvm::LoadIntFromMemory(APInt &IntVal, const uint8_t *Src,
30943094
memcpy(Dst + sizeof(uint64_t) - LoadBytes, Src, LoadBytes);
30953095
}
30963096
}
3097+
3098+
APInt APIntOps::avgFloorS(const APInt &C1, const APInt &C2) {
3099+
// Return floor((C1 + C2)/2)
3100+
assert(C1.getBitWidth() == C2.getBitWidth() && "Unequal bitwidths");
3101+
unsigned FullWidth = C1.getBitWidth() + 1;
3102+
APInt C1Ext = C1.sext(FullWidth);
3103+
APInt C2Ext = C2.sext(FullWidth);
3104+
return (C1Ext + C2Ext).extractBits(C1.getBitWidth(), 1);
3105+
}
3106+
3107+
APInt APIntOps::avgFloorU(const APInt &C1, const APInt &C2) {
3108+
// Return floor((C1 + C2)/2)
3109+
assert(C1.getBitWidth() == C2.getBitWidth() && "Unequal bitwidths");
3110+
unsigned FullWidth = C1.getBitWidth() + 1;
3111+
APInt C1Ext = C1.zext(FullWidth);
3112+
APInt C2Ext = C2.zext(FullWidth);
3113+
return (C1Ext + C2Ext).extractBits(C1.getBitWidth(), 1);
3114+
}
3115+
3116+
APInt APIntOps::avgCeilS(const APInt &C1, const APInt &C2) {
3117+
// Return ceil((C1 + C2)/2)
3118+
assert(C1.getBitWidth() == C2.getBitWidth() && "Unequal bitwidths");
3119+
unsigned FullWidth = C1.getBitWidth() + 1;
3120+
APInt C1Ext = C1.sext(FullWidth);
3121+
APInt C2Ext = C2.sext(FullWidth);
3122+
return (C1Ext + C2Ext + 1).extractBits(C1.getBitWidth(), 1);
3123+
}
3124+
3125+
APInt APIntOps::avgCeilU(const APInt &C1, const APInt &C2) {
3126+
// Return ceil((C1 + C2)/2)
3127+
assert(C1.getBitWidth() == C2.getBitWidth() && "Unequal bitwidths");
3128+
unsigned FullWidth = C1.getBitWidth() + 1;
3129+
APInt C1Ext = C1.zext(FullWidth);
3130+
APInt C2Ext = C2.zext(FullWidth);
3131+
return (C1Ext + C2Ext + 1).extractBits(C1.getBitWidth(), 1);
3132+
}

llvm/unittests/ADT/APIntTest.cpp

Lines changed: 86 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,7 @@
1414
#include "llvm/Support/Alignment.h"
1515
#include "gtest/gtest.h"
1616
#include <array>
17+
#include <climits>
1718
#include <optional>
1819

1920
using namespace llvm;
@@ -2877,6 +2878,91 @@ TEST(APIntTest, RoundingSDiv) {
28772878
}
28782879
}
28792880

2881+
TEST(APIntTest, Average) {
2882+
APInt A0(32, 0);
2883+
APInt A2(32, 2);
2884+
APInt A100(32, 100);
2885+
APInt A101(32, 101);
2886+
APInt A200(32, 200, false);
2887+
APInt ApUMax(32, UINT_MAX, false);
2888+
2889+
EXPECT_EQ(APInt(32, 150), APIntOps::avgFloorU(A100, A200));
2890+
EXPECT_EQ(APIntOps::RoundingUDiv(A100 + A200, A2, APInt::Rounding::DOWN),
2891+
APIntOps::avgFloorU(A100, A200));
2892+
EXPECT_EQ(APIntOps::RoundingUDiv(A100 + A200, A2, APInt::Rounding::UP),
2893+
APIntOps::avgCeilU(A100, A200));
2894+
EXPECT_EQ(APIntOps::RoundingUDiv(A100 + A101, A2, APInt::Rounding::DOWN),
2895+
APIntOps::avgFloorU(A100, A101));
2896+
EXPECT_EQ(APIntOps::RoundingUDiv(A100 + A101, A2, APInt::Rounding::UP),
2897+
APIntOps::avgCeilU(A100, A101));
2898+
EXPECT_EQ(A0, APIntOps::avgFloorU(A0, A0));
2899+
EXPECT_EQ(A0, APIntOps::avgCeilU(A0, A0));
2900+
EXPECT_EQ(ApUMax, APIntOps::avgFloorU(ApUMax, ApUMax));
2901+
EXPECT_EQ(ApUMax, APIntOps::avgCeilU(ApUMax, ApUMax));
2902+
EXPECT_EQ(APIntOps::RoundingUDiv(ApUMax, A2, APInt::Rounding::DOWN),
2903+
APIntOps::avgFloorU(A0, ApUMax));
2904+
EXPECT_EQ(APIntOps::RoundingUDiv(ApUMax, A2, APInt::Rounding::UP),
2905+
APIntOps::avgCeilU(A0, ApUMax));
2906+
2907+
APInt Ap100(32, +100);
2908+
APInt Ap101(32, +101);
2909+
APInt Ap200(32, +200);
2910+
APInt Am1(32, -1);
2911+
APInt Am100(32, -100);
2912+
APInt Am101(32, -101);
2913+
APInt Am200(32, -200);
2914+
APInt AmSMin(32, INT_MIN);
2915+
APInt ApSMax(32, INT_MAX);
2916+
2917+
EXPECT_EQ(APInt(32, +150), APIntOps::avgFloorS(Ap100, Ap200));
2918+
EXPECT_EQ(APIntOps::RoundingSDiv(Ap100 + Ap200, A2, APInt::Rounding::DOWN),
2919+
APIntOps::avgFloorS(Ap100, Ap200));
2920+
EXPECT_EQ(APIntOps::RoundingSDiv(Ap100 + Ap200, A2, APInt::Rounding::UP),
2921+
APIntOps::avgCeilS(Ap100, Ap200));
2922+
2923+
EXPECT_EQ(APInt(32, -150), APIntOps::avgFloorS(Am100, Am200));
2924+
EXPECT_EQ(APIntOps::RoundingSDiv(Am100 + Am200, A2, APInt::Rounding::DOWN),
2925+
APIntOps::avgFloorS(Am100, Am200));
2926+
EXPECT_EQ(APIntOps::RoundingSDiv(Am100 + Am200, A2, APInt::Rounding::UP),
2927+
APIntOps::avgCeilS(Am100, Am200));
2928+
2929+
EXPECT_EQ(APInt(32, +100), APIntOps::avgFloorS(Ap100, Ap101));
2930+
EXPECT_EQ(APIntOps::RoundingSDiv(Ap100 + Ap101, A2, APInt::Rounding::DOWN),
2931+
APIntOps::avgFloorS(Ap100, Ap101));
2932+
EXPECT_EQ(APInt(32, +101), APIntOps::avgCeilS(Ap100, Ap101));
2933+
EXPECT_EQ(APIntOps::RoundingSDiv(Ap100 + Ap101, A2, APInt::Rounding::UP),
2934+
APIntOps::avgCeilS(Ap100, Ap101));
2935+
2936+
EXPECT_EQ(APInt(32, -101), APIntOps::avgFloorS(Am100, Am101));
2937+
EXPECT_EQ(APIntOps::RoundingSDiv(Am100 + Am101, A2, APInt::Rounding::DOWN),
2938+
APIntOps::avgFloorS(Am100, Am101));
2939+
EXPECT_EQ(APInt(32, -100), APIntOps::avgCeilS(Am100, Am101));
2940+
EXPECT_EQ(APIntOps::RoundingSDiv(Am100 + Am101, A2, APInt::Rounding::UP),
2941+
APIntOps::avgCeilS(Am100, Am101));
2942+
2943+
EXPECT_EQ(AmSMin, APIntOps::avgFloorS(AmSMin, AmSMin));
2944+
EXPECT_EQ(AmSMin, APIntOps::avgCeilS(AmSMin, AmSMin));
2945+
2946+
EXPECT_EQ(APIntOps::RoundingSDiv(AmSMin, A2, APInt::Rounding::DOWN),
2947+
APIntOps::avgFloorS(A0, AmSMin));
2948+
EXPECT_EQ(APIntOps::RoundingSDiv(AmSMin, A2, APInt::Rounding::UP),
2949+
APIntOps::avgCeilS(A0, AmSMin));
2950+
2951+
EXPECT_EQ(A0, APIntOps::avgFloorS(A0, A0));
2952+
EXPECT_EQ(A0, APIntOps::avgCeilS(A0, A0));
2953+
2954+
EXPECT_EQ(Am1, APIntOps::avgFloorS(AmSMin, ApSMax));
2955+
EXPECT_EQ(A0, APIntOps::avgCeilS(AmSMin, ApSMax));
2956+
2957+
EXPECT_EQ(APIntOps::RoundingSDiv(ApSMax, A2, APInt::Rounding::DOWN),
2958+
APIntOps::avgFloorS(A0, ApSMax));
2959+
EXPECT_EQ(APIntOps::RoundingSDiv(ApSMax, A2, APInt::Rounding::UP),
2960+
APIntOps::avgCeilS(A0, ApSMax));
2961+
2962+
EXPECT_EQ(ApSMax, APIntOps::avgFloorS(ApSMax, ApSMax));
2963+
EXPECT_EQ(ApSMax, APIntOps::avgCeilS(ApSMax, ApSMax));
2964+
}
2965+
28802966
TEST(APIntTest, umul_ov) {
28812967
const std::pair<uint64_t, uint64_t> Overflows[] = {
28822968
{0x8000000000000000, 2},

0 commit comments

Comments
 (0)