Skip to content

[MLIR][Presburger] Implement PresburgerSpace::mergeAndAlignSymbols #76397

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 1 commit into from
Jan 1, 2024

Conversation

iambrj
Copy link
Member

@iambrj iambrj commented Dec 26, 2023

No description provided.

@llvmbot
Copy link
Member

llvmbot commented Dec 26, 2023

@llvm/pr-subscribers-mlir

@llvm/pr-subscribers-mlir-presburger

Author: Bharathi Ramana Joshi (iambrj)

Changes

Full diff: https://github.com/llvm/llvm-project/pull/76397.diff

3 Files Affected:

  • (modified) mlir/include/mlir/Analysis/Presburger/PresburgerSpace.h (+5)
  • (modified) mlir/lib/Analysis/Presburger/PresburgerSpace.cpp (+36-2)
  • (modified) mlir/unittests/Analysis/Presburger/PresburgerSpaceTest.cpp (+77)
diff --git a/mlir/include/mlir/Analysis/Presburger/PresburgerSpace.h b/mlir/include/mlir/Analysis/Presburger/PresburgerSpace.h
index 9fe2abafd36bad..6a450ddf3ed407 100644
--- a/mlir/include/mlir/Analysis/Presburger/PresburgerSpace.h
+++ b/mlir/include/mlir/Analysis/Presburger/PresburgerSpace.h
@@ -290,6 +290,11 @@ class PresburgerSpace {
   /// the symbols in two spaces are aligned.
   bool isAligned(const PresburgerSpace &other, VarKind kind) const;
 
+  /// Merge and align VarKind variables of `this` and `other` with respect to
+  /// identifiers. After this operation the VarKind variables of both spaces
+  /// have the same identifiers in the same order.
+  void mergeAndAlignVarKind(VarKind kind, PresburgerSpace &other);
+
   void print(llvm::raw_ostream &os) const;
   void dump() const;
 
diff --git a/mlir/lib/Analysis/Presburger/PresburgerSpace.cpp b/mlir/lib/Analysis/Presburger/PresburgerSpace.cpp
index cf1b3befbc89f8..3c440cebeee5f7 100644
--- a/mlir/lib/Analysis/Presburger/PresburgerSpace.cpp
+++ b/mlir/lib/Analysis/Presburger/PresburgerSpace.cpp
@@ -18,8 +18,9 @@ using namespace presburger;
 bool Identifier::isEqual(const Identifier &other) const {
   if (value == nullptr || other.value == nullptr)
     return false;
-  assert(value == other.value && idType == other.idType &&
-         "Values of Identifiers are equal but their types do not match.");
+  assert(value != other.value ||
+         (value == other.value && idType == other.idType &&
+          "Values of Identifiers are equal but their types do not match."));
   return value == other.value;
 }
 
@@ -293,6 +294,39 @@ void PresburgerSpace::setVarSymbolSeperation(unsigned newSymbolCount) {
   // `identifiers` remains same.
 }
 
+void PresburgerSpace::mergeAndAlignVarKind(VarKind kind,
+                                           PresburgerSpace &other) {
+  assert(usingIds && other.usingIds &&
+         "Both spaces need to have identifers to merge & align");
+
+  // First merge & align identifiers into `other` from `this`.
+  unsigned kindBeginOffset = other.getVarKindOffset(kind);
+  unsigned i = 0;
+  for (const Identifier *identifier =
+           identifiers.begin() + getVarKindOffset(kind);
+       identifier != identifiers.begin() + getVarKindEnd(kind); identifier++) {
+    // If the identifier exists in `other`, then align it; otherwise insert it
+    // assuming it is a new identifier. Search in `other` starting at position
+    // `i` since the left of `i` is aligned.
+    auto *findEnd = other.identifiers.begin() + other.getVarKindEnd(kind);
+    auto *itr = std::find(other.identifiers.begin() + kindBeginOffset + i,
+                          findEnd, *identifier);
+    if (itr != findEnd) {
+      std::iter_swap(other.identifiers.begin() + kindBeginOffset + i, itr);
+    } else {
+      other.insertVar(kind, i);
+      other.getId(kind, i) = *identifier;
+    }
+    i++;
+  }
+
+  // Finally add identifiers that are in `other`, but not in `this` to `this`.
+  for (unsigned e = other.getNumVarKind(kind); i < e; i++) {
+    insertVar(kind, i);
+    getId(kind, i) = other.getId(kind, i);
+  }
+}
+
 void PresburgerSpace::print(llvm::raw_ostream &os) const {
   os << "Domain: " << getNumDomainVars() << ", "
      << "Range: " << getNumRangeVars() << ", "
diff --git a/mlir/unittests/Analysis/Presburger/PresburgerSpaceTest.cpp b/mlir/unittests/Analysis/Presburger/PresburgerSpaceTest.cpp
index dd06d462f54bee..b8a578620161a8 100644
--- a/mlir/unittests/Analysis/Presburger/PresburgerSpaceTest.cpp
+++ b/mlir/unittests/Analysis/Presburger/PresburgerSpaceTest.cpp
@@ -179,3 +179,80 @@ TEST(PresburgerSpaceTest, convertVarKind2) {
   EXPECT_EQ(space.getId(VarKind::Symbol, 2), Identifier(&identifiers[1]));
   EXPECT_EQ(space.getId(VarKind::Symbol, 3), Identifier(&identifiers[3]));
 }
+
+TEST(PresburgerSpaceTest, mergeSymbols) {
+  PresburgerSpace space = PresburgerSpace::getRelationSpace(3, 3, 2, 0);
+  space.resetIds();
+
+  PresburgerSpace otherSpace = PresburgerSpace::getRelationSpace(3, 2, 3, 0);
+  otherSpace.resetIds();
+
+  // Attach identifiers.
+  int identifiers[7] = {0, 1, 2, 3, 4, 5, 6};
+  int otherIdentifiers[8] = {10, 11, 12, 13, 14, 15, 16, 17};
+
+  space.getId(VarKind::Domain, 0) = Identifier(&identifiers[0]);
+  space.getId(VarKind::Domain, 1) = Identifier(&identifiers[1]);
+  // Note the common identifier
+  space.getId(VarKind::Domain, 2) = Identifier(&otherIdentifiers[2]);
+  space.getId(VarKind::Range, 0) = Identifier(&identifiers[2]);
+  space.getId(VarKind::Range, 1) = Identifier(&identifiers[3]);
+  space.getId(VarKind::Range, 2) = Identifier(&identifiers[4]);
+  space.getId(VarKind::Symbol, 0) = Identifier(&identifiers[5]);
+  space.getId(VarKind::Symbol, 1) = Identifier(&identifiers[6]);
+
+  otherSpace.getId(VarKind::Domain, 0) = Identifier(&otherIdentifiers[0]);
+  otherSpace.getId(VarKind::Domain, 1) = Identifier(&otherIdentifiers[1]);
+  otherSpace.getId(VarKind::Domain, 2) = Identifier(&otherIdentifiers[2]);
+  otherSpace.getId(VarKind::Range, 0) = Identifier(&otherIdentifiers[3]);
+  otherSpace.getId(VarKind::Range, 1) = Identifier(&otherIdentifiers[4]);
+  // Note the common identifier
+  otherSpace.getId(VarKind::Symbol, 0) = Identifier(&identifiers[6]);
+  otherSpace.getId(VarKind::Symbol, 1) = Identifier(&otherIdentifiers[5]);
+  otherSpace.getId(VarKind::Symbol, 2) = Identifier(&otherIdentifiers[7]);
+
+  space.mergeAndAlignVarKind(VarKind::Domain, otherSpace);
+  space.mergeAndAlignVarKind(VarKind::Range, otherSpace);
+  space.mergeAndAlignVarKind(VarKind::Symbol, otherSpace);
+
+  // Check if merge & align is successful
+  // Check domain var identifiers
+  EXPECT_EQ(5u, space.getNumRangeVars());
+  EXPECT_EQ(5u, otherSpace.getNumRangeVars());
+  space.getId(VarKind::Domain, 0) = Identifier(&identifiers[0]);
+  space.getId(VarKind::Domain, 1) = Identifier(&identifiers[1]);
+  space.getId(VarKind::Domain, 2) = Identifier(&otherIdentifiers[2]);
+  space.getId(VarKind::Domain, 3) = Identifier(&otherIdentifiers[0]);
+  space.getId(VarKind::Domain, 4) = Identifier(&otherIdentifiers[1]);
+  otherSpace.getId(VarKind::Domain, 0) = Identifier(&identifiers[0]);
+  otherSpace.getId(VarKind::Domain, 1) = Identifier(&identifiers[1]);
+  otherSpace.getId(VarKind::Domain, 2) = Identifier(&otherIdentifiers[2]);
+  otherSpace.getId(VarKind::Domain, 3) = Identifier(&otherIdentifiers[0]);
+  otherSpace.getId(VarKind::Domain, 4) = Identifier(&otherIdentifiers[1]);
+  // Check range var identifiers
+  EXPECT_EQ(5u, space.getNumRangeVars());
+  EXPECT_EQ(5u, otherSpace.getNumRangeVars());
+  space.getId(VarKind::Range, 0) = Identifier(&identifiers[2]);
+  space.getId(VarKind::Range, 1) = Identifier(&identifiers[3]);
+  space.getId(VarKind::Range, 2) = Identifier(&identifiers[4]);
+  space.getId(VarKind::Range, 3) = Identifier(&otherIdentifiers[3]);
+  space.getId(VarKind::Range, 4) = Identifier(&otherIdentifiers[4]);
+  otherSpace.getId(VarKind::Range, 0) = Identifier(&identifiers[2]);
+  otherSpace.getId(VarKind::Range, 1) = Identifier(&identifiers[3]);
+  otherSpace.getId(VarKind::Range, 2) = Identifier(&identifiers[4]);
+  otherSpace.getId(VarKind::Range, 3) = Identifier(&otherIdentifiers[3]);
+  otherSpace.getId(VarKind::Range, 4) = Identifier(&otherIdentifiers[4]);
+  // Check symbol var identifiers
+  EXPECT_EQ(4u, space.getNumSymbolVars());
+  EXPECT_EQ(4u, otherSpace.getNumSymbolVars());
+  EXPECT_EQ(space.getId(VarKind::Symbol, 0), Identifier(&identifiers[5]));
+  EXPECT_EQ(space.getId(VarKind::Symbol, 1), Identifier(&identifiers[6]));
+  EXPECT_EQ(space.getId(VarKind::Symbol, 2), Identifier(&otherIdentifiers[5]));
+  EXPECT_EQ(space.getId(VarKind::Symbol, 3), Identifier(&otherIdentifiers[7]));
+  EXPECT_EQ(otherSpace.getId(VarKind::Symbol, 0), Identifier(&identifiers[5]));
+  EXPECT_EQ(otherSpace.getId(VarKind::Symbol, 1), Identifier(&identifiers[6]));
+  EXPECT_EQ(otherSpace.getId(VarKind::Symbol, 2),
+            Identifier(&otherIdentifiers[5]));
+  EXPECT_EQ(otherSpace.getId(VarKind::Symbol, 3),
+            Identifier(&otherIdentifiers[7]));
+}

@@ -293,6 +294,39 @@ void PresburgerSpace::setVarSymbolSeperation(unsigned newSymbolCount) {
// `identifiers` remains same.
}

void PresburgerSpace::mergeAndAlignVarKind(VarKind kind,
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Do we need non symbol merging anywhere?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I haven't seen any non-symbol merging in the affine dialect glue code.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Let's do mergeSymbols only then. I don't think we should have merging for any other variable kind.

Copy link

github-actions bot commented Dec 31, 2023

✅ With the latest revision this PR passed the C/C++ code formatter.

@tobiasgrosser
Copy link
Contributor

You probably want to adjust the title of the PR as well.

@iambrj iambrj changed the title [MLIR][Presburger] Implement PresburgerSpace::mergeAndAlignVarKind [MLIR][Presburger] Implement PresburgerSpace::mergeAndAlignSymbols Dec 31, 2023
@iambrj iambrj force-pushed the spaceMergeSymbols branch 2 times, most recently from 1107253 to 4cfa312 Compare December 31, 2023 07:40
@iambrj iambrj requested a review from Groverkss December 31, 2023 17:57
Copy link
Member

@Groverkss Groverkss left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM

@iambrj iambrj merged commit ff80414 into llvm:main Jan 1, 2024
@iambrj iambrj deleted the spaceMergeSymbols branch January 1, 2024 18:58
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging this pull request may close these issues.

4 participants