Skip to content

Commit d53abc4

Browse files
[mlir][Interfaces][NFC] ValueBoundsConstraintSet: Add columns for constant values/dims
`ValueBoundsConstraintSet` maintains an internal constraint set (`IntegerRelation`), where every analyzed index-typed SSA value or dimension of a shaped type is represented with a dimension/symbol. Prior to this change, index-typed values with a statically known constant value and static shaped type dimensions were not added to the constraint set. Instead, `getExpr` directly returned an affine constrant expression. With this commit, dynamic and static values/dimension sizes are treated in the same way: in either case, a dimension/symbol is added to the constraint set. This is needed for a subsequent commit that adds support for branches.
1 parent cceedc9 commit d53abc4

File tree

2 files changed

+56
-17
lines changed

2 files changed

+56
-17
lines changed

mlir/include/mlir/Interfaces/ValueBoundsOpInterface.h

Lines changed: 7 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -292,14 +292,19 @@ class ValueBoundsConstraintSet
292292
/// value/dimension exists in the constraint set.
293293
int64_t getPos(Value value, std::optional<int64_t> dim = std::nullopt) const;
294294

295+
/// Return an affine expression that represents column `pos` in the constraint
296+
/// set.
297+
AffineExpr getPosExpr(int64_t pos);
298+
295299
/// Insert a value/dimension into the constraint set. If `isSymbol` is set to
296300
/// "false", a dimension is added. The value/dimension is added to the
297-
/// worklist.
301+
/// worklist if `addToWorklist` is set.
298302
///
299303
/// Note: There are certain affine restrictions wrt. dimensions. E.g., they
300304
/// cannot be multiplied. Furthermore, bounds can only be queried for
301305
/// dimensions but not for symbols.
302-
int64_t insert(Value value, std::optional<int64_t> dim, bool isSymbol = true);
306+
int64_t insert(Value value, std::optional<int64_t> dim, bool isSymbol = true,
307+
bool addToWorklist = true);
303308

304309
/// Insert an anonymous column into the constraint set. The column is not
305310
/// bound to any value/dimension. If `isSymbol` is set to "false", a dimension

mlir/lib/Interfaces/ValueBoundsOpInterface.cpp

Lines changed: 49 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -107,25 +107,47 @@ AffineExpr ValueBoundsConstraintSet::getExpr(Value value,
107107
assertValidValueDim(value, dim);
108108
#endif // NDEBUG
109109

110+
// Check if the value/dim is statically known. In that case, an affine
111+
// constant expression should be returned. This allows us to support
112+
// multiplications with constants. (Multiplications of two columns in the
113+
// constraint set is not supported.)
114+
std::optional<int64_t> constSize = std::nullopt;
110115
auto shapedType = dyn_cast<ShapedType>(value.getType());
111116
if (shapedType) {
112-
// Static dimension: return constant directly.
113117
if (shapedType.hasRank() && !shapedType.isDynamicDim(*dim))
114-
return builder.getAffineConstantExpr(shapedType.getDimSize(*dim));
115-
} else {
116-
// Constant index value: return directly.
117-
if (auto constInt = ::getConstantIntValue(value))
118-
return builder.getAffineConstantExpr(*constInt);
118+
constSize = shapedType.getDimSize(*dim);
119+
} else if (auto constInt = ::getConstantIntValue(value)) {
120+
constSize = *constInt;
119121
}
120122

121-
// Dynamic value: add to constraint set.
123+
// If the value/dim is already mapped, return the corresponding expression
124+
// directly.
122125
ValueDim valueDim = std::make_pair(value, dim.value_or(kIndexValue));
123-
if (!valueDimToPosition.contains(valueDim))
124-
(void)insert(value, dim);
125-
int64_t pos = getPos(value, dim);
126-
return pos < cstr.getNumDimVars()
127-
? builder.getAffineDimExpr(pos)
128-
: builder.getAffineSymbolExpr(pos - cstr.getNumDimVars());
126+
if (valueDimToPosition.contains(valueDim)) {
127+
// If it is a constant, return an affine constant expression. Otherwise,
128+
// return an affine expression that represents the respective column in the
129+
// constraint set.
130+
if (constSize)
131+
return builder.getAffineConstantExpr(*constSize);
132+
return getPosExpr(getPos(value, dim));
133+
}
134+
135+
if (constSize) {
136+
// Constant index value/dim: add column to the constraint set, add EQ bound
137+
// and return an affine constant expression without pushing the newly added
138+
// column to the worklist.
139+
(void)insert(value, dim, /*isSymbol=*/true, /*addToWorklist=*/false);
140+
if (shapedType)
141+
bound(value)[*dim] == *constSize;
142+
else
143+
bound(value) == *constSize;
144+
return builder.getAffineConstantExpr(*constSize);
145+
}
146+
147+
// Dynamic value/dim: insert column to the constraint set and put it on the
148+
// worklist. Return an affine expression that represents the newly inserted
149+
// column in the constraint set.
150+
return getPosExpr(insert(value, dim, /*isSymbol=*/true));
129151
}
130152

131153
AffineExpr ValueBoundsConstraintSet::getExpr(OpFoldResult ofr) {
@@ -142,7 +164,7 @@ AffineExpr ValueBoundsConstraintSet::getExpr(int64_t constant) {
142164

143165
int64_t ValueBoundsConstraintSet::insert(Value value,
144166
std::optional<int64_t> dim,
145-
bool isSymbol) {
167+
bool isSymbol, bool addToWorklist) {
146168
#ifndef NDEBUG
147169
assertValidValueDim(value, dim);
148170
#endif // NDEBUG
@@ -157,7 +179,12 @@ int64_t ValueBoundsConstraintSet::insert(Value value,
157179
if (positionToValueDim[i].has_value())
158180
valueDimToPosition[*positionToValueDim[i]] = i;
159181

160-
worklist.push(pos);
182+
if (addToWorklist) {
183+
LLVM_DEBUG(llvm::dbgs() << "Push to worklist: " << value
184+
<< " (dim: " << dim.value_or(kIndexValue) << ")\n");
185+
worklist.push(pos);
186+
}
187+
161188
return pos;
162189
}
163190

@@ -187,6 +214,13 @@ int64_t ValueBoundsConstraintSet::getPos(Value value,
187214
return it->second;
188215
}
189216

217+
AffineExpr ValueBoundsConstraintSet::getPosExpr(int64_t pos) {
218+
assert(pos >= 0 && pos < cstr.getNumDimAndSymbolVars() && "invalid position");
219+
return pos < cstr.getNumDimVars()
220+
? builder.getAffineDimExpr(pos)
221+
: builder.getAffineSymbolExpr(pos - cstr.getNumDimVars());
222+
}
223+
190224
static Operation *getOwnerOfValue(Value value) {
191225
if (auto bbArg = dyn_cast<BlockArgument>(value))
192226
return bbArg.getOwner()->getParentOp();

0 commit comments

Comments
 (0)