Skip to content

Commit b6d67af

Browse files
authored
[mlir][tosa] Add verifier for tosa.tile, fix shape inference crash (#70972)
This patch adds an verifier to `tosa.tile` which checks input/output ranks and the length of the `multiples` array. The patch also fixes a crash in the shape inference when an invalid `multiples` array is supplied. Fix #70415
1 parent 9e0a5be commit b6d67af

File tree

4 files changed

+40
-1
lines changed

4 files changed

+40
-1
lines changed

mlir/include/mlir/Dialect/Tosa/IR/TosaOps.td

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1644,6 +1644,7 @@ def Tosa_TileOp : Tosa_InferShapedTypeOp<"tile"> {
16441644
);
16451645

16461646
let hasFolder = 1;
1647+
let hasVerifier = 1;
16471648
}
16481649

16491650
//===----------------------------------------------------------------------===//

mlir/lib/Dialect/Tosa/IR/TosaOps.cpp

Lines changed: 20 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -863,7 +863,8 @@ LogicalResult tosa::TileOp::inferReturnTypeComponents(
863863
outputShape.resize(multiples.size(), ShapedType::kDynamic);
864864
inferredReturnShapes.push_back(ShapedTypeComponents(outputShape));
865865
return success();
866-
}
866+
} else if (inputShape.getRank() != multiples.size())
867+
return failure();
867868

868869
// Any non dynamic dimension can be multiplied to a known size.
869870
outputShape.reserve(multiples.size());
@@ -878,6 +879,24 @@ LogicalResult tosa::TileOp::inferReturnTypeComponents(
878879
return success();
879880
}
880881

882+
LogicalResult tosa::TileOp::verify() {
883+
ShapedType inputType = llvm::cast<ShapedType>(getInput1().getType());
884+
ShapedType outputType = llvm::cast<ShapedType>(getType());
885+
auto multiples = getMultiples();
886+
887+
if (inputType.hasRank()) {
888+
if (inputType.getRank() != multiples.size())
889+
return emitOpError("expect 'multiples' array to have length ")
890+
<< inputType.getRank() << " but got " << multiples.size() << ".";
891+
if (outputType.hasRank() && inputType.getRank() != outputType.getRank())
892+
return emitOpError("expect same input and output tensor rank.");
893+
} else if (outputType.hasRank() && outputType.getRank() != multiples.size())
894+
return emitOpError("expect 'multiples' array to have length ")
895+
<< outputType.getRank() << " but got " << multiples.size() << ".";
896+
897+
return success();
898+
}
899+
881900
bool tosa::ReshapeOp::isCompatibleReturnTypes(TypeRange l, TypeRange r) {
882901
if (l.size() != r.size() || l.size() != 1)
883902
return false;

mlir/test/Dialect/Tosa/canonicalize.mlir

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -603,3 +603,13 @@ func.func nested @fold_reduce_rank_zero() {
603603
%2 = tosa.reverse %0 {axis = 0 : i32} : (tensor<i32>) -> tensor<i32>
604604
return
605605
}
606+
607+
// -----
608+
609+
// CHECK-LABEL: @fold_tile_rank_zero
610+
func.func nested @fold_tile_rank_zero() -> tensor<i32> {
611+
// CHECK-NOT: tosa.tile
612+
%0 = tensor.empty() : tensor<i32>
613+
%1 = tosa.tile %0 {multiples = array<i64>} : (tensor<i32>) -> tensor<i32>
614+
return %1 : tensor<i32>
615+
}

mlir/test/Dialect/Tosa/invalid.mlir

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -329,3 +329,12 @@ func.func @test_slice_invalid_size() {
329329
%1 = tosa.slice %0 {size = array<i64: 1>, start = array<i64: 1, 1, 1>} : (tensor<4x31x31xf32>) -> tensor<*xf32>
330330
return
331331
}
332+
333+
// -----
334+
335+
func.func @test_tile_invalid_multiples() {
336+
%0 = tensor.empty() : tensor<4x31x31xf32>
337+
// expected-error@+1 {{'tosa.tile' op expect 'multiples' array to have length 3 but got 0.}}
338+
%1 = tosa.tile %0 {multiples = array<i64>} : (tensor<4x31x31xf32>) -> tensor<4x31x31xf32>
339+
return
340+
}

0 commit comments

Comments
 (0)