Skip to content

Commit 68f0bc6

Browse files
rikhuijzerLewuathe
andauthored
[mlir] Fix a zero stride canonicalizer crash (#74200)
This PR fixes #73383 and is another shot at the refactoring proposed in #72885. --------- Co-authored-by: Kai Sasaki <[email protected]>
1 parent df7545e commit 68f0bc6

File tree

5 files changed

+78
-25
lines changed

5 files changed

+78
-25
lines changed

mlir/include/mlir/Dialect/Utils/StaticValueUtils.h

Lines changed: 27 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -139,12 +139,36 @@ SmallVector<int64_t>
139139
getValuesSortedByKey(ArrayRef<Attribute> keys, ArrayRef<int64_t> values,
140140
llvm::function_ref<bool(Attribute, Attribute)> compare);
141141

142+
/// Helper function to check whether the passed in `sizes` or `offsets` are
143+
/// valid. This can be used to re-check whether dimensions are still valid
144+
/// after constant folding the dynamic dimensions.
145+
bool hasValidSizesOffsets(SmallVector<int64_t> sizesOrOffsets);
146+
147+
/// Helper function to check whether the passed in `strides` are valid. This
148+
/// can be used to re-check whether dimensions are still valid after constant
149+
/// folding the dynamic dimensions.
150+
bool hasValidStrides(SmallVector<int64_t> strides);
151+
142152
/// Returns "success" when any of the elements in `ofrs` is a constant value. In
143153
/// that case the value is replaced by an attribute. Returns "failure" when no
144-
/// folding happened. If `onlyNonNegative` is set, only non-negative constant
145-
/// values are folded.
154+
/// folding happened. If `onlyNonNegative` and `onlyNonZero` are set, only
155+
/// non-negative and non-zero constant values are folded respectively.
146156
LogicalResult foldDynamicIndexList(SmallVectorImpl<OpFoldResult> &ofrs,
147-
bool onlyNonNegative = false);
157+
bool onlyNonNegative = false,
158+
bool onlyNonZero = false);
159+
160+
/// Returns "success" when any of the elements in `offsetsOrSizes` is a
161+
/// constant value. In that case the value is replaced by an attribute. Returns
162+
/// "failure" when no folding happened. Invalid values are not folded to avoid
163+
/// canonicalization crashes.
164+
LogicalResult
165+
foldDynamicOffsetSizeList(SmallVectorImpl<OpFoldResult> &offsetsOrSizes);
166+
167+
/// Returns "success" when any of the elements in `strides` is a constant
168+
/// value. In that case the value is replaced by an attribute. Returns
169+
/// "failure" when no folding happened. Invalid values are not folded to avoid
170+
/// canonicalization crashes.
171+
LogicalResult foldDynamicStrideList(SmallVectorImpl<OpFoldResult> &strides);
148172

149173
/// Return the number of iterations for a loop with a lower bound `lb`, upper
150174
/// bound `ub` and step `step`.

mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp

Lines changed: 6 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -2582,17 +2582,12 @@ Type SubViewOp::inferResultType(MemRefType sourceMemRefType,
25822582
dispatchIndexOpFoldResults(offsets, dynamicOffsets, staticOffsets);
25832583
dispatchIndexOpFoldResults(sizes, dynamicSizes, staticSizes);
25842584
dispatchIndexOpFoldResults(strides, dynamicStrides, staticStrides);
2585-
2586-
// If one of the offsets or sizes is invalid, fail the canonicalization.
2587-
// These checks also occur in the verifier, but they are needed here
2588-
// because some dynamic dimensions may have been constant folded.
2589-
for (int64_t offset : staticOffsets)
2590-
if (offset < 0 && !ShapedType::isDynamic(offset))
2591-
return {};
2592-
for (int64_t size : staticSizes)
2593-
if (size < 0 && !ShapedType::isDynamic(size))
2594-
return {};
2595-
2585+
if (!hasValidSizesOffsets(staticOffsets))
2586+
return {};
2587+
if (!hasValidSizesOffsets(staticSizes))
2588+
return {};
2589+
if (!hasValidStrides(staticStrides))
2590+
return {};
25962591
return SubViewOp::inferResultType(sourceMemRefType, staticOffsets,
25972592
staticSizes, staticStrides);
25982593
}

mlir/lib/Dialect/Tensor/IR/TensorOps.cpp

