Skip to content

Commit ccf194b

Browse files
authored
[MLIR][Presburger] Implement convertVarKind for PresburgerRelation
1 parent 2f45b56 commit ccf194b

File tree

5 files changed

+88
-24
lines changed

5 files changed

+88
-24
lines changed

mlir/include/mlir/Analysis/Presburger/PresburgerRelation.h

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -66,6 +66,14 @@ class PresburgerRelation {
6666

6767
void insertVarInPlace(VarKind kind, unsigned pos, unsigned num = 1);
6868

69+
/// Converts variables of the specified kind in the column range [srcPos,
70+
/// srcPos + num) to variables of the specified kind at position dstPos. The
71+
/// ranges are relative to the kind of variable.
72+
///
73+
/// srcKind and dstKind must be different.
74+
void convertVarKind(VarKind srcKind, unsigned srcPos, unsigned num,
75+
VarKind dstKind, unsigned dstPos);
76+
6977
/// Return a reference to the list of disjuncts.
7078
ArrayRef<IntegerRelation> getAllDisjuncts() const;
7179

mlir/lib/Analysis/Presburger/PresburgerRelation.cpp

Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,7 @@
99
#include "mlir/Analysis/Presburger/PresburgerRelation.h"
1010
#include "mlir/Analysis/Presburger/IntegerRelation.h"
1111
#include "mlir/Analysis/Presburger/PWMAFunction.h"
12+
#include "mlir/Analysis/Presburger/PresburgerSpace.h"
1213
#include "mlir/Analysis/Presburger/Simplex.h"
1314
#include "mlir/Analysis/Presburger/Utils.h"
1415
#include "llvm/ADT/STLExtras.h"
@@ -38,6 +39,23 @@ void PresburgerRelation::insertVarInPlace(VarKind kind, unsigned pos,
3839
space.insertVar(kind, pos, num);
3940
}
4041

42+
void PresburgerRelation::convertVarKind(VarKind srcKind, unsigned srcPos,
43+
unsigned num, VarKind dstKind,
44+
unsigned dstPos) {
45+
assert(srcKind != VarKind::Local && dstKind != VarKind::Local &&
46+
"srcKind/dstKind cannot be local");
47+
assert(srcKind != dstKind && "cannot convert variables to the same kind");
48+
assert(srcPos + num <= space.getNumVarKind(srcKind) &&
49+
"invalid range for source variables");
50+
assert(dstPos <= space.getNumVarKind(dstKind) &&
51+
"invalid position for destination variables");
52+
53+
space.convertVarKind(srcKind, srcPos, num, dstKind, dstPos);
54+
55+
for (IntegerRelation &disjunct : disjuncts)
56+
disjunct.convertVarKind(srcKind, srcPos, srcPos + num, dstKind, dstPos);
57+
}
58+
4159
unsigned PresburgerRelation::getNumDisjuncts() const {
4260
return disjuncts.size();
4361
}

mlir/unittests/Analysis/Presburger/IntegerRelationTest.cpp

Lines changed: 0 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -16,14 +16,6 @@
1616
using namespace mlir;
1717
using namespace presburger;
1818

19-
static IntegerRelation parseRelationFromSet(StringRef set, unsigned numDomain) {
20-
IntegerRelation rel = parseIntegerPolyhedron(set);
21-
22-
rel.convertVarKind(VarKind::SetDim, 0, numDomain, VarKind::Domain);
23-
24-
return rel;
25-
}
26-
2719
TEST(IntegerRelationTest, getDomainAndRangeSet) {
2820
IntegerRelation rel = parseRelationFromSet(
2921
"(x, xr)[N] : (xr - x - 10 == 0, xr >= 0, N - xr >= 0)", 1);

mlir/unittests/Analysis/Presburger/Parser.h

Lines changed: 21 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -80,6 +80,27 @@ parsePWMAF(ArrayRef<std::pair<StringRef, StringRef>> pieces) {
8080
return func;
8181
}
8282

83+
inline IntegerRelation parseRelationFromSet(StringRef set, unsigned numDomain) {
84+
IntegerRelation rel = parseIntegerPolyhedron(set);
85+
86+
rel.convertVarKind(VarKind::SetDim, 0, numDomain, VarKind::Domain);
87+
88+
return rel;
89+
}
90+
91+
inline PresburgerRelation
92+
parsePresburgerRelationFromPresburgerSet(ArrayRef<StringRef> strs,
93+
unsigned numDomain) {
94+
assert(!strs.empty() && "strs should not be empty");
95+
96+
IntegerRelation rel = parseIntegerPolyhedron(strs[0]);
97+
PresburgerRelation result(rel);
98+
for (unsigned i = 1, e = strs.size(); i < e; ++i)
99+
result.unionInPlace(parseIntegerPolyhedron(strs[i]));
100+
result.convertVarKind(VarKind::SetDim, 0, numDomain, VarKind::Domain, 0);
101+
return result;
102+
}
103+
83104
} // namespace presburger
84105
} // namespace mlir
85106

