Skip to content

Commit ff80414

Browse files
authored
[MLIR][Presburger] Implement PresburgerSpace::mergeAndAlignSymbols (#76397)
1 parent 945c2e6 commit ff80414

File tree

3 files changed

+109
-0
lines changed

3 files changed

+109
-0
lines changed

mlir/include/mlir/Analysis/Presburger/PresburgerSpace.h

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -290,6 +290,11 @@ class PresburgerSpace {
290290
/// the symbols in two spaces are aligned.
291291
bool isAligned(const PresburgerSpace &other, VarKind kind) const;
292292

293+
/// Merge and align symbol variables of `this` and `other` with respect to
294+
/// identifiers. After this operation the symbol variables of both spaces have
295+
/// the same identifiers in the same order.
296+
void mergeAndAlignSymbols(PresburgerSpace &other);
297+
293298
void print(llvm::raw_ostream &os) const;
294299
void dump() const;
295300

mlir/lib/Analysis/Presburger/PresburgerSpace.cpp

Lines changed: 34 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -294,6 +294,40 @@ void PresburgerSpace::setVarSymbolSeperation(unsigned newSymbolCount) {
294294
// `identifiers` remains same.
295295
}
296296

297+
void PresburgerSpace::mergeAndAlignSymbols(PresburgerSpace &other) {
298+
assert(usingIds && other.usingIds &&
299+
"Both spaces need to have identifers to merge & align");
300+
301+
// First merge & align identifiers into `other` from `this`.
302+
unsigned kindBeginOffset = other.getVarKindOffset(VarKind::Symbol);
303+
unsigned i = 0;
304+
for (const Identifier *identifier =
305+
identifiers.begin() + getVarKindOffset(VarKind::Symbol);
306+
identifier != identifiers.begin() + getVarKindEnd(VarKind::Symbol);
307+
identifier++) {
308+
// If the identifier exists in `other`, then align it; otherwise insert it
309+
// assuming it is a new identifier. Search in `other` starting at position
310+
// `i` since the left of `i` is aligned.
311+
auto *findEnd =
312+
other.identifiers.begin() + other.getVarKindEnd(VarKind::Symbol);
313+
auto *itr = std::find(other.identifiers.begin() + kindBeginOffset + i,
314+
findEnd, *identifier);
315+
if (itr != findEnd) {
316+
std::iter_swap(other.identifiers.begin() + kindBeginOffset + i, itr);
317+
} else {
318+
other.insertVar(VarKind::Symbol, i);
319+
other.getId(VarKind::Symbol, i) = *identifier;
320+
}
321+
i++;
322+
}
323+
324+
// Finally add identifiers that are in `other`, but not in `this` to `this`.
325+
for (unsigned e = other.getNumVarKind(VarKind::Symbol); i < e; i++) {
326+
insertVar(VarKind::Symbol, i);
327+
getId(VarKind::Symbol, i) = other.getId(VarKind::Symbol, i);
328+
}
329+
}
330+
297331
void PresburgerSpace::print(llvm::raw_ostream &os) const {
298332
os << "Domain: " << getNumDomainVars() << ", "
299333
<< "Range: " << getNumRangeVars() << ", "

mlir/unittests/Analysis/Presburger/PresburgerSpaceTest.cpp

Lines changed: 70 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -193,3 +193,73 @@ TEST(PresburgerSpaceTest, convertVarKind2) {
193193
EXPECT_EQ(space.getId(VarKind::Symbol, 2), Identifier(&identifiers[1]));
194194
EXPECT_EQ(space.getId(VarKind::Symbol, 3), Identifier(&identifiers[3]));
195195
}
196+
197+
TEST(PresburgerSpaceTest, mergeAndAlignSymbols) {
198+
PresburgerSpace space = PresburgerSpace::getRelationSpace(3, 3, 2, 0);
199+
space.resetIds();
200+
201+
PresburgerSpace otherSpace = PresburgerSpace::getRelationSpace(3, 2, 3, 0);
202+
otherSpace.resetIds();
203+
204+
// Attach identifiers.
205+
int identifiers[7] = {0, 1, 2, 3, 4, 5, 6};
206+
int otherIdentifiers[8] = {10, 11, 12, 13, 14, 15, 16, 17};
207+
208+
space.getId(VarKind::Domain, 0) = Identifier(&identifiers[0]);
209+
space.getId(VarKind::Domain, 1) = Identifier(&identifiers[1]);
210+
// Note the common identifier.
211+
space.getId(VarKind::Domain, 2) = Identifier(&otherIdentifiers[2]);
212+
space.getId(VarKind::Range, 0) = Identifier(&identifiers[2]);
213+
space.getId(VarKind::Range, 1) = Identifier(&identifiers[3]);
214+
space.getId(VarKind::Range, 2) = Identifier(&identifiers[4]);
215+
space.getId(VarKind::Symbol, 0) = Identifier(&identifiers[5]);
216+
space.getId(VarKind::Symbol, 1) = Identifier(&identifiers[6]);
217+
218+
otherSpace.getId(VarKind::Domain, 0) = Identifier(&otherIdentifiers[0]);
219+
otherSpace.getId(VarKind::Domain, 1) = Identifier(&otherIdentifiers[1]);
220+
otherSpace.getId(VarKind::Domain, 2) = Identifier(&otherIdentifiers[2]);
221+
otherSpace.getId(VarKind::Range, 0) = Identifier(&otherIdentifiers[3]);
222+
otherSpace.getId(VarKind::Range, 1) = Identifier(&otherIdentifiers[4]);
223+
// Note the common identifier.
224+
otherSpace.getId(VarKind::Symbol, 0) = Identifier(&identifiers[6]);
225+
otherSpace.getId(VarKind::Symbol, 1) = Identifier(&otherIdentifiers[5]);
226+
otherSpace.getId(VarKind::Symbol, 2) = Identifier(&otherIdentifiers[7]);
227+
228+
space.mergeAndAlignSymbols(otherSpace);
229+
230+
// Check if merge & align is successful.
231+
// Check symbol var identifiers.
232+
EXPECT_EQ(4u, space.getNumSymbolVars());
233+
EXPECT_EQ(4u, otherSpace.getNumSymbolVars());
234+
EXPECT_EQ(space.getId(VarKind::Symbol, 0), Identifier(&identifiers[5]));
235+
EXPECT_EQ(space.getId(VarKind::Symbol, 1), Identifier(&identifiers[6]));
236+
EXPECT_EQ(space.getId(VarKind::Symbol, 2), Identifier(&otherIdentifiers[5]));
237+
EXPECT_EQ(space.getId(VarKind::Symbol, 3), Identifier(&otherIdentifiers[7]));
238+
EXPECT_EQ(otherSpace.getId(VarKind::Symbol, 0), Identifier(&identifiers[5]));
239+
EXPECT_EQ(otherSpace.getId(VarKind::Symbol, 1), Identifier(&identifiers[6]));
240+
EXPECT_EQ(otherSpace.getId(VarKind::Symbol, 2),
241+
Identifier(&otherIdentifiers[5]));
242+
EXPECT_EQ(otherSpace.getId(VarKind::Symbol, 3),
243+
Identifier(&otherIdentifiers[7]));
244+
// Check that domain and range var identifiers are not affected.
245+
EXPECT_EQ(3u, space.getNumDomainVars());
246+
EXPECT_EQ(3u, space.getNumRangeVars());
247+
EXPECT_EQ(space.getId(VarKind::Domain, 0), Identifier(&identifiers[0]));
248+
EXPECT_EQ(space.getId(VarKind::Domain, 1), Identifier(&identifiers[1]));
249+
EXPECT_EQ(space.getId(VarKind::Domain, 2), Identifier(&otherIdentifiers[2]));
250+
EXPECT_EQ(space.getId(VarKind::Range, 0), Identifier(&identifiers[2]));
251+
EXPECT_EQ(space.getId(VarKind::Range, 1), Identifier(&identifiers[3]));
252+
EXPECT_EQ(space.getId(VarKind::Range, 2), Identifier(&identifiers[4]));
253+
EXPECT_EQ(3u, otherSpace.getNumDomainVars());
254+
EXPECT_EQ(2u, otherSpace.getNumRangeVars());
255+
EXPECT_EQ(otherSpace.getId(VarKind::Domain, 0),
256+
Identifier(&otherIdentifiers[0]));
257+
EXPECT_EQ(otherSpace.getId(VarKind::Domain, 1),
258+
Identifier(&otherIdentifiers[1]));
259+
EXPECT_EQ(otherSpace.getId(VarKind::Domain, 2),
260+
Identifier(&otherIdentifiers[2]));
261+
EXPECT_EQ(otherSpace.getId(VarKind::Range, 0),
262+
Identifier(&otherIdentifiers[3]));
263+
EXPECT_EQ(otherSpace.getId(VarKind::Range, 1),
264+
Identifier(&otherIdentifiers[4]));
265+
}

0 commit comments

Comments
 (0)