Skip to content

Add ConstantRangeList::unionWith() and ::intersectWith() #96547

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 4 commits into from
Jun 25, 2024

Conversation

haopliu
Copy link
Contributor

@haopliu haopliu commented Jun 24, 2024

Add ConstantRangeList::unionWith() and ::intersectWith().

These methods will be used in the "initializes" attribute inference.
df11106

@haopliu haopliu requested review from jvoung and aeubanks June 24, 2024 20:10
@llvmbot
Copy link
Member

llvmbot commented Jun 24, 2024

@llvm/pr-subscribers-llvm-ir

Author: Haopeng Liu (haopliu)

Changes

Add ConstantRangeList::unionWith() and ::intersectWith().

These methods will be used in the "initializes" attribute inference.
df11106


Full diff: https://github.com/llvm/llvm-project/pull/96547.diff

3 Files Affected:

  • (modified) llvm/include/llvm/IR/ConstantRangeList.h (+8)
  • (modified) llvm/lib/IR/ConstantRangeList.cpp (+80)
  • (modified) llvm/unittests/IR/ConstantRangeListTest.cpp (+118)
diff --git a/llvm/include/llvm/IR/ConstantRangeList.h b/llvm/include/llvm/IR/ConstantRangeList.h
index f696bd6cc6a3d..3b32f92a76653 100644
--- a/llvm/include/llvm/IR/ConstantRangeList.h
+++ b/llvm/include/llvm/IR/ConstantRangeList.h
@@ -72,6 +72,14 @@ class [[nodiscard]] ConstantRangeList {
                          APInt(64, Upper, /*isSigned=*/true)));
   }
 
+  /// Return the range list that results from the union of this range
+  /// with another range list.
+  ConstantRangeList unionWith(const ConstantRangeList &CRL) const;
+
+  /// Return the range list that results from the intersection of this range
+  /// with another range list.
+  ConstantRangeList intersectWith(const ConstantRangeList &CRL) const;
+
   /// Return true if this range list is equal to another range list.
   bool operator==(const ConstantRangeList &CRL) const {
     return Ranges == CRL.Ranges;
diff --git a/llvm/lib/IR/ConstantRangeList.cpp b/llvm/lib/IR/ConstantRangeList.cpp
index 2cc483d4e4962..2d93d06ad30b1 100644
--- a/llvm/lib/IR/ConstantRangeList.cpp
+++ b/llvm/lib/IR/ConstantRangeList.cpp
@@ -81,6 +81,86 @@ void ConstantRangeList::insert(const ConstantRange &NewRange) {
   }
 }
 
