Skip to content

Add ConstantRangeList::unionWith() and ::intersectWith() #96547

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 4 commits into from
Jun 25, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 8 additions & 0 deletions llvm/include/llvm/IR/ConstantRangeList.h
Original file line number Diff line number Diff line change
Expand Up @@ -72,6 +72,14 @@ class [[nodiscard]] ConstantRangeList {
APInt(64, Upper, /*isSigned=*/true)));
}

/// Return the range list that results from the union of this
/// ConstantRangeList with another ConstantRangeList, "CRL".
ConstantRangeList unionWith(const ConstantRangeList &CRL) const;

/// Return the range list that results from the intersection of this
/// ConstantRangeList with another ConstantRangeList, "CRL".
ConstantRangeList intersectWith(const ConstantRangeList &CRL) const;

/// Return true if this range list is equal to another range list.
bool operator==(const ConstantRangeList &CRL) const {
return Ranges == CRL.Ranges;
Expand Down
89 changes: 89 additions & 0 deletions llvm/lib/IR/ConstantRangeList.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -81,6 +81,95 @@ void ConstantRangeList::insert(const ConstantRange &NewRange) {
}
}

ConstantRangeList
ConstantRangeList::unionWith(const ConstantRangeList &CRL) const {
assert(getBitWidth() == CRL.getBitWidth() &&
"ConstantRangeList bitwidths don't agree!");
// Handle common cases.
if (empty())
return CRL;
if (CRL.empty())
return *this;

ConstantRangeList Result;
size_t i = 0, j = 0;
// "PreviousRange" tracks the lowest unioned range that is being processed.
// Its lower is fixed and the upper may be updated over iterations.
ConstantRange PreviousRange(getBitWidth(), false);
if (Ranges[i].getLower().slt(CRL.Ranges[j].getLower())) {
PreviousRange = Ranges[i++];
} else {
PreviousRange = CRL.Ranges[j++];
}

// Try to union "PreviousRange" and "CR". If they are disjoint, push
// "PreviousRange" to the result and assign it to "CR", a new union range.
// Otherwise, update the upper of "PreviousRange" to cover "CR". Note that,
// the lower of "PreviousRange" is always less or equal the lower of "CR".
auto UnionAndUpdateRange = [&PreviousRange,
&Result](const ConstantRange &CR) {
if (PreviousRange.getUpper().slt(CR.getLower())) {
Result.Ranges.push_back(PreviousRange);
PreviousRange = CR;
} else {
PreviousRange = ConstantRange(
PreviousRange.getLower(),
APIntOps::smax(PreviousRange.getUpper(), CR.getUpper()));
}
};
while (i < size() || j < CRL.size()) {
if (j == CRL.size() ||
(i < size() && Ranges[i].getLower().slt(CRL.Ranges[j].getLower()))) {
// Merge PreviousRange with this.
UnionAndUpdateRange(Ranges[i++]);
} else {
// Merge PreviousRange with CRL.
UnionAndUpdateRange(CRL.Ranges[j++]);
}
}
Result.Ranges.push_back(PreviousRange);
return Result;
}

ConstantRangeList
ConstantRangeList::intersectWith(const ConstantRangeList &CRL) const {
assert(getBitWidth() == CRL.getBitWidth() &&
"ConstantRangeList bitwidths don't agree!");

// Handle common cases.
if (empty())
return *this;
if (CRL.empty())
return CRL;

ConstantRangeList Result;
size_t i = 0, j = 0;
while (i < size() && j < CRL.size()) {
auto &Range = this->Ranges[i];
auto &OtherRange = CRL.Ranges[j];

// The intersection of two Ranges is (max(lowers), min(uppers)), and it's
// possible that max(lowers) > min(uppers) if they don't have intersection.
// Add the intersection to result only if it's non-empty.
// To keep simple, we don't call ConstantRange::intersectWith() as it
// considers the complex upper wrapped case and may result two ranges,
// like (2, 8) && (6, 4) = {(2, 4), (6, 8)}.
APInt Start = APIntOps::smax(Range.getLower(), OtherRange.getLower());
APInt End = APIntOps::smin(Range.getUpper(), OtherRange.getUpper());
if (Start.slt(End))
Result.Ranges.push_back(ConstantRange(Start, End));

// Move to the next Range in one list determined by the uppers.
// For example: A = {(0, 2), (4, 8)}; B = {(-2, 5), (6, 10)}
// We need to intersect three pairs: A0 && B0; A1 && B0; A1 && B1.
if (Range.getUpper().slt(OtherRange.getUpper()))
i++;
else
j++;
}
return Result;
}

