@@ -187,14 +187,22 @@ DeletionKind memref::LoadOp::removeBlockingUses(
187
187
return DeletionKind::Delete;
188
188
}
189
189
190
- // / Returns the index of a memref in attribute form, given its indices.
190
+ // / Returns the index of a memref in attribute form, given its indices. Returns
191
+ // / a null pointer if whether the indices form a valid index for the provided
192
+ // / MemRefType cannot be computed. The indices must come from a valid memref
193
+ // / StoreOp or LoadOp.
191
194
static Attribute getAttributeIndexFromIndexOperands (MLIRContext *ctx,
192
- ValueRange indices) {
195
+ ValueRange indices,
196
+ MemRefType memrefType) {
193
197
SmallVector<Attribute> index;
194
- for (Value coord : indices) {
198
+ for (auto [ coord, dimSize] : llvm::zip ( indices, memrefType. getShape ()) ) {
195
199
IntegerAttr coordAttr;
196
200
if (!matchPattern (coord, m_Constant<IntegerAttr>(&coordAttr)))
197
201
return {};
202
+ // MemRefType shape dimensions are always positive (checked by verifier).
203
+ std::optional<uint64_t > coordInt = coordAttr.getValue ().tryZExtValue ();
204
+ if (!coordInt || coordInt.value () >= static_cast <uint64_t >(dimSize))
205
+ return {};
198
206
index.push_back (coordAttr);
199
207
}
200
208
return ArrayAttr::get (ctx, index);
@@ -205,8 +213,8 @@ bool memref::LoadOp::canRewire(const DestructurableMemorySlot &slot,
205
213
SmallVectorImpl<MemorySlot> &mustBeSafelyUsed) {
206
214
if (slot.ptr != getMemRef ())
207
215
return false ;
208
- Attribute index =
209
- getAttributeIndexFromIndexOperands ( getContext (), getIndices ());
216
+ Attribute index = getAttributeIndexFromIndexOperands (
217
+ getContext (), getIndices (), getMemRefType ());
210
218
if (!index)
211
219
return false ;
212
220
usedIndices.insert (index);
@@ -216,8 +224,8 @@ bool memref::LoadOp::canRewire(const DestructurableMemorySlot &slot,
216
224
DeletionKind memref::LoadOp::rewire (const DestructurableMemorySlot &slot,
217
225
DenseMap<Attribute, MemorySlot> &subslots,
218
226
RewriterBase &rewriter) {
219
- Attribute index =
220
- getAttributeIndexFromIndexOperands ( getContext (), getIndices ());
227
+ Attribute index = getAttributeIndexFromIndexOperands (
228
+ getContext (), getIndices (), getMemRefType ());
221
229
const MemorySlot &memorySlot = subslots.at (index);
222
230
rewriter.updateRootInPlace (*this , [&]() {
223
231
setMemRef (memorySlot.ptr );
@@ -258,8 +266,8 @@ bool memref::StoreOp::canRewire(const DestructurableMemorySlot &slot,
258
266
SmallVectorImpl<MemorySlot> &mustBeSafelyUsed) {
259
267
if (slot.ptr != getMemRef () || getValue () == slot.ptr )
260
268
return false ;
261
- Attribute index =
262
- getAttributeIndexFromIndexOperands ( getContext (), getIndices ());
269
+ Attribute index = getAttributeIndexFromIndexOperands (
270
+ getContext (), getIndices (), getMemRefType ());
263
271
if (!index || !slot.elementPtrs .contains (index))
264
272
return false ;
265
273
usedIndices.insert (index);
@@ -269,8 +277,8 @@ bool memref::StoreOp::canRewire(const DestructurableMemorySlot &slot,
269
277
DeletionKind memref::StoreOp::rewire (const DestructurableMemorySlot &slot,
270
278
DenseMap<Attribute, MemorySlot> &subslots,
271
279
RewriterBase &rewriter) {
272
- Attribute index =
273
- getAttributeIndexFromIndexOperands ( getContext (), getIndices ());
280
+ Attribute index = getAttributeIndexFromIndexOperands (
281
+ getContext (), getIndices (), getMemRefType ());
274
282
const MemorySlot &memorySlot = subslots.at (index);
275
283
rewriter.updateRootInPlace (*this , [&]() {
276
284
setMemRef (memorySlot.ptr );
0 commit comments