+ConstantRangeList
+ConstantRangeList::unionWith(const ConstantRangeList &CRL) const {
+  assert(getBitWidth() == CRL.getBitWidth() &&
+         "ConstantRangeList types don't agree!");
+  // Handle common cases.
+  if (empty())
+    return CRL;
+  if (CRL.empty())
+    return *this;
+
+  ConstantRangeList Result;
+  size_t i = 0, j = 0;
+  ConstantRange PreviousRange(getBitWidth(), false);
+  if (Ranges[i].getLower().slt(CRL.Ranges[j].getLower())) {
+    PreviousRange = Ranges[i++];
+  } else {
+    PreviousRange = CRL.Ranges[j++];
+  }
+  auto UnionAndUpdateRange = [&PreviousRange,
+                              &Result](const ConstantRange &CR) {
+    assert(!CR.isSignWrappedSet() && "Upper wrapped ranges are not supported");
+    if (PreviousRange.getUpper().slt(CR.getLower())) {
+      Result.Ranges.push_back(PreviousRange);
+      PreviousRange = CR;
+    } else {
+      PreviousRange = ConstantRange(
+          PreviousRange.getLower(),
+          APIntOps::smax(PreviousRange.getUpper(), CR.getUpper()));
+    }
+  };
+  while (i < size() || j < CRL.size()) {
+    if (j == CRL.size() ||
+        (i < size() && Ranges[i].getLower().slt(CRL.Ranges[j].getLower()))) {
+      // Merge PreviousRange with this.
+      UnionAndUpdateRange(Ranges[i++]);
+    } else {
+      // Merge PreviousRange with CRL.
+      UnionAndUpdateRange(CRL.Ranges[j++]);
+    }
+  }
+  Result.Ranges.push_back(PreviousRange);
+  return Result;
+}
+
+ConstantRangeList
+ConstantRangeList::intersectWith(const ConstantRangeList &CRL) const {
+  assert(getBitWidth() == CRL.getBitWidth() &&
+         "ConstantRangeList types don't agree!");
+
+  // Handle common cases.
+  if (empty())
+    return *this;
+  if (CRL.empty())
+    return CRL;
+
+  ConstantRangeList Result;
+  size_t i = 0, j = 0;
+  while (i < size() && j < CRL.size()) {
+    auto &Range = this->Ranges[i];
+    auto &OtherRange = CRL.Ranges[j];
+    assert(!Range.isSignWrappedSet() && !OtherRange.isSignWrappedSet() &&
+           "Upper wrapped ranges are not supported");
+
+    APInt Start = Range.getLower().slt(OtherRange.getLower())
+                      ? OtherRange.getLower()
+                      : Range.getLower();
+    APInt End = Range.getUpper().slt(OtherRange.getUpper())
+                    ? Range.getUpper()
+                    : OtherRange.getUpper();
+    if (Start.slt(End))
+      Result.Ranges.push_back(ConstantRange(Start, End));
+
+    if (Range.getUpper().slt(OtherRange.getUpper()))
+      i++;
+    else
+      j++;
+  }
+  return Result;
+}
+
 void ConstantRangeList::print(raw_ostream &OS) const {
   interleaveComma(Ranges, OS, [&](ConstantRange CR) {
     OS << "(" << CR.getLower() << ", " << CR.getUpper() << ")";
diff --git a/llvm/unittests/IR/ConstantRangeListTest.cpp b/llvm/unittests/IR/ConstantRangeListTest.cpp
index 144b5ccdc1fc0..d2fee18c88b5a 100644
--- a/llvm/unittests/IR/ConstantRangeListTest.cpp
+++ b/llvm/unittests/IR/ConstantRangeListTest.cpp
@@ -94,4 +94,122 @@ TEST_F(ConstantRangeListTest, Insert) {
   EXPECT_TRUE(CRL == Expected);
 }
 
+ConstantRangeList GetCRL(const SmallVector<std::pair<APInt, APInt>, 2> &Pairs) {
+  SmallVector<ConstantRange, 2> Ranges;
+  for (auto &[Start, End] : Pairs)
+    Ranges.push_back(ConstantRange(Start, End));
+  return ConstantRangeList(Ranges);
+}
+
+TEST_F(ConstantRangeListTest, Union) {
+  APInt AP0 = APInt(64, 0, /*isSigned=*/true);
+  APInt AP2 = APInt(64, 2, /*isSigned=*/true);
+  APInt AP4 = APInt(64, 4, /*isSigned=*/true);
+  APInt AP8 = APInt(64, 8, /*isSigned=*/true);
+  APInt AP10 = APInt(64, 10, /*isSigned=*/true);
+  APInt AP11 = APInt(64, 11, /*isSigned=*/true);
+  APInt AP12 = APInt(64, 12, /*isSigned=*/true);
+  ConstantRangeList CRL = GetCRL({{AP0, AP4}, {AP8, AP12}});
+
+  // Union with a subset.
+  ConstantRangeList Empty;
+  EXPECT_TRUE(CRL.unionWith(Empty) == CRL);
+  EXPECT_TRUE(Empty.unionWith(CRL) == CRL);
+
+  EXPECT_TRUE(CRL.unionWith(GetCRL({{AP0, AP2}})) == CRL);
+  EXPECT_TRUE(CRL.unionWith(GetCRL({{AP10, AP12}})) == CRL);
+
+  EXPECT_TRUE(CRL.unionWith(GetCRL({{AP0, AP2}, {AP8, AP10}})) == CRL);
+  EXPECT_TRUE(CRL.unionWith(GetCRL({{AP0, AP2}, {AP10, AP12}})) == CRL);
+  EXPECT_TRUE(CRL.unionWith(GetCRL({{AP2, AP4}, {AP8, AP10}})) == CRL);
+  EXPECT_TRUE(CRL.unionWith(GetCRL({{AP2, AP4}, {AP10, AP12}})) == CRL);
+
+  EXPECT_TRUE(CRL.unionWith(GetCRL({{AP0, AP4}, {AP8, AP10}, {AP11, AP12}})) ==
+              CRL);
+
+  EXPECT_TRUE(CRL.unionWith(CRL) == CRL);
+
+  // Union with new ranges.
+  APInt APN4 = APInt(64, -4, /*isSigned=*/true);
+  APInt APN2 = APInt(64, -2, /*isSigned=*/true);
+  APInt AP6 = APInt(64, 6, /*isSigned=*/true);
+  APInt AP7 = APInt(64, 7, /*isSigned=*/true);
+  APInt AP16 = APInt(64, 16, /*isSigned=*/true);
+  APInt AP18 = APInt(64, 18, /*isSigned=*/true);
+
+  EXPECT_TRUE(CRL.unionWith(GetCRL({{APN4, APN2}})) ==
+              GetCRL({{APN4, APN2}, {AP0, AP4}, {AP8, AP12}}));
+  EXPECT_TRUE(CRL.unionWith(GetCRL({{AP6, AP7}})) ==
+              GetCRL({{AP0, AP4}, {AP6, AP7}, {AP8, AP12}}));
+  EXPECT_TRUE(CRL.unionWith(GetCRL({{AP16, AP18}})) ==
+              GetCRL({{AP0, AP4}, {AP8, AP12}, {AP16, AP18}}));
+
+  EXPECT_TRUE(CRL.unionWith(GetCRL({{APN2, AP2}})) ==
+              GetCRL({{APN2, AP4}, {AP8, AP12}}));
+  EXPECT_TRUE(CRL.unionWith(GetCRL({{AP2, AP6}})) ==
+              GetCRL({{AP0, AP6}, {AP8, AP12}}));
+  EXPECT_TRUE(CRL.unionWith(GetCRL({{AP10, AP16}})) ==
+              GetCRL({{AP0, AP4}, {AP8, AP16}}));
+
+  EXPECT_TRUE(CRL.unionWith(GetCRL({{APN2, AP10}})) == GetCRL({{APN2, AP12}}));
+  EXPECT_TRUE(CRL.unionWith(GetCRL({{AP2, AP10}})) == GetCRL({{AP0, AP12}}));
+  EXPECT_TRUE(CRL.unionWith(GetCRL({{AP4, AP16}})) == GetCRL({{AP0, AP16}}));
+  EXPECT_TRUE(CRL.unionWith(GetCRL({{APN2, AP16}})) == GetCRL({{APN2, AP16}}));
+}
+
+TEST_F(ConstantRangeListTest, Intersect) {
+  APInt APN2 = APInt(64, -2, /*isSigned=*/true);
+  APInt AP0 = APInt(64, 0, /*isSigned=*/true);
+  APInt AP2 = APInt(64, 2, /*isSigned=*/true);
+  APInt AP4 = APInt(64, 4, /*isSigned=*/true);
+  APInt AP6 = APInt(64, 6, /*isSigned=*/true);
+  APInt AP7 = APInt(64, 7, /*isSigned=*/true);
+  APInt AP8 = APInt(64, 8, /*isSigned=*/true);
+  APInt AP10 = APInt(64, 10, /*isSigned=*/true);
+  APInt AP11 = APInt(64, 11, /*isSigned=*/true);
+  APInt AP12 = APInt(64, 12, /*isSigned=*/true);
+  APInt AP16 = APInt(64, 16, /*isSigned=*/true);
+  ConstantRangeList CRL = GetCRL({{AP0, AP4}, {AP8, AP12}});
+
+  // No intersection.
+  ConstantRangeList Empty;
+  EXPECT_TRUE(CRL.intersectWith(Empty) == Empty);
+  EXPECT_TRUE(Empty.intersectWith(CRL) == Empty);
+
+  EXPECT_TRUE(CRL.intersectWith(GetCRL({{APN2, AP0}})) == Empty);
+  EXPECT_TRUE(CRL.intersectWith(GetCRL({{AP6, AP8}})) == Empty);
+  EXPECT_TRUE(CRL.intersectWith(GetCRL({{AP12, AP16}})) == Empty);
+
+  // Single intersect range.
+  EXPECT_TRUE(CRL.intersectWith(GetCRL({{APN2, AP2}})) == GetCRL({{AP0, AP2}}));
+  EXPECT_TRUE(CRL.intersectWith(GetCRL({{APN2, AP6}})) == GetCRL({{AP0, AP4}}));
+  EXPECT_TRUE(CRL.intersectWith(GetCRL({{AP2, AP4}})) == GetCRL({{AP2, AP4}}));
+  EXPECT_TRUE(CRL.intersectWith(GetCRL({{AP2, AP6}})) == GetCRL({{AP2, AP4}}));
+  EXPECT_TRUE(CRL.intersectWith(GetCRL({{AP6, AP10}})) ==
+              GetCRL({{AP8, AP10}}));
+  EXPECT_TRUE(CRL.intersectWith(GetCRL({{AP6, AP16}})) ==
+              GetCRL({{AP8, AP12}}));
+  EXPECT_TRUE(CRL.intersectWith(GetCRL({{AP10, AP12}})) ==
+              GetCRL({{AP10, AP12}}));
+  EXPECT_TRUE(CRL.intersectWith(GetCRL({{AP10, AP16}})) ==
+              GetCRL({{AP10, AP12}}));
+
+  // Multiple intersect ranges.
+  EXPECT_TRUE(CRL.intersectWith(GetCRL({{APN2, AP10}})) ==
+              GetCRL({{AP0, AP4}, {AP8, AP10}}));
+  EXPECT_TRUE(CRL.intersectWith(GetCRL({{APN2, AP16}})) == CRL);
+  EXPECT_TRUE(CRL.intersectWith(GetCRL({{AP2, AP10}})) ==
+              GetCRL({{AP2, AP4}, {AP8, AP10}}));
+  EXPECT_TRUE(CRL.intersectWith(GetCRL({{AP2, AP16}})) ==
+              GetCRL({{AP2, AP4}, {AP8, AP12}}));
+  EXPECT_TRUE(CRL.intersectWith(GetCRL({{{APN2, AP2}, {AP6, AP10}}})) ==
+              GetCRL({{AP0, AP2}, {AP8, AP10}}));
+  EXPECT_TRUE(CRL.intersectWith(GetCRL({{{AP2, AP6}, {AP10, AP16}}})) ==
+              GetCRL({{AP2, AP4}, {AP10, AP12}}));
+  EXPECT_TRUE(
+      CRL.intersectWith(GetCRL({{{APN2, AP2}, {AP7, AP10}, {AP11, AP16}}})) ==
+      GetCRL({{AP0, AP2}, {AP8, AP10}, {AP11, AP12}}));
+  EXPECT_TRUE(CRL.intersectWith(CRL) == CRL);
+}
+
 } // anonymous namespace

@@ -94,4 +94,122 @@ TEST_F(ConstantRangeListTest, Insert) {
EXPECT_TRUE(CRL == Expected);
}

ConstantRangeList GetCRL(const SmallVector<std::pair<APInt, APInt>, 2> &Pairs) {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

does ArrayRef work here?

Copy link
Contributor Author

@haopliu haopliu Jun 25, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ah, it works! I thought ArrayRef requires specifying the explicit type at callsites.
Then, with ArrayRef, each callsite creates a temp std::vector argument by default?

}
auto UnionAndUpdateRange = [&PreviousRange,
&Result](const ConstantRange &CR) {
assert(!CR.isSignWrappedSet() && "Upper wrapped ranges are not supported");
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

these asserts should probably go into wherever we're adding a ConstantRange to a ConstantRangeList, rather than sprinkled throughout various other methods. then we can assume that this invariant holds for existing ConstantRangeLists and don't need extra asserts

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Good point. Removed.

assert(!Range.isSignWrappedSet() && !OtherRange.isSignWrappedSet() &&
"Upper wrapped ranges are not supported");

APInt Start = Range.getLower().slt(OtherRange.getLower())
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

could you add some comments here (and also in union)? it seems like there are some details that would be good to have written down

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done. PTAL.

APInt AP16 = APInt(64, 16, /*isSigned=*/true);
APInt AP18 = APInt(64, 18, /*isSigned=*/true);

EXPECT_TRUE(CRL.unionWith(GetCRL({{APN4, APN2}})) ==
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

does EXPECT_EQ not work?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done!

EXPECT_TRUE(CRL.unionWith(CRL) == CRL);

// Union with new ranges.
APInt APN4 = APInt(64, -4, /*isSigned=*/true);
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

group this with APInts above

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done!

Copy link
Contributor

@jvoung jvoung left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Otherwise, LGTM -- nice test cases!


// The intersection of two Ranges is (max(lowers), min(uppers)), and it's
// possible that max(lowers) > min(uppers). Add the intersection to result
// only if it's a non-upper-wrapped range.
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

did you mean "empty" or did you actually mean "non-upper-wrapped"? I thought we didn't support upper-wrapped ranges at all?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Revised to "non-empty".


// Multiple intersect ranges.
EXPECT_EQ(CRL.intersectWith(GetCRL({{APN2, AP10}})),
GetCRL({{AP0, AP4}, {AP8, AP10}}));
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

isn't this supposed to be 2-4,8-10?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ah, the range to union is (-2, 10), APN2 not AP2 :-)
{(0, 4), (8,12)} U {(-2, 10)} = 0-4,8-10

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

ah yes I misread the APN2 to be AP2

@haopliu haopliu merged commit e6c2216 into llvm:main Jun 25, 2024
3 of 4 checks passed
AlexisPerry pushed a commit to llvm-project-tlp/llvm-project that referenced this pull request Jul 9, 2024
Add ConstantRangeList::unionWith() and ::intersectWith().

These methods will be used in the "initializes" attribute inference.

llvm@df11106
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
Projects
None yet
Development

Successfully merging this pull request may close these issues.

4 participants