Skip to content

Commit 68386a7

Browse files
[mlir][tensor] Fix crash when canonicalizing invalid IR (#72888)
This commit fixes a crash of the canonicalizer when there are slice ops with offset/size SSA values that have a negative constant value. Such ops are invalid if they are reachable and their offsets/sizes should not be folded to static integer values. (But such ops may appear in non-reachable block.) This commit fixes #71150.
1 parent bebf3a9 commit 68386a7

File tree

5 files changed

+34
-8
lines changed

5 files changed

+34
-8
lines changed

mlir/include/mlir/Dialect/Utils/StaticValueUtils.h

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -141,8 +141,10 @@ getValuesSortedByKey(ArrayRef<Attribute> keys, ArrayRef<int64_t> values,
141141

142142
/// Returns "success" when any of the elements in `ofrs` is a constant value. In
143143
/// that case the value is replaced by an attribute. Returns "failure" when no
144-
/// folding happened.
145-
LogicalResult foldDynamicIndexList(SmallVectorImpl<OpFoldResult> &ofrs);
144+
/// folding happened. If `onlyNonNegative` is set, only non-negative constant
145+
/// values are folded.
146+
LogicalResult foldDynamicIndexList(SmallVectorImpl<OpFoldResult> &ofrs,
147+
bool onlyNonNegative = false);
146148

147149
/// Return the number of iterations for a loop with a lower bound `lb`, upper
148150
/// bound `ub` and step `step`.

mlir/include/mlir/Interfaces/ViewLikeInterface.h

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -67,8 +67,8 @@ class OpWithOffsetSizesAndStridesConstantArgumentFolder final
6767
SmallVector<OpFoldResult> mixedStrides(op.getMixedStrides());
6868

6969
// No constant operands were folded, just return;
70-
if (failed(foldDynamicIndexList(mixedOffsets)) &&
71-
failed(foldDynamicIndexList(mixedSizes)) &&
70+
if (failed(foldDynamicIndexList(mixedOffsets, /*onlyNonNegative=*/true)) &&
71+
failed(foldDynamicIndexList(mixedSizes, /*onlyNonNegative=*/true)) &&
7272
failed(foldDynamicIndexList(mixedStrides)))
7373
return failure();
7474

mlir/lib/Dialect/Tensor/IR/TensorOps.cpp

Lines changed: 7 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -2361,8 +2361,8 @@ class InsertSliceOpConstantArgumentFolder final
23612361
SmallVector<OpFoldResult> mixedStrides(insertSliceOp.getMixedStrides());
23622362

23632363
// No constant operands were folded, just return;
2364-
if (failed(foldDynamicIndexList(mixedOffsets)) &&
2365-
failed(foldDynamicIndexList(mixedSizes)) &&
2364+
if (failed(foldDynamicIndexList(mixedOffsets, /*onlyNonNegative=*/true)) &&
2365+
failed(foldDynamicIndexList(mixedSizes, /*onlyNonNegative=*/true)) &&
23662366
failed(foldDynamicIndexList(mixedStrides)))
23672367
return failure();
23682368

@@ -2497,8 +2497,12 @@ struct InsertSliceOpSourceCastInserter final
24972497
srcType.getShape().end());
24982498
for (int64_t i = 0; i < srcType.getRank(); ++i) {
24992499
if (std::optional<int64_t> constInt =
2500-
getConstantIntValue(insertSliceOp.getMixedSizes()[i]))
2500+
getConstantIntValue(insertSliceOp.getMixedSizes()[i])) {
2501+
// Bail on invalid IR.
2502+
if (*constInt < 0)
2503+
return failure();
25012504
newSrcShape[i] = *constInt;
2505+
}
25022506
}
25032507

25042508
RankedTensorType newSrcType =

mlir/lib/Dialect/Utils/StaticValueUtils.cpp

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -256,13 +256,17 @@ std::optional<int64_t> constantTripCount(OpFoldResult lb, OpFoldResult ub,
256256
return mlir::ceilDiv(*ubConstant - *lbConstant, *stepConstant);
257257
}
258258

259-
LogicalResult foldDynamicIndexList(SmallVectorImpl<OpFoldResult> &ofrs) {
259+
LogicalResult foldDynamicIndexList(SmallVectorImpl<OpFoldResult> &ofrs,
260+
bool onlyNonNegative) {
260261
bool valuesChanged = false;
261262
for (OpFoldResult &ofr : ofrs) {
262263
if (ofr.is<Attribute>())
263264
continue;
264265
Attribute attr;
265266
if (matchPattern(ofr.get<Value>(), m_Constant(&attr))) {
267+
// Note: All ofrs have index type.
268+
if (onlyNonNegative && *getConstantIntValue(attr) < 0)
269+
continue;
266270
ofr = attr;
267271
valuesChanged = true;
268272
}

mlir/test/Dialect/Tensor/canonicalize.mlir

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1925,3 +1925,19 @@ func.func @fold_dst_style_ops_into_unpack(%arg0 : tensor<?x?x16x64xf32>, %init :
19251925
// CHECK: %[[UNPACK:.+]] = tensor.unpack %[[ARG0]]
19261926
// CHECK-SAME: into %[[INIT]]
19271927
// CHECK: return %[[UNPACK]]
1928+
1929+
// -----
1930+
1931+
// The IR in this test case in invalid. This test tests that the canonicalizer
1932+
// does not crash.
1933+
1934+
// CHECK-LABEL: func @invalid_slice_ops(
1935+
// CHECK: %[[c:.*]] = arith.constant -5 : index
1936+
// CHECK: tensor.extract_slice {{.*}}%[[c]]
1937+
// CHECK: tensor.insert_slice {{.*}}%[[c]]
1938+
func.func @invalid_slice_ops(%t: tensor<?xf32>, %t2: tensor<?xf32>) -> tensor<?xf32> {
1939+
%c = arith.constant -5 : index
1940+
%0 = tensor.extract_slice %t[0][%c][1] : tensor<?xf32> to tensor<?xf32>
1941+
%1 = tensor.insert_slice %0 into %t2[2][%c][1] : tensor<?xf32> into tensor<?xf32>
1942+
return %1 : tensor<?xf32>
1943+
}

0 commit comments

Comments
 (0)