-
Notifications
You must be signed in to change notification settings - Fork 14.3k
[mlir][tosa] Add verifier checks for Gather #137204
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
This adds verifier checks for the gather op to make sure the shapes of inputs and output are consistent with respect to spec. Signed-off-by: Tai Ly <[email protected]> Change-Id: I16685bceef25f428669c5412d897b6918a424119
@llvm/pr-subscribers-mlir-tosa @llvm/pr-subscribers-mlir Author: Tai Ly (Tai78641) ChangesThis adds verifier checks for the gather op Full diff: https://github.com/llvm/llvm-project/pull/137204.diff 2 Files Affected:
diff --git a/mlir/lib/Dialect/Tosa/IR/TosaOps.cpp b/mlir/lib/Dialect/Tosa/IR/TosaOps.cpp
index 751ae785bda6f..22aca774a403d 100644
--- a/mlir/lib/Dialect/Tosa/IR/TosaOps.cpp
+++ b/mlir/lib/Dialect/Tosa/IR/TosaOps.cpp
@@ -2262,8 +2262,52 @@ LogicalResult tosa::GatherOp::inferReturnTypeComponents(
}
LogicalResult tosa::GatherOp::verify() {
- return verifySameElementTypes(*this, /* inType = */ getValues().getType(),
- /* outType = */ getOutput().getType());
+ if (verifySameElementTypes(*this, /* inType = */ getValues().getType(),
+ /* outType = */ getOutput().getType())
+ .failed()) {
+ return failure();
+ }
+
+ const ShapeAdaptor valuesShape(getValues().getType());
+ const ShapeAdaptor indicesShape(getIndices().getType());
+ const ShapeAdaptor outputShape(getOutput().getType());
+
+ int64_t N = ShapedType::kDynamic;
+ int64_t W = ShapedType::kDynamic;
+ int64_t C = ShapedType::kDynamic;
+
+ if (valuesShape.hasRank()) {
+ N = valuesShape.getDimSize(0);
+ C = valuesShape.getDimSize(2);
+ }
+ if (indicesShape.hasRank()) {
+ const int64_t indicesN = indicesShape.getDimSize(0);
+ W = indicesShape.getDimSize(1);
+ if (N == ShapedType::kDynamic)
+ N = indicesN;
+ else if (indicesN != ShapedType::kDynamic && N != indicesN)
+ return emitOpError() << "requires indices dimension 0 to have size " << N
+ << ", got " << indicesN;
+ }
+ if (outputShape.hasRank()) {
+ const int64_t outputN = outputShape.getDimSize(0);
+ const int64_t outputW = outputShape.getDimSize(1);
+ const int64_t outputC = outputShape.getDimSize(2);
+ if (N != ShapedType::kDynamic && outputN != ShapedType::kDynamic &&
+ N != outputN)
+ return emitOpError() << "requires output dimension 0 to have size " << N
+ << ", got " << outputN;
+
+ if (W != ShapedType::kDynamic && outputW != ShapedType::kDynamic &&
+ W != outputW)
+ return emitOpError() << "requires output dimension 1 to have size " << W
+ << ", got " << outputW;
+ if (C != ShapedType::kDynamic && outputC != ShapedType::kDynamic &&
+ C != outputC)
+ return emitOpError() << "requires output dimension 2 to have size " << C
+ << ", got " << outputC;
+ }
+ return success();
}
LogicalResult tosa::ResizeOp::inferReturnTypeComponents(
diff --git a/mlir/test/Dialect/Tosa/verifier.mlir b/mlir/test/Dialect/Tosa/verifier.mlir
index 262e6d4265ea6..2b78773e4ed7f 100644
--- a/mlir/test/Dialect/Tosa/verifier.mlir
+++ b/mlir/test/Dialect/Tosa/verifier.mlir
@@ -358,3 +358,35 @@ func.func @test_concat_axis_sum_error(%arg0: tensor<1x2xf32>, %arg1: tensor<2x?x
%0 = tosa.concat %arg0, %arg1 {axis = 0 : i32} : (tensor<1x2xf32>, tensor<2x?xf32>) -> tensor<2x?xf32>
return %0 : tensor<2x?xf32>
}
+
+// -----
+// CHECK-LABEL: @test_gather_invalid_indices_N
+func.func @test_gather_invalid_indices_N(%arg0: tensor<13x21x3xf32>, %arg1: tensor<12x26xi32>) -> tensor<13x26x3xf32> {
+ // expected-error@+1 {{'tosa.gather' op requires indices dimension 0 to have size 13, got 12}}
+ %0 = tosa.gather %arg0, %arg1 : (tensor<13x21x3xf32>, tensor<12x26xi32>) -> tensor<13x26x3xf32>
+ return %0 : tensor<13x26x3xf32>
+}
+
+// -----
+// CHECK-LABEL: test_gather_invalid_out_N
+func.func @test_gather_invalid_out_N(%arg0: tensor<13x21x3xf32>, %arg1: tensor<13x26xi32>) -> tensor<12x26x3xf32> {
+ // expected-error@+1 {{'tosa.gather' op requires output dimension 0 to have size 13, got 12}}
+ %0 = tosa.gather %arg0, %arg1 : (tensor<13x21x3xf32>, tensor<13x26xi32>) -> tensor<12x26x3xf32>
+ return %0 : tensor<12x26x3xf32>
+}
+
+// -----
+// CHECK-LABEL: test_gather_invalid_out_W
+func.func @test_gather_invalid_out_W(%arg0: tensor<13x21x3xf32>, %arg1: tensor<13x26xi32>) -> tensor<13x28x3xf32> {
+ // expected-error@+1 {{'tosa.gather' op requires output dimension 1 to have size 26, got 28}}
+ %0 = tosa.gather %arg0, %arg1 : (tensor<13x21x3xf32>, tensor<13x26xi32>) -> tensor<13x28x3xf32>
+ return %0 : tensor<13x28x3xf32>
+}
+
+// -----
+// CHECK-LABEL: test_gather_invalid_out_C
+func.func @test_gather_invalid_out_C(%arg0: tensor<13x21x3xf32>, %arg1: tensor<13x26xi32>) -> tensor<13x26x8xf32> {
+ // expected-error@+1 {{'tosa.gather' op requires output dimension 2 to have size 3, got 8}}
+ %0 = tosa.gather %arg0, %arg1 : (tensor<13x21x3xf32>, tensor<13x26xi32>) -> tensor<13x26x8xf32>
+ return %0 : tensor<13x26x8xf32>
+}
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM, thanks!
This adds verifier checks for the gather op to make sure the shapes of inputs and output are consistent with respect to spec. --------- Signed-off-by: Tai Ly <[email protected]> Co-authored-by: Luke Hutton <[email protected]>
This adds verifier checks for the gather op to make sure the shapes of inputs and output are consistent with respect to spec. --------- Signed-off-by: Tai Ly <[email protected]> Co-authored-by: Luke Hutton <[email protected]>
This adds verifier checks for the gather op to make sure the shapes of inputs and output are consistent with respect to spec. --------- Signed-off-by: Tai Ly <[email protected]> Co-authored-by: Luke Hutton <[email protected]>
This adds verifier checks for the gather op to make sure the shapes of inputs and output are consistent with respect to spec. --------- Signed-off-by: Tai Ly <[email protected]> Co-authored-by: Luke Hutton <[email protected]>
This adds verifier checks for the gather op
to make sure the shapes of inputs and output
are consistent with respect to spec.