@@ -107,25 +107,47 @@ AffineExpr ValueBoundsConstraintSet::getExpr(Value value,
107
107
assertValidValueDim (value, dim);
108
108
#endif // NDEBUG
109
109
110
+ // Check if the value/dim is statically known. In that case, an affine
111
+ // constant expression should be returned. This allows us to support
112
+ // multiplications with constants. (Multiplications of two columns in the
113
+ // constraint set is not supported.)
114
+ std::optional<int64_t > constSize = std::nullopt;
110
115
auto shapedType = dyn_cast<ShapedType>(value.getType ());
111
116
if (shapedType) {
112
- // Static dimension: return constant directly.
113
117
if (shapedType.hasRank () && !shapedType.isDynamicDim (*dim))
114
- return builder.getAffineConstantExpr (shapedType.getDimSize (*dim));
115
- } else {
116
- // Constant index value: return directly.
117
- if (auto constInt = ::getConstantIntValue (value))
118
- return builder.getAffineConstantExpr (*constInt);
118
+ constSize = shapedType.getDimSize (*dim);
119
+ } else if (auto constInt = ::getConstantIntValue (value)) {
120
+ constSize = *constInt;
119
121
}
120
122
121
- // Dynamic value: add to constraint set.
123
+ // If the value/dim is already mapped, return the corresponding expression
124
+ // directly.
122
125
ValueDim valueDim = std::make_pair (value, dim.value_or (kIndexValue ));
123
- if (!valueDimToPosition.contains (valueDim))
124
- (void )insert (value, dim);
125
- int64_t pos = getPos (value, dim);
126
- return pos < cstr.getNumDimVars ()
127
- ? builder.getAffineDimExpr (pos)
128
- : builder.getAffineSymbolExpr (pos - cstr.getNumDimVars ());
126
+ if (valueDimToPosition.contains (valueDim)) {
127
+ // If it is a constant, return an affine constant expression. Otherwise,
128
+ // return an affine expression that represents the respective column in the
129
+ // constraint set.
130
+ if (constSize)
131
+ return builder.getAffineConstantExpr (*constSize);
132
+ return getPosExpr (getPos (value, dim));
133
+ }
134
+
135
+ if (constSize) {
136
+ // Constant index value/dim: add column to the constraint set, add EQ bound
137
+ // and return an affine constant expression without pushing the newly added
138
+ // column to the worklist.
139
+ (void )insert (value, dim, /* isSymbol=*/ true , /* addToWorklist=*/ false );
140
+ if (shapedType)
141
+ bound (value)[*dim] == *constSize;
142
+ else
143
+ bound (value) == *constSize;
144
+ return builder.getAffineConstantExpr (*constSize);
145
+ }
146
+
147
+ // Dynamic value/dim: insert column to the constraint set and put it on the
148
+ // worklist. Return an affine expression that represents the newly inserted
149
+ // column in the constraint set.
150
+ return getPosExpr (insert (value, dim, /* isSymbol=*/ true ));
129
151
}
130
152
131
153
AffineExpr ValueBoundsConstraintSet::getExpr (OpFoldResult ofr) {
@@ -142,7 +164,7 @@ AffineExpr ValueBoundsConstraintSet::getExpr(int64_t constant) {
142
164
143
165
int64_t ValueBoundsConstraintSet::insert (Value value,
144
166
std::optional<int64_t > dim,
145
- bool isSymbol) {
167
+ bool isSymbol, bool addToWorklist ) {
146
168
#ifndef NDEBUG
147
169
assertValidValueDim (value, dim);
148
170
#endif // NDEBUG
@@ -157,7 +179,12 @@ int64_t ValueBoundsConstraintSet::insert(Value value,
157
179
if (positionToValueDim[i].has_value ())
158
180
valueDimToPosition[*positionToValueDim[i]] = i;
159
181
160
- worklist.push (pos);
182
+ if (addToWorklist) {
183
+ LLVM_DEBUG (llvm::dbgs () << " Push to worklist: " << value
184
+ << " (dim: " << dim.value_or (kIndexValue ) << " )\n " );
185
+ worklist.push (pos);
186
+ }
187
+
161
188
return pos;
162
189
}
163
190
@@ -187,6 +214,13 @@ int64_t ValueBoundsConstraintSet::getPos(Value value,
187
214
return it->second ;
188
215
}
189
216
217
+ AffineExpr ValueBoundsConstraintSet::getPosExpr (int64_t pos) {
218
+ assert (pos >= 0 && pos < cstr.getNumDimAndSymbolVars () && " invalid position" );
219
+ return pos < cstr.getNumDimVars ()
220
+ ? builder.getAffineDimExpr (pos)
221
+ : builder.getAffineSymbolExpr (pos - cstr.getNumDimVars ());
222
+ }
223
+
190
224
static Operation *getOwnerOfValue (Value value) {
191
225
if (auto bbArg = dyn_cast<BlockArgument>(value))
192
226
return bbArg.getOwner ()->getParentOp ();
0 commit comments