Skip to content

Commit 7323aa0

Browse files
committed
Add assertion in outer `inferResultType
1 parent 30fc408 commit 7323aa0

File tree

3 files changed

+13
-31
lines changed

3 files changed

+13
-31
lines changed

mlir/include/mlir/Interfaces/ViewLikeInterface.h

Lines changed: 0 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -73,19 +73,6 @@ class OpWithOffsetSizesAndStridesConstantArgumentFolder final
7373
failed(foldDynamicIndexList(mixedStrides)))
7474
return failure();
7575

76-
SmallVector<int64_t> staticOffsets, staticSizes, staticStrides;
77-
SmallVector<Value> dynamicOffsets, dynamicSizes, dynamicStrides;
78-
dispatchIndexOpFoldResults(mixedOffsets, dynamicOffsets, staticOffsets);
79-
dispatchIndexOpFoldResults(mixedSizes, dynamicSizes, staticSizes);
80-
dispatchIndexOpFoldResults(mixedStrides, dynamicStrides, staticStrides);
81-
82-
for (int64_t size : staticSizes) {
83-
if (size < 0 && !ShapedType::isDynamic(size)) {
84-
return op.emitError("expected non-negative size, but got ")
85-
<< size;;
86-
}
87-
}
88-
8976
// Create the new op in canonical form.
9077
auto resultType =
9178
ResultTypeFn()(op, mixedOffsets, mixedSizes, mixedStrides);

mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp

Lines changed: 13 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -2621,6 +2621,19 @@ Type SubViewOp::inferResultType(MemRefType sourceMemRefType,
26212621
dispatchIndexOpFoldResults(offsets, dynamicOffsets, staticOffsets);
26222622
dispatchIndexOpFoldResults(sizes, dynamicSizes, staticSizes);
26232623
dispatchIndexOpFoldResults(strides, dynamicStrides, staticStrides);
2624+
2625+
// Double-check the offsets, sizes, and strides after constant folding.
2626+
// This allows throwing a more informative assertion message than
2627+
// what would be thrown at a later point.
2628+
for (int64_t offset : staticOffsets) {
2629+
if (!ShapedType::isDynamic(offset))
2630+
assert(offset >= 0 && "expected subview offsets to be non-negative");
2631+
}
2632+
for (int64_t size : staticSizes) {
2633+
if (!ShapedType::isDynamic(size))
2634+
assert(size >= 0 && "expected subview sizes to be non-negative");
2635+
}
2636+
26242637
return SubViewOp::inferResultType(sourceMemRefType, staticOffsets,
26252638
staticSizes, staticStrides);
26262639
}
@@ -2843,8 +2856,6 @@ static LogicalResult produceSubViewErrorMsg(SliceVerificationResult result,
28432856
}
28442857

28452858
LogicalResult SubViewOp::verify() {
2846-
llvm::outs() << "SubViewOp::verify\n";
2847-
28482859
for (int64_t offset : getStaticOffsets()) {
28492860
if (offset < 0 && !ShapedType::isDynamic(offset))
28502861
return emitError("expected subview offsets to be non-negative, but got ")
@@ -3105,20 +3116,6 @@ struct SubViewReturnTypeCanonicalizer {
31053116
MemRefType operator()(SubViewOp op, ArrayRef<OpFoldResult> mixedOffsets,
31063117
ArrayRef<OpFoldResult> mixedSizes,
31073118
ArrayRef<OpFoldResult> mixedStrides) {
3108-
SmallVector<int64_t> staticOffsets, staticSizes, staticStrides;
3109-
SmallVector<Value> dynamicOffsets, dynamicSizes, dynamicStrides;
3110-
dispatchIndexOpFoldResults(mixedOffsets, dynamicOffsets, staticOffsets);
3111-
dispatchIndexOpFoldResults(mixedSizes, dynamicSizes, staticSizes);
3112-
dispatchIndexOpFoldResults(mixedStrides, dynamicStrides, staticStrides);
3113-
3114-
for (int64_t size : staticSizes) {
3115-
if (size < 0 && !ShapedType::isDynamic(size)) {
3116-
llvm::dbgs() << "expected subview sizes to be non-negative, but got "
3117-
<< size << "\n";
3118-
return {};
3119-
}
3120-
}
3121-
31223119
// Infer a memref type without taking into account any rank reductions.
31233120
MemRefType nonReducedType = cast<MemRefType>(SubViewOp::inferResultType(
31243121
op.getSourceType(), mixedOffsets, mixedSizes, mixedStrides));

mlir/lib/Dialect/Tensor/IR/TensorOps.cpp

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1163,8 +1163,6 @@ static void operandsAndShape(TensorType resultType,
11631163
}
11641164

11651165
LogicalResult GenerateOp::verify() {
1166-
llvm::outs() << "GenerateOp::verify()\n";
1167-
11681166
// Ensure that the tensor type has as many dynamic dimensions as are
11691167
// specified by the operands.
11701168
RankedTensorType resultType = llvm::cast<RankedTensorType>(getType());

0 commit comments

Comments
 (0)