Skip to content

Commit fca8ef5

Browse files
[mlir][Interfaces][NFC] ValueBoundsConstraintSet: Add columns for constant values/dims
`ValueBoundsConstraintSet` maintains an internal constraint set (`IntegerRelation`), where every analyzed index-typed SSA value or dimension of a shaped type is represented with a dimension/symbol. Prior to this change, index-typed values with a statically known constant value and static shaped type dimensions were not added to the constraint set. Instead, `getExpr` directly returned an affine constrant expression. With this commit, dynamic and static values/dimension sizes are treated in the same way: in either case, a dimension/symbol is added to the constraint set. This is needed for a subsequent commit that adds support for branches.
1 parent 5e4a443 commit fca8ef5

File tree

2 files changed

+56
-17
lines changed

2 files changed

+56
-17
lines changed

mlir/include/mlir/Interfaces/ValueBoundsOpInterface.h

Lines changed: 7 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -295,14 +295,19 @@ class ValueBoundsConstraintSet
295295
/// value/dimension exists in the constraint set.
296296
int64_t getPos(Value value, std::optional<int64_t> dim = std::nullopt) const;
297297

298+
/// Return an affine expression that represents column `pos` in the constraint
299+
/// set.
300+
AffineExpr getPosExpr(int64_t pos);
301+
298302
/// Insert a value/dimension into the constraint set. If `isSymbol` is set to
299303
/// "false", a dimension is added. The value/dimension is added to the
300-
/// worklist.
304+
/// worklist if `addToWorklist` is set.
301305
///
302306
/// Note: There are certain affine restrictions wrt. dimensions. E.g., they
303307
/// cannot be multiplied. Furthermore, bounds can only be queried for
304308
/// dimensions but not for symbols.
305-
int64_t insert(Value value, std::optional<int64_t> dim, bool isSymbol = true);
309+
int64_t insert(Value value, std::optional<int64_t> dim, bool isSymbol = true,
310+
bool addToWorklist = true);
306311

307312
/// Insert an anonymous column into the constraint set. The column is not
308313
/// bound to any value/dimension. If `isSymbol` is set to "false", a dimension

mlir/lib/Interfaces/ValueBoundsOpInterface.cpp

Lines changed: 49 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -110,25 +110,47 @@ AffineExpr ValueBoundsConstraintSet::getExpr(Value value,
110110
assertValidValueDim(value, dim);
111111
#endif // NDEBUG
112112

113+
// Check if the value/dim is statically known. In that case, an affine
114+
// constant expression should be returned. This allows us to support
115+
// multiplications with constants. (Multiplications of two columns in the
116+
// constraint set is not supported.)
117+
std::optional<int64_t> constSize = std::nullopt;
113118
auto shapedType = dyn_cast<ShapedType>(value.getType());
114119
if (shapedType) {
115-
// Static dimension: return constant directly.
116120
if (shapedType.hasRank() && !shapedType.isDynamicDim(*dim))
117-
return builder.getAffineConstantExpr(shapedType.getDimSize(*dim));
118-
} else {
119-
// Constant index value: return directly.
120-
if (auto constInt = ::getConstantIntValue(value))
121-
return builder.getAffineConstantExpr(*constInt);
121+
constSize = shapedType.getDimSize(*dim);
122+
} else if (auto constInt = ::getConstantIntValue(value)) {
123+
constSize = *constInt;
122124
}
123125

124-
// Dynamic value: add to constraint set.
126+
// If the value/dim is already mapped, return the corresponding expression
127+
// directly.
125128
ValueDim valueDim = std::make_pair(value, dim.value_or(kIndexValue));
126-
if (!valueDimToPosition.contains(valueDim))
127-
(void)insert(value, dim);
128-
int64_t pos = getPos(value, dim);
129-
return pos < cstr.getNumDimVars()
130-
? builder.getAffineDimExpr(pos)
131-
: builder.getAffineSymbolExpr(pos - cstr.getNumDimVars());
129+
if (valueDimToPosition.contains(valueDim)) {
130+
// If it is a constant, return an affine constant expression. Otherwise,
131+
// return an affine expression that represents the respective column in the
132+
// constraint set.
133+
if (constSize)
134+
return builder.getAffineConstantExpr(*constSize);
135+
return getPosExpr(getPos(value, dim));
136+
}
137+
138+
if (constSize) {
139+
// Constant index value/dim: add column to the constraint set, add EQ bound
140+
// and return an affine constant expression without pushing the newly added
141+
// column to the worklist.
142+
(void)insert(value, dim, /*isSymbol=*/true, /*addToWorklist=*/false);
143+
if (shapedType)
144+
bound(value)[*dim] == *constSize;
145+
else
146+
bound(value) == *constSize;
147+
return builder.getAffineConstantExpr(*constSize);
148+
}
149+
150+
// Dynamic value/dim: insert column to the constraint set and put it on the
151+
// worklist. Return an affine expression that represents the newly inserted
152+
// column in the constraint set.
153+
return getPosExpr(insert(value, dim, /*isSymbol=*/true));
132154
}
133155

134156
AffineExpr ValueBoundsConstraintSet::getExpr(OpFoldResult ofr) {
@@ -145,7 +167,7 @@ AffineExpr ValueBoundsConstraintSet::getExpr(int64_t constant) {
145167

146168
int64_t ValueBoundsConstraintSet::insert(Value value,
147169
std::optional<int64_t> dim,
148-
bool isSymbol) {
170+
bool isSymbol, bool addToWorklist) {
149171
#ifndef NDEBUG
150172
assertValidValueDim(value, dim);
151173
#endif // NDEBUG
@@ -160,7 +182,12 @@ int64_t ValueBoundsConstraintSet::insert(Value value,
160182
if (positionToValueDim[i].has_value())
161183
valueDimToPosition[*positionToValueDim[i]] = i;
162184

163-
worklist.push(pos);
185+
if (addToWorklist) {
186+
LLVM_DEBUG(llvm::dbgs() << "Push to worklist: " << value
187+
<< " (dim: " << dim.value_or(kIndexValue) << ")\n");
188+
worklist.push(pos);
189+
}
190+
164191
return pos;
165192
}
166193

@@ -190,6 +217,13 @@ int64_t ValueBoundsConstraintSet::getPos(Value value,
190217
return it->second;
191218
}
192219

220+
AffineExpr ValueBoundsConstraintSet::getPosExpr(int64_t pos) {
221+
assert(pos >= 0 && pos < cstr.getNumDimAndSymbolVars() && "invalid position");
222+
return pos < cstr.getNumDimVars()
223+
? builder.getAffineDimExpr(pos)
224+
: builder.getAffineSymbolExpr(pos - cstr.getNumDimVars());
225+
}
226+
193227
static Operation *getOwnerOfValue(Value value) {
194228
if (auto bbArg = dyn_cast<BlockArgument>(value))
195229
return bbArg.getOwner()->getParentOp();

0 commit comments

Comments
 (0)