Skip to content

Commit 1949fe9

Browse files
authored
[mlir] Verify non-negative offset and size (#72059)
In #71153, the `memref.subview` canonicalizer crashes due to a negative `size` being passed as an operand. During `SubViewOp::verify` this negative `size` is not yet detectable since it is dynamic and only available after constant folding, which happens during the canonicalization passes. As discussed in <https://discourse.llvm.org/t/rfc-more-opfoldresult-and-mixed-indices-in-ops-that-deal-with-shaped-values/72510>, the verifier should not be extended as it should "only verify local aspects of an operation". This patch fixes #71153 by not folding in aforementioned situation. Also, this patch adds a basic offset and size check in the `OffsetSizeAndStrideOpInterface` verifier. Note: only `offset` and `size` are checked because `stride` is allowed to be negative (54d81e4).
1 parent de8f906 commit 1949fe9

File tree

6 files changed

+56
-3
lines changed

6 files changed

+56
-3
lines changed

mlir/include/mlir/Interfaces/ViewLikeInterface.td

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -56,6 +56,7 @@ def OffsetSizeAndStrideOpInterface : OpInterface<"OffsetSizeAndStrideOpInterface
5656
for each dynamic offset (resp. size, stride).
5757
5. `offsets`, `sizes` and `strides` operands are specified in this order
5858
at operand index starting at `getOffsetSizeAndStrideStartOperandIndex`.
59+
6. `offsets` and `sizes` operands are non-negative.
5960

6061
This interface is useful to factor out common behavior and provide support
6162
for carrying or injecting static behavior through the use of the static

mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp

Lines changed: 16 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -2621,6 +2621,17 @@ Type SubViewOp::inferResultType(MemRefType sourceMemRefType,
26212621
dispatchIndexOpFoldResults(offsets, dynamicOffsets, staticOffsets);
26222622
dispatchIndexOpFoldResults(sizes, dynamicSizes, staticSizes);
26232623
dispatchIndexOpFoldResults(strides, dynamicStrides, staticStrides);
2624+
2625+
// If one of the offsets or sizes is invalid, fail the canonicalization.
2626+
// These checks also occur in the verifier, but they are needed here
2627+
// because some dynamic dimensions may have been constant folded.
2628+
for (int64_t offset : staticOffsets)
2629+
if (offset < 0 && !ShapedType::isDynamic(offset))
2630+
return {};
2631+
for (int64_t size : staticSizes)
2632+
if (size < 0 && !ShapedType::isDynamic(size))
2633+
return {};
2634+
26242635
return SubViewOp::inferResultType(sourceMemRefType, staticOffsets,
26252636
staticSizes, staticStrides);
26262637
}
@@ -3094,8 +3105,11 @@ struct SubViewReturnTypeCanonicalizer {
30943105
ArrayRef<OpFoldResult> mixedSizes,
30953106
ArrayRef<OpFoldResult> mixedStrides) {
30963107
// Infer a memref type without taking into account any rank reductions.
3097-
MemRefType nonReducedType = cast<MemRefType>(SubViewOp::inferResultType(
3098-
op.getSourceType(), mixedOffsets, mixedSizes, mixedStrides));
3108+
auto resTy = SubViewOp::inferResultType(op.getSourceType(), mixedOffsets,
3109+
mixedSizes, mixedStrides);
3110+
if (!resTy)
3111+
return {};
3112+
MemRefType nonReducedType = cast<MemRefType>(resTy);
30993113

31003114
// Directly return the non-rank reduced type if there are no dropped dims.
31013115
llvm::SmallBitVector droppedDims = op.getDroppedDims();

mlir/lib/Dialect/Tensor/IR/TensorOps.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1261,7 +1261,7 @@ struct StaticTensorGenerate : public OpRewritePattern<GenerateOp> {
12611261

12621262
for (int64_t newdim : newShape) {
12631263
// This check also occurs in the verifier, but we need it here too
1264-
// since intermediate passes may have some replaced dynamic dimensions
1264+
// since intermediate passes may have replaced some dynamic dimensions
12651265
// by constants.
12661266
if (newdim < 0 && !ShapedType::isDynamic(newdim))
12671267
return failure();

mlir/lib/Interfaces/ViewLikeInterface.cpp

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -66,6 +66,17 @@ mlir::detail::verifyOffsetSizeAndStrideOp(OffsetSizeAndStrideOpInterface op) {
6666
if (failed(verifyListOfOperandsOrIntegers(
6767
op, "stride", maxRanks[2], op.getStaticStrides(), op.getStrides())))
6868
return failure();
69+
70+
for (int64_t offset : op.getStaticOffsets()) {
71+
if (offset < 0 && !ShapedType::isDynamic(offset))
72+
return op->emitError("expected offsets to be non-negative, but got ")
73+
<< offset;
74+
}
75+
for (int64_t size : op.getStaticSizes()) {
76+
if (size < 0 && !ShapedType::isDynamic(size))
77+
return op->emitError("expected sizes to be non-negative, but got ")
78+
<< size;
79+
}
6980
return success();
7081
}
7182

mlir/test/Dialect/MemRef/canonicalize.mlir

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -180,6 +180,17 @@ func.func @dim_of_sized_view(%arg : memref<?xi8>, %size: index) -> index {
180180

181181
// -----
182182

183+
// CHECK-LABEL: func @no_fold_subview_negative_size
184+
// CHECK: %[[SUBVIEW:.+]] = memref.subview
185+
// CHECK: return %[[SUBVIEW]]
186+
func.func @no_fold_subview_negative_size(%input: memref<4x1024xf32>) -> memref<?x256xf32, strided<[1024, 1], offset: 2304>> {
187+
%cst = arith.constant -13 : index
188+
%0 = memref.subview %input[2, 256] [%cst, 256] [1, 1] : memref<4x1024xf32> to memref<?x256xf32, strided<[1024, 1], offset: 2304>>
189+
return %0 : memref<?x256xf32, strided<[1024, 1], offset: 2304>>
190+
}
191+
192+
// -----
193+
183194
// CHECK-LABEL: func @no_fold_of_store
184195
// CHECK: %[[cst:.+]] = memref.cast %arg
185196
// CHECK: memref.store %[[cst]]

mlir/test/Dialect/MemRef/invalid.mlir

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -611,6 +611,22 @@ func.func @invalid_view(%arg0 : index, %arg1 : index, %arg2 : index) {
611611

612612
// -----
613613

614+
func.func @invalid_subview(%input: memref<4x1024xf32>) -> memref<2x256xf32, strided<[1024, 1], offset: 2304>> {
615+
// expected-error@+1 {{expected offsets to be non-negative, but got -1}}
616+
%0 = memref.subview %input[-1, 256] [2, 256] [1, 1] : memref<4x1024xf32> to memref<2x256xf32, strided<[1024, 1], offset: 2304>>
617+
return %0 : memref<2x256xf32, strided<[1024, 1], offset: 2304>>
618+
}
619+
620+
// -----
621+
622+
func.func @invalid_subview(%input: memref<4x1024xf32>) -> memref<2x256xf32, strided<[1024, 1], offset: 2304>> {
623+
// expected-error@+1 {{expected sizes to be non-negative, but got -1}}
624+
%0 = memref.subview %input[2, 256] [-1, 256] [1, 1] : memref<4x1024xf32> to memref<2x256xf32, strided<[1024, 1], offset: 2304>>
625+
return %0 : memref<2x256xf32, strided<[1024, 1], offset: 2304>>
626+
}
627+
628+
// -----
629+
614630
func.func @invalid_subview(%arg0 : index, %arg1 : index, %arg2 : index) {
615631
%0 = memref.alloc() : memref<8x16x4xf32>
616632
// expected-error@+1 {{expected mixed offsets rank to match mixed sizes rank (2 vs 3) so the rank of the result type is well-formed}}

0 commit comments

Comments
 (0)