mlir/unittests/Analysis/Presburger/PresburgerRelationTest.cpp

Lines changed: 41 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@
77
//===----------------------------------------------------------------------===//
88
#include "mlir/Analysis/Presburger/PresburgerRelation.h"
99
#include "Parser.h"
10+
#include "mlir/Analysis/Presburger/IntegerRelation.h"
1011
#include "mlir/Analysis/Presburger/Simplex.h"
1112

1213
#include <gmock/gmock.h>
@@ -16,22 +17,6 @@
1617
using namespace mlir;
1718
using namespace presburger;
1819

19-
static PresburgerRelation
20-
parsePresburgerRelationFromPresburgerSet(ArrayRef<StringRef> strs,
21-
unsigned numDomain) {
22-
assert(!strs.empty() && "strs should not be empty");
23-
24-
IntegerRelation rel = parseIntegerPolyhedron(strs[0]);
25-
rel.convertVarKind(VarKind::SetDim, 0, numDomain, VarKind::Domain);
26-
PresburgerRelation result(rel);
27-
for (unsigned i = 1, e = strs.size(); i < e; ++i) {
28-
rel = parseIntegerPolyhedron(strs[i]);
29-
rel.convertVarKind(VarKind::SetDim, 0, numDomain, VarKind::Domain);
30-
result.unionInPlace(rel);
31-
}
32-
return result;
33-
}
34-
3520
TEST(PresburgerRelationTest, intersectDomainAndRange) {
3621
{
3722
PresburgerRelation rel = parsePresburgerRelationFromPresburgerSet(
@@ -291,3 +276,43 @@ TEST(PresburgerRelationTest, getDomainAndRangeSet) {
291276

292277
EXPECT_TRUE(rangeSet.isEqual(expectedRangeSet));
293278
}
279+
280+
TEST(PresburgerRelationTest, convertVarKind) {
281+
PresburgerSpace space = PresburgerSpace::getRelationSpace(2, 1, 3, 0);
282+
283+
IntegerRelation disj1 = parseRelationFromSet(
284+
"(x, y, a)[U, V, W] : (x - U == 0, y + a - W == 0,"
285+
"U - V >= 0, y - a >= 0)",
286+
2),
287+
disj2 = parseRelationFromSet(
288+
"(x, y, a)[U, V, W] : (x + y - U == 0, x - a + V == 0,"
289+
"V - U >= 0, y + a >= 0)",
290+
2);
291+
292+
PresburgerRelation rel(disj1);
293+
rel.unionInPlace(disj2);
294+
295+
// Make a few kind conversions.
296+
rel.convertVarKind(VarKind::Domain, 0, 1, VarKind::Range, 0);
297+
rel.convertVarKind(VarKind::Symbol, 1, 2, VarKind::Domain, 1);
298+
rel.convertVarKind(VarKind::Symbol, 0, 1, VarKind::Range, 1);
299+
300+
// Expected rel.
301+
disj1.convertVarKind(VarKind::Domain, 0, 1, VarKind::Range, 0);
302+
disj1.convertVarKind(VarKind::Symbol, 1, 3, VarKind::Domain, 1);
303+
disj1.convertVarKind(VarKind::Symbol, 0, 1, VarKind::Range, 1);
304+
disj2.convertVarKind(VarKind::Domain, 0, 1, VarKind::Range, 0);
305+
disj2.convertVarKind(VarKind::Symbol, 1, 3, VarKind::Domain, 1);
306+
disj2.convertVarKind(VarKind::Symbol, 0, 1, VarKind::Range, 1);
307+
308+
PresburgerRelation expectedRel(disj1);
309+
expectedRel.unionInPlace(disj2);
310+
311+
// Check if var counts are correct.
312+
EXPECT_EQ(rel.getNumDomainVars(), 3u);
313+
EXPECT_EQ(rel.getNumRangeVars(), 3u);
314+
EXPECT_EQ(rel.getNumSymbolVars(), 0u);
315+
316+
// Check if identifiers are transferred correctly.
317+
EXPECT_TRUE(expectedRel.isEqual(rel));
318+
}

0 commit comments

Comments
 (0)