Skip to content

Commit 1e34706

Browse files
authored
[mlir][tosa] Add verifier for tosa.table (#103708)
This patch adds a verifier to `tosa.table` which fixes a crash. Fix #103086.
1 parent 372842b commit 1e34706

File tree

3 files changed

+52
-0
lines changed

3 files changed

+52
-0
lines changed

mlir/include/mlir/Dialect/Tosa/IR/TosaOps.td

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -897,6 +897,8 @@ def Tosa_TableOp : Tosa_InferShapedTypeOp<"table"> {
897897
let assemblyFormat = [{
898898
$input `,` $table attr-dict `:` `(` type($input) `,` type($table) `)` `->` type($output)
899899
}];
900+
901+
let hasVerifier = 1;
900902
}
901903

902904
//===----------------------------------------------------------------------===//

mlir/lib/Dialect/Tosa/IR/TosaOps.cpp

Lines changed: 23 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -864,6 +864,29 @@ LogicalResult tosa::TableOp::inferReturnTypeComponents(
864864
return success();
865865
}
866866

867+
LogicalResult tosa::TableOp::verify() {
868+
TensorType inputType = getInput().getType();
869+
TensorType outputType = getOutput().getType();
870+
871+
if (inputType.hasRank() && outputType.hasRank() &&
872+
inputType.getRank() != outputType.getRank())
873+
return emitOpError()
874+
<< "expected input tensor rank to equal result tensor rank";
875+
876+
auto inputDims = inputType.getShape();
877+
auto outputDims = outputType.getShape();
878+
for (auto it : llvm::enumerate(llvm::zip(inputDims, outputDims))) {
879+
int64_t dim = it.index();
880+
auto [inputDim, outputDim] = it.value();
881+
if (!ShapedType::isDynamic(outputDim) && outputDim != inputDim) {
882+
return emitOpError() << "dim(result, " << dim << ") = " << outputDim
883+
<< " doesn't match dim(input, " << dim
884+
<< ") = " << inputDim;
885+
}
886+
}
887+
return success();
888+
}
889+
867890
LogicalResult tosa::TileOp::inferReturnTypeComponents(
868891
MLIRContext *context, ::std::optional<Location> location,
869892
TileOp::Adaptor adaptor,

mlir/test/Dialect/Tosa/invalid.mlir

Lines changed: 27 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -448,3 +448,30 @@ func.func @test_large_constant_permutation() {
448448
%3 = tosa.transpose %2, %1 : (tensor<?x27xi64>, tensor<2xi32>) -> tensor<?x27xi64>
449449
return
450450
}
451+
452+
// -----
453+
454+
// CHECK-LABEL: test_table_rank0_table
455+
func.func @test_table_rank0_table(%arg0: tensor<64xi16>, %arg1: tensor<i16>) {
456+
// expected-error@+1 {{'tosa.table' op operand #1 must be 1-d tensor, but got 'tensor<i16>'}}
457+
%0 = tosa.table %arg0, %arg1 : (tensor<64xi16>, tensor<i16>) -> tensor<64xi16>
458+
return
459+
}
460+
461+
// -----
462+
463+
// CHECK-LABEL: test_table_io_rank_mismatch
464+
func.func @test_table_io_rank_mismatch(%arg0: tensor<64xi16>, %arg1: tensor<6xi16>) {
465+
// expected-error@+1 {{'tosa.table' op expected input tensor rank to equal result tensor rank}}
466+
%0 = tosa.table %arg0, %arg1 : (tensor<64xi16>, tensor<6xi16>) -> tensor<64x?xi16>
467+
return
468+
}
469+
470+
// -----
471+
472+
// CHECK-LABEL: test_table_io_shape_mismatch
473+
func.func @test_table_io_shape_mismatch(%arg0: tensor<?x16xi16>, %arg1: tensor<6xi16>) {
474+
// expected-error@+1 {{'tosa.table' op dim(result, 1) = 15 doesn't match dim(input, 1) = 16}}
475+
%0 = tosa.table %arg0, %arg1 : (tensor<?x16xi16>, tensor<6xi16>) -> tensor<?x15xi16>
476+
return
477+
}

0 commit comments

Comments
 (0)