Skip to content

Commit b142501

Browse files
authored
[mlir][memref] Fix segfault in SROA (#71063)
Fixes #70902. The out of bounds check in the SROA implementation for MemRef was not actually testing anything because it only operated on a store op which does not trigger the logic by itself. It is now checked for real and the underlying bug is fixed. I checked the LLVM implementation just in case but this should not happen as out-of-bound checks happen in GEP's verifier there.
1 parent ba13978 commit b142501

File tree

2 files changed

+42
-13
lines changed

2 files changed

+42
-13
lines changed

mlir/lib/Dialect/MemRef/IR/MemRefMemorySlot.cpp

Lines changed: 19 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -187,14 +187,22 @@ DeletionKind memref::LoadOp::removeBlockingUses(
187187
return DeletionKind::Delete;
188188
}
189189

190-
/// Returns the index of a memref in attribute form, given its indices.
190+
/// Returns the index of a memref in attribute form, given its indices. Returns
191+
/// a null pointer if whether the indices form a valid index for the provided
192+
/// MemRefType cannot be computed. The indices must come from a valid memref
193+
/// StoreOp or LoadOp.
191194
static Attribute getAttributeIndexFromIndexOperands(MLIRContext *ctx,
192-
ValueRange indices) {
195+
ValueRange indices,
196+
MemRefType memrefType) {
193197
SmallVector<Attribute> index;
194-
for (Value coord : indices) {
198+
for (auto [coord, dimSize] : llvm::zip(indices, memrefType.getShape())) {
195199
IntegerAttr coordAttr;
196200
if (!matchPattern(coord, m_Constant<IntegerAttr>(&coordAttr)))
197201
return {};
202+
// MemRefType shape dimensions are always positive (checked by verifier).
203+
std::optional<uint64_t> coordInt = coordAttr.getValue().tryZExtValue();
204+
if (!coordInt || coordInt.value() >= static_cast<uint64_t>(dimSize))
205+
return {};
198206
index.push_back(coordAttr);
199207
}
200208
return ArrayAttr::get(ctx, index);
@@ -205,8 +213,8 @@ bool memref::LoadOp::canRewire(const DestructurableMemorySlot &slot,
205213
SmallVectorImpl<MemorySlot> &mustBeSafelyUsed) {
206214
if (slot.ptr != getMemRef())
207215
return false;
208-
Attribute index =
209-
getAttributeIndexFromIndexOperands(getContext(), getIndices());
216+
Attribute index = getAttributeIndexFromIndexOperands(
217+
getContext(), getIndices(), getMemRefType());
210218
if (!index)
211219
return false;
212220
usedIndices.insert(index);
@@ -216,8 +224,8 @@ bool memref::LoadOp::canRewire(const DestructurableMemorySlot &slot,
216224
DeletionKind memref::LoadOp::rewire(const DestructurableMemorySlot &slot,
217225
DenseMap<Attribute, MemorySlot> &subslots,
218226
RewriterBase &rewriter) {
219-
Attribute index =
220-
getAttributeIndexFromIndexOperands(getContext(), getIndices());
227+
Attribute index = getAttributeIndexFromIndexOperands(
228+
getContext(), getIndices(), getMemRefType());
221229
const MemorySlot &memorySlot = subslots.at(index);
222230
rewriter.updateRootInPlace(*this, [&]() {
223231
setMemRef(memorySlot.ptr);
@@ -258,8 +266,8 @@ bool memref::StoreOp::canRewire(const DestructurableMemorySlot &slot,
258266
SmallVectorImpl<MemorySlot> &mustBeSafelyUsed) {
259267
if (slot.ptr != getMemRef() || getValue() == slot.ptr)
260268
return false;
261-
Attribute index =
262-
getAttributeIndexFromIndexOperands(getContext(), getIndices());
269+
Attribute index = getAttributeIndexFromIndexOperands(
270+
getContext(), getIndices(), getMemRefType());
263271
if (!index || !slot.elementPtrs.contains(index))
264272
return false;
265273
usedIndices.insert(index);
@@ -269,8 +277,8 @@ bool memref::StoreOp::canRewire(const DestructurableMemorySlot &slot,
269277
DeletionKind memref::StoreOp::rewire(const DestructurableMemorySlot &slot,
270278
DenseMap<Attribute, MemorySlot> &subslots,
271279
RewriterBase &rewriter) {
272-
Attribute index =
273-
getAttributeIndexFromIndexOperands(getContext(), getIndices());
280+
Attribute index = getAttributeIndexFromIndexOperands(
281+
getContext(), getIndices(), getMemRefType());
274282
const MemorySlot &memorySlot = subslots.at(index);
275283
rewriter.updateRootInPlace(*this, [&]() {
276284
setMemRef(memorySlot.ptr);

mlir/test/Dialect/MemRef/sroa.mlir

Lines changed: 23 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -132,9 +132,9 @@ func.func @no_dynamic_shape(%arg0: i32, %arg1: i32) -> i32 {
132132

133133
// -----
134134

135-
// CHECK-LABEL: func.func @no_out_of_bounds
135+
// CHECK-LABEL: func.func @no_out_of_bound_write
136136
// CHECK-SAME: (%[[ARG0:.*]]: i32, %[[ARG1:.*]]: i32)
137-
func.func @no_out_of_bounds(%arg0: i32, %arg1: i32) -> i32 {
137+
func.func @no_out_of_bound_write(%arg0: i32, %arg1: i32) -> i32 {
138138
// CHECK: %[[C0:.*]] = arith.constant 0 : index
139139
%c0 = arith.constant 0 : index
140140
// CHECK: %[[C100:.*]] = arith.constant 100 : index
@@ -152,3 +152,24 @@ func.func @no_out_of_bounds(%arg0: i32, %arg1: i32) -> i32 {
152152
// CHECK: return %[[RES]] : i32
153153
return %res : i32
154154
}
155+
156+
// -----
157+
158+
// CHECK-LABEL: func.func @no_out_of_bound_load
159+
// CHECK-SAME: (%[[ARG0:.*]]: i32, %[[ARG1:.*]]: i32)
160+
func.func @no_out_of_bound_load(%arg0: i32, %arg1: i32) -> i32 {
161+
// CHECK: %[[C0:.*]] = arith.constant 0 : index
162+
%c0 = arith.constant 0 : index
163+
// CHECK: %[[C100:.*]] = arith.constant 100 : index
164+
%c100 = arith.constant 100 : index
165+
// CHECK-NOT: = memref.alloca()
166+
// CHECK: %[[ALLOCA:.*]] = memref.alloca() : memref<2xi32>
167+
// CHECK-NOT: = memref.alloca()
168+
%alloca = memref.alloca() : memref<2xi32>
169+
// CHECK: memref.store %[[ARG0]], %[[ALLOCA]][%[[C0]]]
170+
memref.store %arg0, %alloca[%c0] : memref<2xi32>
171+
// CHECK: %[[RES:.*]] = memref.load %[[ALLOCA]][%[[C100]]]
172+
%res = memref.load %alloca[%c100] : memref<2xi32>
173+
// CHECK: return %[[RES]] : i32
174+
return %res : i32
175+
}

0 commit comments

Comments
 (0)