Skip to content

Commit 30fc408

Browse files
committed
Catch negative size earlier
This now does catch the negative size inside the interface so that an `op.emitError` can be thrown. That works, but then continues to return an empty result? Instead, the interface can probably be refactored first because it's very restrictive in its current form.
1 parent 32adf97 commit 30fc408

File tree

3 files changed

+33
-1
lines changed

3 files changed

+33
-1
lines changed

mlir/include/mlir/Interfaces/ViewLikeInterface.h

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,7 @@
1919
#include "mlir/IR/BuiltinTypes.h"
2020
#include "mlir/IR/OpImplementation.h"
2121
#include "mlir/IR/PatternMatch.h"
22+
#include <_types/_uint64_t.h>
2223

2324
namespace mlir {
2425

@@ -72,6 +73,19 @@ class OpWithOffsetSizesAndStridesConstantArgumentFolder final
7273
failed(foldDynamicIndexList(mixedStrides)))
7374
return failure();
7475

76+
SmallVector<int64_t> staticOffsets, staticSizes, staticStrides;
77+
SmallVector<Value> dynamicOffsets, dynamicSizes, dynamicStrides;
78+
dispatchIndexOpFoldResults(mixedOffsets, dynamicOffsets, staticOffsets);
79+
dispatchIndexOpFoldResults(mixedSizes, dynamicSizes, staticSizes);
80+
dispatchIndexOpFoldResults(mixedStrides, dynamicStrides, staticStrides);
81+
82+
for (int64_t size : staticSizes) {
83+
if (size < 0 && !ShapedType::isDynamic(size)) {
84+
return op.emitError("expected non-negative size, but got ")
85+
<< size;;
86+
}
87+
}
88+
7589
// Create the new op in canonical form.
7690
auto resultType =
7791
ResultTypeFn()(op, mixedOffsets, mixedSizes, mixedStrides);

mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2843,6 +2843,8 @@ static LogicalResult produceSubViewErrorMsg(SliceVerificationResult result,
28432843
}
28442844

28452845
LogicalResult SubViewOp::verify() {
2846+
llvm::outs() << "SubViewOp::verify\n";
2847+
28462848
for (int64_t offset : getStaticOffsets()) {
28472849
if (offset < 0 && !ShapedType::isDynamic(offset))
28482850
return emitError("expected subview offsets to be non-negative, but got ")
@@ -3103,6 +3105,19 @@ struct SubViewReturnTypeCanonicalizer {
31033105
MemRefType operator()(SubViewOp op, ArrayRef<OpFoldResult> mixedOffsets,
31043106
ArrayRef<OpFoldResult> mixedSizes,
31053107
ArrayRef<OpFoldResult> mixedStrides) {
3108+
SmallVector<int64_t> staticOffsets, staticSizes, staticStrides;
3109+
SmallVector<Value> dynamicOffsets, dynamicSizes, dynamicStrides;
3110+
dispatchIndexOpFoldResults(mixedOffsets, dynamicOffsets, staticOffsets);
3111+
dispatchIndexOpFoldResults(mixedSizes, dynamicSizes, staticSizes);
3112+
dispatchIndexOpFoldResults(mixedStrides, dynamicStrides, staticStrides);
3113+
3114+
for (int64_t size : staticSizes) {
3115+
if (size < 0 && !ShapedType::isDynamic(size)) {
3116+
llvm::dbgs() << "expected subview sizes to be non-negative, but got "
3117+
<< size << "\n";
3118+
return {};
3119+
}
3120+
}
31063121

31073122
// Infer a memref type without taking into account any rank reductions.
31083123
MemRefType nonReducedType = cast<MemRefType>(SubViewOp::inferResultType(

mlir/lib/Dialect/Tensor/IR/TensorOps.cpp

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1163,6 +1163,8 @@ static void operandsAndShape(TensorType resultType,
11631163
}
11641164

11651165
LogicalResult GenerateOp::verify() {
1166+
llvm::outs() << "GenerateOp::verify()\n";
1167+
11661168
// Ensure that the tensor type has as many dynamic dimensions as are
11671169
// specified by the operands.
11681170
RankedTensorType resultType = llvm::cast<RankedTensorType>(getType());
@@ -1173,6 +1175,7 @@ LogicalResult GenerateOp::verify() {
11731175
SmallVector<Value> newOperands;
11741176
SmallVector<int64_t> newShape;
11751177
operandsAndShape(resultType, getDynamicExtents(), newOperands, newShape);
1178+
11761179
for (int64_t newdim : newShape) {
11771180
if (newdim < 0 && !ShapedType::isDynamic(newdim))
11781181
return emitError("tensor dimensions must be non-negative");
@@ -1242,7 +1245,7 @@ struct StaticTensorGenerate : public OpRewritePattern<GenerateOp> {
12421245

12431246
for (int64_t newdim : newShape) {
12441247
// This check also occurs in the verifier, but we need it here too
1245-
// since intermediate passes may have some replaced dynamic dimensions
1248+
// since intermediate passes may have replaced some dynamic dimensions
12461249
// by constants.
12471250
if (newdim < 0 && !ShapedType::isDynamic(newdim))
12481251
return failure();

0 commit comments

Comments
 (0)