Skip to content

Commit cf055b4

Browse files
committed
[MLIR][Presburger] Use Identifiers outside Presburger library
1 parent 2ec01d5 commit cf055b4

File tree

7 files changed

+138
-147
lines changed

7 files changed

+138
-147
lines changed

mlir/include/mlir/Analysis/FlatLinearValueConstraints.h

Lines changed: 58 additions & 36 deletions
Original file line numberDiff line numberDiff line change
@@ -205,6 +205,8 @@ class FlatLinearConstraints : public presburger::IntegerPolyhedron {
205205
/// where each non-local variable can have an SSA Value attached to it.
206206
class FlatLinearValueConstraints : public FlatLinearConstraints {
207207
public:
208+
using Identifier = presburger::Identifier;
209+
208210
/// Constructs a constraint system reserving memory for the specified number
209211
/// of constraints and variables. `valArgs` are the optional SSA values
210212
/// associated with each dimension/symbol. These must either be empty or match
@@ -217,11 +219,12 @@ class FlatLinearValueConstraints : public FlatLinearConstraints {
217219
: FlatLinearConstraints(numReservedInequalities, numReservedEqualities,
218220
numReservedCols, numDims, numSymbols, numLocals) {
219221
assert(valArgs.empty() || valArgs.size() == getNumDimAndSymbolVars());
220-
values.reserve(numReservedCols);
221-
if (valArgs.empty())
222-
values.resize(getNumDimAndSymbolVars(), std::nullopt);
223-
else
224-
values.append(valArgs.begin(), valArgs.end());
222+
// Use values in space for FlatLinearValueConstraints.
223+
space.resetIds();
224+
// Set the values for the non-local variables.
225+
for (unsigned i = 0, e = valArgs.size(); i < e; ++i)
226+
if (valArgs[i])
227+
setValue(i, *valArgs[i]);
225228
}
226229

227230
/// Constructs a constraint system reserving memory for the specified number
@@ -236,11 +239,12 @@ class FlatLinearValueConstraints : public FlatLinearConstraints {
236239
: FlatLinearConstraints(numReservedInequalities, numReservedEqualities,
237240
numReservedCols, numDims, numSymbols, numLocals) {
238241
assert(valArgs.empty() || valArgs.size() == getNumDimAndSymbolVars());
239-
values.reserve(numReservedCols);
240-
if (valArgs.empty())
241-
values.resize(getNumDimAndSymbolVars(), std::nullopt);
242-
else
243-
values.append(valArgs.begin(), valArgs.end());
242+
// Use values in space for FlatLinearValueConstraints.
243+
space.resetIds();
244+
// Set the values for the non-local variables.
245+
for (unsigned i = 0, e = valArgs.size(); i < e; ++i)
246+
if (valArgs[i])
247+
setValue(i, valArgs[i]);
244248
}
245249

246250
/// Constructs a constraint system with the specified number of dimensions
@@ -273,10 +277,12 @@ class FlatLinearValueConstraints : public FlatLinearConstraints {
273277
ArrayRef<std::optional<Value>> valArgs = {})
274278
: FlatLinearConstraints(fac) {
275279
assert(valArgs.empty() || valArgs.size() == getNumDimAndSymbolVars());
276-
if (valArgs.empty())
277-
values.resize(getNumDimAndSymbolVars(), std::nullopt);
278-
else
279-
values.append(valArgs.begin(), valArgs.end());
280+
// Use values in space for FlatLinearValueConstraints.
281+
space.resetIds();
282+
// Set the values for the non-local variables.
283+
for (unsigned i = 0, e = valArgs.size(); i < e; ++i)
284+
if (valArgs[i])
285+
setValue(i, *valArgs[i]);
280286
}
281287

282288
/// Creates an affine constraint system from an IntegerSet.
@@ -302,7 +308,9 @@ class FlatLinearValueConstraints : public FlatLinearConstraints {
302308
inline Value getValue(unsigned pos) const {
303309
assert(pos < getNumDimAndSymbolVars() && "Invalid position");
304310
assert(hasValue(pos) && "variable's Value not set");
305-
return *values[pos];
311+
VarKind kind = getVarKindAt(pos);
312+
unsigned relativePos = pos - getVarKindOffset(kind);
313+
return space.getId(kind, relativePos).getValue<Value>();
306314
}
307315

308316
/// Returns the Values associated with variables in range [start, end).
@@ -317,21 +325,44 @@ class FlatLinearValueConstraints : public FlatLinearConstraints {
317325
values->push_back(getValue(i));
318326
}
319327

320-
inline ArrayRef<std::optional<Value>> getMaybeValues() const {
321-
return {values.data(), values.size()};
328+
inline SmallVector<std::optional<Value>> getMaybeValues() const {
329+
SmallVector<std::optional<Value>> maybeValues;
330+
maybeValues.reserve(getNumDimAndSymbolVars());
331+
for (unsigned i = 0, e = getNumDimAndSymbolVars(); i < e; i++) {
332+
if (hasValue(i))
333+
maybeValues.push_back(getValue(i));
334+
else
335+
maybeValues.push_back(std::nullopt);
336+
}
337+
return maybeValues;
322338
}
323339

324-
inline ArrayRef<std::optional<Value>>
325-
getMaybeValues(presburger::VarKind kind) const {
326-
assert(kind != VarKind::Local &&
327-
"Local variables do not have any value attached to them.");
328-
return {values.data() + getVarKindOffset(kind), getNumVarKind(kind)};
329-
}
340+
inline SmallVector<std::optional<Value>>
341+
getMaybeValues(presburger::VarKind kind) const {
342+
assert(kind != VarKind::Local &&
343+
"Local variables do not have any value attached to them.");
344+
SmallVector<std::optional<Value>> maybeValues;
345+
maybeValues.reserve(getNumVarKind(kind));
346+
for (unsigned i = 0, e = getNumVarKind(kind); i < e; i++) {
347+
Identifier id = space.getId(kind, i);
348+
if (id.hasValue())
349+
maybeValues.push_back(space.getId(kind, i).getValue<Value>());
350+
else
351+
maybeValues.push_back(std::nullopt);
352+
}
353+
return maybeValues;
354+
}
330355

331356
/// Returns true if the pos^th variable has an associated Value.
332357
inline bool hasValue(unsigned pos) const {
333358
assert(pos < getNumDimAndSymbolVars() && "Invalid position");
334-
return values[pos].has_value();
359+
VarKind kind = getVarKindAt(pos);
360+
unsigned relativePos = pos - getVarKindOffset(kind);
361+
return space.getId(kind, relativePos).hasValue();
362+
}
363+
364+
void resetValues() {
365+
space.resetIds();
335366
}
336367

337368
unsigned appendDimVar(ValueRange vals);
@@ -360,7 +391,9 @@ class FlatLinearValueConstraints : public FlatLinearConstraints {
360391
/// Sets the Value associated with the pos^th variable.
361392
inline void setValue(unsigned pos, Value val) {
362393
assert(pos < getNumDimAndSymbolVars() && "invalid var position");
363-
values[pos] = val;
394+
VarKind kind = getVarKindAt(pos);
395+
unsigned relativePos = pos - getVarKindOffset(kind);
396+
space.getId(kind, relativePos) = presburger::Identifier(val);
364397
}
365398

366399
/// Sets the Values associated with the variables in the range [start, end).
@@ -455,17 +488,6 @@ class FlatLinearValueConstraints : public FlatLinearConstraints {
455488
// See implementation comments for more details.
456489
void fourierMotzkinEliminate(unsigned pos, bool darkShadow = false,
457490
bool *isResultIntegerExact = nullptr) override;
458-
459-
/// Returns false if the fields corresponding to various variable counts, or
460-
/// equality/inequality buffer sizes aren't consistent; true otherwise. This
461-
/// is meant to be used within an assert internally.
462-
bool hasConsistentState() const override;
463-
464-
/// Values corresponding to the (column) non-local variables of this
465-
/// constraint system appearing in the order the variables correspond to
466-
/// columns. Variables that aren't associated with any Value are set to
467-
/// std::nullopt.
468-
SmallVector<std::optional<Value>, 8> values;
469491
};
470492

471493
/// Flattens 'expr' into 'flattenedExpr', which contains the coefficients of the

mlir/include/mlir/Dialect/Affine/Analysis/AffineAnalysis.h

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,7 @@
1515
#ifndef MLIR_DIALECT_AFFINE_ANALYSIS_AFFINEANALYSIS_H
1616
#define MLIR_DIALECT_AFFINE_ANALYSIS_AFFINEANALYSIS_H
1717

18+
#include "mlir/Analysis/Presburger/IntegerRelation.h"
1819
#include "mlir/Dialect/Arith/IR/Arith.h"
1920
#include "mlir/IR/Value.h"
2021
#include "llvm/ADT/SmallVector.h"
@@ -115,7 +116,7 @@ struct MemRefAccess {
115116
///
116117
/// Returns failure for yet unimplemented/unsupported cases (see docs of
117118
/// mlir::getIndexSet and mlir::getRelationFromMap for these cases).
118-
LogicalResult getAccessRelation(FlatAffineRelation &accessRel) const;
119+
LogicalResult getAccessRelation(presburger::IntegerRelation &accessRel) const;
119120

120121
/// Populates 'accessMap' with composition of AffineApplyOps reachable from
121122
/// 'indices'.

mlir/include/mlir/Dialect/Affine/Analysis/AffineStructures.h

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -251,9 +251,9 @@ class FlatAffineRelation : public FlatAffineValueConstraints {
251251
/// For AffineValueMap, the domain and symbols have Value set corresponding to
252252
/// the Value in `map`. Returns failure if the AffineMap could not be flattened
253253
/// (i.e., semi-affine is not yet handled).
254-
LogicalResult getRelationFromMap(AffineMap &map, FlatAffineRelation &rel);
254+
LogicalResult getRelationFromMap(AffineMap &map, presburger::IntegerRelation &rel);
255255
LogicalResult getRelationFromMap(const AffineValueMap &map,
256-
FlatAffineRelation &rel);
256+
presburger::IntegerRelation &rel);
257257

258258
} // namespace affine
259259
} // namespace mlir

mlir/lib/Analysis/FlatLinearValueConstraints.cpp

Lines changed: 29 additions & 80 deletions
Original file line numberDiff line numberDiff line change
@@ -817,13 +817,12 @@ FlatLinearValueConstraints::FlatLinearValueConstraints(IntegerSet set,
817817
set.getNumDims() + set.getNumSymbols() + 1,
818818
set.getNumDims(), set.getNumSymbols(),
819819
/*numLocals=*/0) {
820-
// Populate values.
821-
if (operands.empty()) {
822-
values.resize(getNumDimAndSymbolVars(), std::nullopt);
823-
} else {
824-
assert(set.getNumInputs() == operands.size() && "operand count mismatch");
825-
values.assign(operands.begin(), operands.end());
826-
}
820+
// Use values in space for FlatLinearValueConstraints.
821+
space.resetIds();
822+
// Set the values for the non-local variables.
823+
for (unsigned i = 0, e = operands.size(); i < e; ++i)
824+
setValue(i, operands[i]);
825+
827826

828827
// Flatten expressions and add them to the constraint system.
829828
std::vector<SmallVector<int64_t, 8>> flatExprs;
@@ -873,11 +872,6 @@ unsigned FlatLinearValueConstraints::insertVar(VarKind kind, unsigned pos,
873872
unsigned num) {
874873
unsigned absolutePos = IntegerPolyhedron::insertVar(kind, pos, num);
875874

876-
if (kind != VarKind::Local) {
877-
values.insert(values.begin() + absolutePos, num, std::nullopt);
878-
assert(values.size() == getNumDimAndSymbolVars());
879-
}
880-
881875
return absolutePos;
882876
}
883877

@@ -890,22 +884,17 @@ unsigned FlatLinearValueConstraints::insertVar(VarKind kind, unsigned pos,
890884
unsigned absolutePos = IntegerPolyhedron::insertVar(kind, pos, num);
891885

892886
// If a Value is provided, insert it; otherwise use std::nullopt.
893-
for (unsigned i = 0; i < num; ++i)
894-
values.insert(values.begin() + absolutePos + i,
895-
vals[i] ? std::optional<Value>(vals[i]) : std::nullopt);
887+
for (unsigned i = 0, e = vals.size(); i < e; ++i)
888+
setValue(absolutePos + i, vals[i]);
896889

897-
assert(values.size() == getNumDimAndSymbolVars());
898890
return absolutePos;
899891
}
900892

901893
/// Checks if two constraint systems are in the same space, i.e., if they are
902894
/// associated with the same set of variables, appearing in the same order.
903895
static bool areVarsAligned(const FlatLinearValueConstraints &a,
904896
const FlatLinearValueConstraints &b) {
905-
return a.getNumDimVars() == b.getNumDimVars() &&
906-
a.getNumSymbolVars() == b.getNumSymbolVars() &&
907-
a.getNumVars() == b.getNumVars() &&
908-
a.getMaybeValues().equals(b.getMaybeValues());
897+
return a.getSpace().isAligned(b.getSpace());
909898
}
910899

911900
/// Calls areVarsAligned to check if two constraint systems have the same set
@@ -928,12 +917,14 @@ static bool LLVM_ATTRIBUTE_UNUSED areVarsUnique(
928917
return true;
929918

930919
SmallPtrSet<Value, 8> uniqueVars;
931-
ArrayRef<std::optional<Value>> maybeValues =
932-
cst.getMaybeValues().slice(start, end - start);
933-
for (std::optional<Value> val : maybeValues) {
920+
SmallVector<std::optional<Value>, 8> maybeValuesAll = cst.getMaybeValues();
921+
ArrayRef<std::optional<Value>> maybeValues = {maybeValuesAll.data() + start,
922+
maybeValuesAll.data() + end};
923+
924+
for (std::optional<Value> val : maybeValues)
934925
if (val && !uniqueVars.insert(*val).second)
935926
return false;
936-
}
927+
937928
return true;
938929
}
939930

@@ -1058,20 +1049,9 @@ void FlatLinearValueConstraints::mergeSymbolVars(
10581049
"expected same number of symbols");
10591050
}
10601051

1061-
bool FlatLinearValueConstraints::hasConsistentState() const {
1062-
return IntegerPolyhedron::hasConsistentState() &&
1063-
values.size() == getNumDimAndSymbolVars();
1064-
}
1065-
10661052
void FlatLinearValueConstraints::removeVarRange(VarKind kind, unsigned varStart,
10671053
unsigned varLimit) {
10681054
IntegerPolyhedron::removeVarRange(kind, varStart, varLimit);
1069-
unsigned offset = getVarKindOffset(kind);
1070-
1071-
if (kind != VarKind::Local) {
1072-
values.erase(values.begin() + varStart + offset,
1073-
values.begin() + varLimit + offset);
1074-
}
10751055
}
10761056

10771057
AffineMap
@@ -1089,14 +1069,15 @@ FlatLinearValueConstraints::computeAlignedMap(AffineMap map,
10891069

10901070
dims.reserve(getNumDimVars());
10911071
syms.reserve(getNumSymbolVars());
1092-
for (unsigned i = getVarKindOffset(VarKind::SetDim),
1093-
e = getVarKindEnd(VarKind::SetDim);
1094-
i < e; ++i)
1095-
dims.push_back(values[i] ? *values[i] : Value());
1096-
for (unsigned i = getVarKindOffset(VarKind::Symbol),
1097-
e = getVarKindEnd(VarKind::Symbol);
1098-
i < e; ++i)
1099-
syms.push_back(values[i] ? *values[i] : Value());
1072+
for (unsigned i = 0, e = getNumVarKind(VarKind::SetDim); i < e; ++i) {
1073+
Identifier id = space.getId(VarKind::SetDim, i);
1074+
dims.push_back(id.hasValue() ? Value(id.getValue<Value>()) : Value());
1075+
}
1076+
for (unsigned i = 0, e = getNumVarKind(VarKind::Symbol); i < e; ++i) {
1077+
Identifier id = space.getId(VarKind::Symbol, i);
1078+
syms.push_back(id.hasValue() ? Value(id.getValue<Value>()) : Value());
1079+
}
1080+
11001081

11011082
AffineMap alignedMap =
11021083
alignAffineMapWithValues(map, operands, dims, syms, newSymsPtr);
@@ -1110,8 +1091,7 @@ FlatLinearValueConstraints::computeAlignedMap(AffineMap map,
11101091
bool FlatLinearValueConstraints::findVar(Value val, unsigned *pos,
11111092
unsigned offset) const {
11121093
unsigned i = offset;
1113-
for (const auto &mayBeVar :
1114-
ArrayRef<std::optional<Value>>(values).drop_front(offset)) {
1094+
for (const auto &mayBeVar : getMaybeValues()) {
11151095
if (mayBeVar && *mayBeVar == val) {
11161096
*pos = i;
11171097
return true;
@@ -1122,25 +1102,12 @@ bool FlatLinearValueConstraints::findVar(Value val, unsigned *pos,
11221102
}
11231103

11241104
bool FlatLinearValueConstraints::containsVar(Value val) const {
1125-
return llvm::any_of(values, [&](const std::optional<Value> &mayBeVar) {
1126-
return mayBeVar && *mayBeVar == val;
1127-
});
1105+
unsigned pos;
1106+
return findVar(val, &pos, 0);
11281107
}
11291108

11301109
void FlatLinearValueConstraints::swapVar(unsigned posA, unsigned posB) {
11311110
IntegerPolyhedron::swapVar(posA, posB);
1132-
1133-
if (getVarKindAt(posA) == VarKind::Local &&
1134-
getVarKindAt(posB) == VarKind::Local)
1135-
return;
1136-
1137-
// Treat value of a local variable as std::nullopt.
1138-
if (getVarKindAt(posA) == VarKind::Local)
1139-
values[posB] = std::nullopt;
1140-
else if (getVarKindAt(posB) == VarKind::Local)
1141-
values[posA] = std::nullopt;
1142-
else
1143-
std::swap(values[posA], values[posB]);
11441111
}
11451112

11461113
void FlatLinearValueConstraints::addBound(BoundType type, Value val,
@@ -1182,27 +1149,13 @@ void FlatLinearValueConstraints::printSpace(raw_ostream &os) const {
11821149

11831150
void FlatLinearValueConstraints::clearAndCopyFrom(
11841151
const IntegerRelation &other) {
1185-
1186-
if (auto *otherValueSet =
1187-
dyn_cast<const FlatLinearValueConstraints>(&other)) {
1188-
*this = *otherValueSet;
1189-
} else {
1190-
*static_cast<IntegerRelation *>(this) = other;
1191-
values.clear();
1192-
values.resize(getNumDimAndSymbolVars(), std::nullopt);
1193-
}
1152+
IntegerPolyhedron::clearAndCopyFrom(other);
11941153
}
11951154

11961155
void FlatLinearValueConstraints::fourierMotzkinEliminate(
11971156
unsigned pos, bool darkShadow, bool *isResultIntegerExact) {
1198-
SmallVector<std::optional<Value>, 8> newVals = values;
1199-
if (getVarKindAt(pos) != VarKind::Local)
1200-
newVals.erase(newVals.begin() + pos);
1201-
// Note: Base implementation discards all associated Values.
12021157
IntegerPolyhedron::fourierMotzkinEliminate(pos, darkShadow,
12031158
isResultIntegerExact);
1204-
values = newVals;
1205-
assert(values.size() == getNumDimAndSymbolVars());
12061159
}
12071160

12081161
void FlatLinearValueConstraints::projectOut(Value val) {
@@ -1215,11 +1168,7 @@ void FlatLinearValueConstraints::projectOut(Value val) {
12151168

12161169
LogicalResult FlatLinearValueConstraints::unionBoundingBox(
12171170
const FlatLinearValueConstraints &otherCst) {
1218-
assert(otherCst.getNumDimVars() == getNumDimVars() && "dims mismatch");
1219-
assert(otherCst.getMaybeValues()
1220-
.slice(0, getNumDimVars())
1221-
.equals(getMaybeValues().slice(0, getNumDimVars())) &&
1222-
"dim values mismatch");
1171+
assert(otherCst.getSpace().isAligned(getSpace(), VarKind::SetDim) && "dims mismatch");
12231172
assert(otherCst.getNumLocalVars() == 0 && "local vars not supported here");
12241173
assert(getNumLocalVars() == 0 && "local vars not supported yet here");
12251174

0 commit comments

Comments
 (0)