void ConstantRangeList::print(raw_ostream &OS) const {
interleaveComma(Ranges, OS, [&](ConstantRange CR) {
OS << "(" << CR.getLower() << ", " << CR.getUpper() << ")";
Expand Down
112 changes: 112 additions & 0 deletions llvm/unittests/IR/ConstantRangeListTest.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -94,4 +94,116 @@ TEST_F(ConstantRangeListTest, Insert) {
EXPECT_TRUE(CRL == Expected);
}

ConstantRangeList GetCRL(ArrayRef<std::pair<APInt, APInt>> Pairs) {
SmallVector<ConstantRange, 2> Ranges;
for (auto &[Start, End] : Pairs)
Ranges.push_back(ConstantRange(Start, End));
return ConstantRangeList(Ranges);
}

TEST_F(ConstantRangeListTest, Union) {
APInt APN4 = APInt(64, -4, /*isSigned=*/true);
APInt APN2 = APInt(64, -2, /*isSigned=*/true);
APInt AP0 = APInt(64, 0, /*isSigned=*/true);
APInt AP2 = APInt(64, 2, /*isSigned=*/true);
APInt AP4 = APInt(64, 4, /*isSigned=*/true);
APInt AP6 = APInt(64, 6, /*isSigned=*/true);
APInt AP7 = APInt(64, 7, /*isSigned=*/true);
APInt AP8 = APInt(64, 8, /*isSigned=*/true);
APInt AP10 = APInt(64, 10, /*isSigned=*/true);
APInt AP11 = APInt(64, 11, /*isSigned=*/true);
APInt AP12 = APInt(64, 12, /*isSigned=*/true);
APInt AP16 = APInt(64, 16, /*isSigned=*/true);
APInt AP18 = APInt(64, 18, /*isSigned=*/true);
ConstantRangeList CRL = GetCRL({{AP0, AP4}, {AP8, AP12}});

// Union with a subset.
ConstantRangeList Empty;
EXPECT_EQ(CRL.unionWith(Empty), CRL);
EXPECT_EQ(Empty.unionWith(CRL), CRL);

EXPECT_EQ(CRL.unionWith(GetCRL({{AP0, AP2}})), CRL);
EXPECT_EQ(CRL.unionWith(GetCRL({{AP10, AP12}})), CRL);

EXPECT_EQ(CRL.unionWith(GetCRL({{AP0, AP2}, {AP8, AP10}})), CRL);
EXPECT_EQ(CRL.unionWith(GetCRL({{AP0, AP2}, {AP10, AP12}})), CRL);
EXPECT_EQ(CRL.unionWith(GetCRL({{AP2, AP4}, {AP8, AP10}})), CRL);
EXPECT_EQ(CRL.unionWith(GetCRL({{AP2, AP4}, {AP10, AP12}})), CRL);

EXPECT_EQ(CRL.unionWith(GetCRL({{AP0, AP4}, {AP8, AP10}, {AP11, AP12}})),
CRL);

EXPECT_EQ(CRL.unionWith(CRL), CRL);

// Union with new ranges.
EXPECT_EQ(CRL.unionWith(GetCRL({{APN4, APN2}})),
GetCRL({{APN4, APN2}, {AP0, AP4}, {AP8, AP12}}));
EXPECT_EQ(CRL.unionWith(GetCRL({{AP6, AP7}})),
GetCRL({{AP0, AP4}, {AP6, AP7}, {AP8, AP12}}));
EXPECT_EQ(CRL.unionWith(GetCRL({{AP16, AP18}})),
GetCRL({{AP0, AP4}, {AP8, AP12}, {AP16, AP18}}));

EXPECT_EQ(CRL.unionWith(GetCRL({{APN2, AP2}})),
GetCRL({{APN2, AP4}, {AP8, AP12}}));
EXPECT_EQ(CRL.unionWith(GetCRL({{AP2, AP6}})),
GetCRL({{AP0, AP6}, {AP8, AP12}}));
EXPECT_EQ(CRL.unionWith(GetCRL({{AP10, AP16}})),
GetCRL({{AP0, AP4}, {AP8, AP16}}));

EXPECT_EQ(CRL.unionWith(GetCRL({{APN2, AP10}})), GetCRL({{APN2, AP12}}));
EXPECT_EQ(CRL.unionWith(GetCRL({{AP2, AP10}})), GetCRL({{AP0, AP12}}));
EXPECT_EQ(CRL.unionWith(GetCRL({{AP4, AP16}})), GetCRL({{AP0, AP16}}));
EXPECT_EQ(CRL.unionWith(GetCRL({{APN2, AP16}})), GetCRL({{APN2, AP16}}));
}

TEST_F(ConstantRangeListTest, Intersect) {
APInt APN2 = APInt(64, -2, /*isSigned=*/true);
APInt AP0 = APInt(64, 0, /*isSigned=*/true);
APInt AP2 = APInt(64, 2, /*isSigned=*/true);
APInt AP4 = APInt(64, 4, /*isSigned=*/true);
APInt AP6 = APInt(64, 6, /*isSigned=*/true);
APInt AP7 = APInt(64, 7, /*isSigned=*/true);
APInt AP8 = APInt(64, 8, /*isSigned=*/true);
APInt AP10 = APInt(64, 10, /*isSigned=*/true);
APInt AP11 = APInt(64, 11, /*isSigned=*/true);
APInt AP12 = APInt(64, 12, /*isSigned=*/true);
APInt AP16 = APInt(64, 16, /*isSigned=*/true);
ConstantRangeList CRL = GetCRL({{AP0, AP4}, {AP8, AP12}});

// No intersection.
ConstantRangeList Empty;
EXPECT_EQ(CRL.intersectWith(Empty), Empty);
EXPECT_EQ(Empty.intersectWith(CRL), Empty);

EXPECT_EQ(CRL.intersectWith(GetCRL({{APN2, AP0}})), Empty);
EXPECT_EQ(CRL.intersectWith(GetCRL({{AP6, AP8}})), Empty);
EXPECT_EQ(CRL.intersectWith(GetCRL({{AP12, AP16}})), Empty);

// Single intersect range.
EXPECT_EQ(CRL.intersectWith(GetCRL({{APN2, AP2}})), GetCRL({{AP0, AP2}}));
EXPECT_EQ(CRL.intersectWith(GetCRL({{APN2, AP6}})), GetCRL({{AP0, AP4}}));
EXPECT_EQ(CRL.intersectWith(GetCRL({{AP2, AP4}})), GetCRL({{AP2, AP4}}));
EXPECT_EQ(CRL.intersectWith(GetCRL({{AP2, AP6}})), GetCRL({{AP2, AP4}}));
EXPECT_EQ(CRL.intersectWith(GetCRL({{AP6, AP10}})), GetCRL({{AP8, AP10}}));
EXPECT_EQ(CRL.intersectWith(GetCRL({{AP6, AP16}})), GetCRL({{AP8, AP12}}));
EXPECT_EQ(CRL.intersectWith(GetCRL({{AP10, AP12}})), GetCRL({{AP10, AP12}}));
EXPECT_EQ(CRL.intersectWith(GetCRL({{AP10, AP16}})), GetCRL({{AP10, AP12}}));

// Multiple intersect ranges.
EXPECT_EQ(CRL.intersectWith(GetCRL({{APN2, AP10}})),
GetCRL({{AP0, AP4}, {AP8, AP10}}));
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

isn't this supposed to be 2-4,8-10?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ah, the range to union is (-2, 10), APN2 not AP2 :-)
{(0, 4), (8,12)} U {(-2, 10)} = 0-4,8-10

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

ah yes I misread the APN2 to be AP2

EXPECT_EQ(CRL.intersectWith(GetCRL({{APN2, AP16}})), CRL);
EXPECT_EQ(CRL.intersectWith(GetCRL({{AP2, AP10}})),
GetCRL({{AP2, AP4}, {AP8, AP10}}));
EXPECT_EQ(CRL.intersectWith(GetCRL({{AP2, AP16}})),
GetCRL({{AP2, AP4}, {AP8, AP12}}));
EXPECT_EQ(CRL.intersectWith(GetCRL({{APN2, AP2}, {AP6, AP10}})),
GetCRL({{AP0, AP2}, {AP8, AP10}}));
EXPECT_EQ(CRL.intersectWith(GetCRL({{AP2, AP6}, {AP10, AP16}})),
GetCRL({{AP2, AP4}, {AP10, AP12}}));
EXPECT_EQ(CRL.intersectWith(GetCRL({{APN2, AP2}, {AP7, AP10}, {AP11, AP16}})),
GetCRL({{AP0, AP2}, {AP8, AP10}, {AP11, AP12}}));
EXPECT_EQ(CRL.intersectWith(CRL), CRL);
}

} // anonymous namespace
Loading