Lines changed: 7 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -1447,13 +1447,8 @@ struct StaticTensorGenerate : public OpRewritePattern<GenerateOp> {
14471447
SmallVector<int64_t> newShape;
14481448
operandsAndShape(resultType, dynamicExtents, newOperands, newShape);
14491449

1450-
for (int64_t newdim : newShape) {
1451-
// This check also occurs in the verifier, but we need it here too
1452-
// since intermediate passes may have replaced some dynamic dimensions
1453-
// by constants.
1454-
if (newdim < 0 && !ShapedType::isDynamic(newdim))
1455-
return failure();
1456-
}
1450+
if (!hasValidSizesOffsets(newShape))
1451+
return failure();
14571452

14581453
if (newOperands.size() == tensorFromElements.getDynamicExtents().size())
14591454
return failure();
@@ -2549,9 +2544,9 @@ class InsertSliceOpConstantArgumentFolder final
25492544
SmallVector<OpFoldResult> mixedStrides(insertSliceOp.getMixedStrides());
25502545

25512546
// No constant operands were folded, just return;
2552-
if (failed(foldDynamicIndexList(mixedOffsets, /*onlyNonNegative=*/true)) &&
2553-
failed(foldDynamicIndexList(mixedSizes, /*onlyNonNegative=*/true)) &&
2554-
failed(foldDynamicIndexList(mixedStrides)))
2547+
if (failed(foldDynamicOffsetSizeList(mixedOffsets)) &&
2548+
failed(foldDynamicOffsetSizeList(mixedSizes)) &&
2549+
failed(foldDynamicStrideList(mixedStrides)))
25552550
return failure();
25562551

25572552
// Create the new op in canonical form.
@@ -2692,6 +2687,8 @@ struct InsertSliceOpSourceCastInserter final
26922687
newSrcShape[i] = *constInt;
26932688
}
26942689
}
2690+
if (!hasValidSizesOffsets(newSrcShape))
2691+
return failure();
26952692

26962693
RankedTensorType newSrcType =
26972694
RankedTensorType::get(newSrcShape, srcType.getElementType());

mlir/lib/Dialect/Utils/StaticValueUtils.cpp

Lines changed: 26 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -256,8 +256,20 @@ std::optional<int64_t> constantTripCount(OpFoldResult lb, OpFoldResult ub,
256256
return mlir::ceilDiv(*ubConstant - *lbConstant, *stepConstant);
257257
}
258258

259+
bool hasValidSizesOffsets(SmallVector<int64_t> sizesOrOffsets) {
260+
return llvm::none_of(sizesOrOffsets, [](int64_t value) {
261+
return !ShapedType::isDynamic(value) && value < 0;
262+
});
263+
}
264+
265+
bool hasValidStrides(SmallVector<int64_t> strides) {
266+
return llvm::none_of(strides, [](int64_t value) {
267+
return !ShapedType::isDynamic(value) && value == 0;
268+
});
269+
}
270+
259271
LogicalResult foldDynamicIndexList(SmallVectorImpl<OpFoldResult> &ofrs,
260-
bool onlyNonNegative) {
272+
bool onlyNonNegative, bool onlyNonZero) {
261273
bool valuesChanged = false;
262274
for (OpFoldResult &ofr : ofrs) {
263275
if (ofr.is<Attribute>())
@@ -267,11 +279,24 @@ LogicalResult foldDynamicIndexList(SmallVectorImpl<OpFoldResult> &ofrs,
267279
// Note: All ofrs have index type.
268280
if (onlyNonNegative && *getConstantIntValue(attr) < 0)
269281
continue;
282+
if (onlyNonZero && *getConstantIntValue(attr) == 0)
283+
continue;
270284
ofr = attr;
271285
valuesChanged = true;
272286
}
273287
}
274288
return success(valuesChanged);
275289
}
276290

291+
LogicalResult
292+
foldDynamicOffsetSizeList(SmallVectorImpl<OpFoldResult> &offsetsOrSizes) {
293+
return foldDynamicIndexList(offsetsOrSizes, /*onlyNonNegative=*/true,
294+
/*onlyNonZero=*/false);
295+
}
296+
297+
LogicalResult foldDynamicStrideList(SmallVectorImpl<OpFoldResult> &strides) {
298+
return foldDynamicIndexList(strides, /*onlyNonNegative=*/false,
299+
/*onlyNonZero=*/true);
300+
}
301+
277302
} // namespace mlir

mlir/test/Dialect/MemRef/canonicalize.mlir

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -191,6 +191,18 @@ func.func @no_fold_subview_negative_size(%input: memref<4x1024xf32>) -> memref<?
191191

192192
// -----
193193

194+
// CHECK-LABEL: func @no_fold_subview_zero_stride
195+
// CHECK: %[[SUBVIEW:.+]] = memref.subview
196+
// CHECK: return %[[SUBVIEW]]
197+
func.func @no_fold_subview_zero_stride(%arg0 : memref<10xf32>) -> memref<1xf32, strided<[?], offset: 1>> {
198+
%c0 = arith.constant 0 : index
199+
%c1 = arith.constant 1 : index
200+
%1 = memref.subview %arg0[1] [1] [%c0] : memref<10xf32> to memref<1xf32, strided<[?], offset: 1>>
201+
return %1 : memref<1xf32, strided<[?], offset: 1>>
202+
}
203+
204+
// -----
205+
194206
// CHECK-LABEL: func @no_fold_of_store
195207
// CHECK: %[[cst:.+]] = memref.cast %arg
196208
// CHECK: memref.store %[[cst]]

0 commit comments

Comments
 (0)