-
Notifications
You must be signed in to change notification settings - Fork 14.3k
[MLIR][Presburger] Implement PresburgerSpace::mergeAndAlignSymbols #76397
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
@llvm/pr-subscribers-mlir @llvm/pr-subscribers-mlir-presburger Author: Bharathi Ramana Joshi (iambrj) ChangesFull diff: https://github.com/llvm/llvm-project/pull/76397.diff 3 Files Affected:
diff --git a/mlir/include/mlir/Analysis/Presburger/PresburgerSpace.h b/mlir/include/mlir/Analysis/Presburger/PresburgerSpace.h
index 9fe2abafd36bad..6a450ddf3ed407 100644
--- a/mlir/include/mlir/Analysis/Presburger/PresburgerSpace.h
+++ b/mlir/include/mlir/Analysis/Presburger/PresburgerSpace.h
@@ -290,6 +290,11 @@ class PresburgerSpace {
/// the symbols in two spaces are aligned.
bool isAligned(const PresburgerSpace &other, VarKind kind) const;
+ /// Merge and align VarKind variables of `this` and `other` with respect to
+ /// identifiers. After this operation the VarKind variables of both spaces
+ /// have the same identifiers in the same order.
+ void mergeAndAlignVarKind(VarKind kind, PresburgerSpace &other);
+
void print(llvm::raw_ostream &os) const;
void dump() const;
diff --git a/mlir/lib/Analysis/Presburger/PresburgerSpace.cpp b/mlir/lib/Analysis/Presburger/PresburgerSpace.cpp
index cf1b3befbc89f8..3c440cebeee5f7 100644
--- a/mlir/lib/Analysis/Presburger/PresburgerSpace.cpp
+++ b/mlir/lib/Analysis/Presburger/PresburgerSpace.cpp
@@ -18,8 +18,9 @@ using namespace presburger;
bool Identifier::isEqual(const Identifier &other) const {
if (value == nullptr || other.value == nullptr)
return false;
- assert(value == other.value && idType == other.idType &&
- "Values of Identifiers are equal but their types do not match.");
+ assert(value != other.value ||
+ (value == other.value && idType == other.idType &&
+ "Values of Identifiers are equal but their types do not match."));
return value == other.value;
}
@@ -293,6 +294,39 @@ void PresburgerSpace::setVarSymbolSeperation(unsigned newSymbolCount) {
// `identifiers` remains same.
}
+void PresburgerSpace::mergeAndAlignVarKind(VarKind kind,
+ PresburgerSpace &other) {
+ assert(usingIds && other.usingIds &&
+ "Both spaces need to have identifers to merge & align");
+
+ // First merge & align identifiers into `other` from `this`.
+ unsigned kindBeginOffset = other.getVarKindOffset(kind);
+ unsigned i = 0;
+ for (const Identifier *identifier =
+ identifiers.begin() + getVarKindOffset(kind);
+ identifier != identifiers.begin() + getVarKindEnd(kind); identifier++) {
+ // If the identifier exists in `other`, then align it; otherwise insert it
+ // assuming it is a new identifier. Search in `other` starting at position
+ // `i` since the left of `i` is aligned.
+ auto *findEnd = other.identifiers.begin() + other.getVarKindEnd(kind);
+ auto *itr = std::find(other.identifiers.begin() + kindBeginOffset + i,
+ findEnd, *identifier);
+ if (itr != findEnd) {
+ std::iter_swap(other.identifiers.begin() + kindBeginOffset + i, itr);
+ } else {
+ other.insertVar(kind, i);
+ other.getId(kind, i) = *identifier;
+ }
+ i++;
+ }
+
+ // Finally add identifiers that are in `other`, but not in `this` to `this`.
+ for (unsigned e = other.getNumVarKind(kind); i < e; i++) {
+ insertVar(kind, i);
+ getId(kind, i) = other.getId(kind, i);
+ }
+}
+
void PresburgerSpace::print(llvm::raw_ostream &os) const {
os << "Domain: " << getNumDomainVars() << ", "
<< "Range: " << getNumRangeVars() << ", "
diff --git a/mlir/unittests/Analysis/Presburger/PresburgerSpaceTest.cpp b/mlir/unittests/Analysis/Presburger/PresburgerSpaceTest.cpp
index dd06d462f54bee..b8a578620161a8 100644
--- a/mlir/unittests/Analysis/Presburger/PresburgerSpaceTest.cpp
+++ b/mlir/unittests/Analysis/Presburger/PresburgerSpaceTest.cpp
@@ -179,3 +179,80 @@ TEST(PresburgerSpaceTest, convertVarKind2) {
EXPECT_EQ(space.getId(VarKind::Symbol, 2), Identifier(&identifiers[1]));
EXPECT_EQ(space.getId(VarKind::Symbol, 3), Identifier(&identifiers[3]));
}
+
+TEST(PresburgerSpaceTest, mergeSymbols) {
+ PresburgerSpace space = PresburgerSpace::getRelationSpace(3, 3, 2, 0);
+ space.resetIds();
+
+ PresburgerSpace otherSpace = PresburgerSpace::getRelationSpace(3, 2, 3, 0);
+ otherSpace.resetIds();
+
+ // Attach identifiers.
+ int identifiers[7] = {0, 1, 2, 3, 4, 5, 6};
+ int otherIdentifiers[8] = {10, 11, 12, 13, 14, 15, 16, 17};
+
+ space.getId(VarKind::Domain, 0) = Identifier(&identifiers[0]);
+ space.getId(VarKind::Domain, 1) = Identifier(&identifiers[1]);
+ // Note the common identifier
+ space.getId(VarKind::Domain, 2) = Identifier(&otherIdentifiers[2]);
+ space.getId(VarKind::Range, 0) = Identifier(&identifiers[2]);
+ space.getId(VarKind::Range, 1) = Identifier(&identifiers[3]);
+ space.getId(VarKind::Range, 2) = Identifier(&identifiers[4]);
+ space.getId(VarKind::Symbol, 0) = Identifier(&identifiers[5]);
+ space.getId(VarKind::Symbol, 1) = Identifier(&identifiers[6]);
+
+ otherSpace.getId(VarKind::Domain, 0) = Identifier(&otherIdentifiers[0]);
+ otherSpace.getId(VarKind::Domain, 1) = Identifier(&otherIdentifiers[1]);
+ otherSpace.getId(VarKind::Domain, 2) = Identifier(&otherIdentifiers[2]);
+ otherSpace.getId(VarKind::Range, 0) = Identifier(&otherIdentifiers[3]);
+ otherSpace.getId(VarKind::Range, 1) = Identifier(&otherIdentifiers[4]);
+ // Note the common identifier
+ otherSpace.getId(VarKind::Symbol, 0) = Identifier(&identifiers[6]);
+ otherSpace.getId(VarKind::Symbol, 1) = Identifier(&otherIdentifiers[5]);
+ otherSpace.getId(VarKind::Symbol, 2) = Identifier(&otherIdentifiers[7]);
+
+ space.mergeAndAlignVarKind(VarKind::Domain, otherSpace);
+ space.mergeAndAlignVarKind(VarKind::Range, otherSpace);
+ space.mergeAndAlignVarKind(VarKind::Symbol, otherSpace);
+
+ // Check if merge & align is successful
+ // Check domain var identifiers
+ EXPECT_EQ(5u, space.getNumRangeVars());
+ EXPECT_EQ(5u, otherSpace.getNumRangeVars());
+ space.getId(VarKind::Domain, 0) = Identifier(&identifiers[0]);
+ space.getId(VarKind::Domain, 1) = Identifier(&identifiers[1]);
+ space.getId(VarKind::Domain, 2) = Identifier(&otherIdentifiers[2]);
+ space.getId(VarKind::Domain, 3) = Identifier(&otherIdentifiers[0]);
+ space.getId(VarKind::Domain, 4) = Identifier(&otherIdentifiers[1]);
+ otherSpace.getId(VarKind::Domain, 0) = Identifier(&identifiers[0]);
+ otherSpace.getId(VarKind::Domain, 1) = Identifier(&identifiers[1]);
+ otherSpace.getId(VarKind::Domain, 2) = Identifier(&otherIdentifiers[2]);
+ otherSpace.getId(VarKind::Domain, 3) = Identifier(&otherIdentifiers[0]);
+ otherSpace.getId(VarKind::Domain, 4) = Identifier(&otherIdentifiers[1]);
+ // Check range var identifiers
+ EXPECT_EQ(5u, space.getNumRangeVars());
+ EXPECT_EQ(5u, otherSpace.getNumRangeVars());
+ space.getId(VarKind::Range, 0) = Identifier(&identifiers[2]);
+ space.getId(VarKind::Range, 1) = Identifier(&identifiers[3]);
+ space.getId(VarKind::Range, 2) = Identifier(&identifiers[4]);
+ space.getId(VarKind::Range, 3) = Identifier(&otherIdentifiers[3]);
+ space.getId(VarKind::Range, 4) = Identifier(&otherIdentifiers[4]);
+ otherSpace.getId(VarKind::Range, 0) = Identifier(&identifiers[2]);
+ otherSpace.getId(VarKind::Range, 1) = Identifier(&identifiers[3]);
+ otherSpace.getId(VarKind::Range, 2) = Identifier(&identifiers[4]);
+ otherSpace.getId(VarKind::Range, 3) = Identifier(&otherIdentifiers[3]);
+ otherSpace.getId(VarKind::Range, 4) = Identifier(&otherIdentifiers[4]);
+ // Check symbol var identifiers
+ EXPECT_EQ(4u, space.getNumSymbolVars());
+ EXPECT_EQ(4u, otherSpace.getNumSymbolVars());
+ EXPECT_EQ(space.getId(VarKind::Symbol, 0), Identifier(&identifiers[5]));
+ EXPECT_EQ(space.getId(VarKind::Symbol, 1), Identifier(&identifiers[6]));
+ EXPECT_EQ(space.getId(VarKind::Symbol, 2), Identifier(&otherIdentifiers[5]));
+ EXPECT_EQ(space.getId(VarKind::Symbol, 3), Identifier(&otherIdentifiers[7]));
+ EXPECT_EQ(otherSpace.getId(VarKind::Symbol, 0), Identifier(&identifiers[5]));
+ EXPECT_EQ(otherSpace.getId(VarKind::Symbol, 1), Identifier(&identifiers[6]));
+ EXPECT_EQ(otherSpace.getId(VarKind::Symbol, 2),
+ Identifier(&otherIdentifiers[5]));
+ EXPECT_EQ(otherSpace.getId(VarKind::Symbol, 3),
+ Identifier(&otherIdentifiers[7]));
+}
|
@@ -293,6 +294,39 @@ void PresburgerSpace::setVarSymbolSeperation(unsigned newSymbolCount) { | |||
// `identifiers` remains same. | |||
} | |||
|
|||
void PresburgerSpace::mergeAndAlignVarKind(VarKind kind, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Do we need non symbol merging anywhere?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I haven't seen any non-symbol merging in the affine dialect glue code.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Let's do mergeSymbols only then. I don't think we should have merging for any other variable kind.
18b89bc
to
293644e
Compare
✅ With the latest revision this PR passed the C/C++ code formatter. |
You probably want to adjust the title of the PR as well. |
1107253
to
4cfa312
Compare
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM
No description provided.