@@ -110,25 +110,47 @@ AffineExpr ValueBoundsConstraintSet::getExpr(Value value,
110
110
assertValidValueDim (value, dim);
111
111
#endif // NDEBUG
112
112
113
+ // Check if the value/dim is statically known. In that case, an affine
114
+ // constant expression should be returned. This allows us to support
115
+ // multiplications with constants. (Multiplications of two columns in the
116
+ // constraint set is not supported.)
117
+ std::optional<int64_t > constSize = std::nullopt;
113
118
auto shapedType = dyn_cast<ShapedType>(value.getType ());
114
119
if (shapedType) {
115
- // Static dimension: return constant directly.
116
120
if (shapedType.hasRank () && !shapedType.isDynamicDim (*dim))
117
- return builder.getAffineConstantExpr (shapedType.getDimSize (*dim));
118
- } else {
119
- // Constant index value: return directly.
120
- if (auto constInt = ::getConstantIntValue (value))
121
- return builder.getAffineConstantExpr (*constInt);
121
+ constSize = shapedType.getDimSize (*dim);
122
+ } else if (auto constInt = ::getConstantIntValue (value)) {
123
+ constSize = *constInt;
122
124
}
123
125
124
- // Dynamic value: add to constraint set.
126
+ // If the value/dim is already mapped, return the corresponding expression
127
+ // directly.
125
128
ValueDim valueDim = std::make_pair (value, dim.value_or (kIndexValue ));
126
- if (!valueDimToPosition.contains (valueDim))
127
- (void )insert (value, dim);
128
- int64_t pos = getPos (value, dim);
129
- return pos < cstr.getNumDimVars ()
130
- ? builder.getAffineDimExpr (pos)
131
- : builder.getAffineSymbolExpr (pos - cstr.getNumDimVars ());
129
+ if (valueDimToPosition.contains (valueDim)) {
130
+ // If it is a constant, return an affine constant expression. Otherwise,
131
+ // return an affine expression that represents the respective column in the
132
+ // constraint set.
133
+ if (constSize)
134
+ return builder.getAffineConstantExpr (*constSize);
135
+ return getPosExpr (getPos (value, dim));
136
+ }
137
+
138
+ if (constSize) {
139
+ // Constant index value/dim: add column to the constraint set, add EQ bound
140
+ // and return an affine constant expression without pushing the newly added
141
+ // column to the worklist.
142
+ (void )insert (value, dim, /* isSymbol=*/ true , /* addToWorklist=*/ false );
143
+ if (shapedType)
144
+ bound (value)[*dim] == *constSize;
145
+ else
146
+ bound (value) == *constSize;
147
+ return builder.getAffineConstantExpr (*constSize);
148
+ }
149
+
150
+ // Dynamic value/dim: insert column to the constraint set and put it on the
151
+ // worklist. Return an affine expression that represents the newly inserted
152
+ // column in the constraint set.
153
+ return getPosExpr (insert (value, dim, /* isSymbol=*/ true ));
132
154
}
133
155
134
156
AffineExpr ValueBoundsConstraintSet::getExpr (OpFoldResult ofr) {
@@ -145,7 +167,7 @@ AffineExpr ValueBoundsConstraintSet::getExpr(int64_t constant) {
145
167
146
168
int64_t ValueBoundsConstraintSet::insert (Value value,
147
169
std::optional<int64_t > dim,
148
- bool isSymbol) {
170
+ bool isSymbol, bool addToWorklist ) {
149
171
#ifndef NDEBUG
150
172
assertValidValueDim (value, dim);
151
173
#endif // NDEBUG
@@ -160,7 +182,12 @@ int64_t ValueBoundsConstraintSet::insert(Value value,
160
182
if (positionToValueDim[i].has_value ())
161
183
valueDimToPosition[*positionToValueDim[i]] = i;
162
184
163
- worklist.push (pos);
185
+ if (addToWorklist) {
186
+ LLVM_DEBUG (llvm::dbgs () << " Push to worklist: " << value
187
+ << " (dim: " << dim.value_or (kIndexValue ) << " )\n " );
188
+ worklist.push (pos);
189
+ }
190
+
164
191
return pos;
165
192
}
166
193
@@ -190,6 +217,13 @@ int64_t ValueBoundsConstraintSet::getPos(Value value,
190
217
return it->second ;
191
218
}
192
219
220
+ AffineExpr ValueBoundsConstraintSet::getPosExpr (int64_t pos) {
221
+ assert (pos >= 0 && pos < cstr.getNumDimAndSymbolVars () && " invalid position" );
222
+ return pos < cstr.getNumDimVars ()
223
+ ? builder.getAffineDimExpr (pos)
224
+ : builder.getAffineSymbolExpr (pos - cstr.getNumDimVars ());
225
+ }
226
+
193
227
static Operation *getOwnerOfValue (Value value) {
194
228
if (auto bbArg = dyn_cast<BlockArgument>(value))
195
229
return bbArg.getOwner ()->getParentOp ();
0 commit comments