Skip to content

[mlir][tosa] Add verifier checks for Gather #137204

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 3 commits into from
Apr 25, 2025
Merged

Conversation

Tai78641
Copy link
Contributor

This adds verifier checks for the gather op
to make sure the shapes of inputs and output
are consistent with respect to spec.

This adds verifier checks for the gather op
to make sure the shapes of inputs and output
are consistent with respect to spec.

Signed-off-by: Tai Ly <[email protected]>
Change-Id: I16685bceef25f428669c5412d897b6918a424119
@llvmbot
Copy link
Member

llvmbot commented Apr 24, 2025

@llvm/pr-subscribers-mlir-tosa

@llvm/pr-subscribers-mlir

Author: Tai Ly (Tai78641)

Changes

This adds verifier checks for the gather op
to make sure the shapes of inputs and output
are consistent with respect to spec.


Full diff: https://github.com/llvm/llvm-project/pull/137204.diff

2 Files Affected:

  • (modified) mlir/lib/Dialect/Tosa/IR/TosaOps.cpp (+46-2)
  • (modified) mlir/test/Dialect/Tosa/verifier.mlir (+32)
diff --git a/mlir/lib/Dialect/Tosa/IR/TosaOps.cpp b/mlir/lib/Dialect/Tosa/IR/TosaOps.cpp
index 751ae785bda6f..22aca774a403d 100644
--- a/mlir/lib/Dialect/Tosa/IR/TosaOps.cpp
+++ b/mlir/lib/Dialect/Tosa/IR/TosaOps.cpp
@@ -2262,8 +2262,52 @@ LogicalResult tosa::GatherOp::inferReturnTypeComponents(
 }
 
 LogicalResult tosa::GatherOp::verify() {
-  return verifySameElementTypes(*this, /* inType = */ getValues().getType(),
-                                /* outType = */ getOutput().getType());
+  if (verifySameElementTypes(*this, /* inType = */ getValues().getType(),
+                             /* outType = */ getOutput().getType())
+          .failed()) {
+    return failure();
+  }
+
+  const ShapeAdaptor valuesShape(getValues().getType());
+  const ShapeAdaptor indicesShape(getIndices().getType());
+  const ShapeAdaptor outputShape(getOutput().getType());
+
+  int64_t N = ShapedType::kDynamic;
+  int64_t W = ShapedType::kDynamic;
+  int64_t C = ShapedType::kDynamic;
+
+  if (valuesShape.hasRank()) {
+    N = valuesShape.getDimSize(0);
+    C = valuesShape.getDimSize(2);
+  }
+  if (indicesShape.hasRank()) {
+    const int64_t indicesN = indicesShape.getDimSize(0);
+    W = indicesShape.getDimSize(1);
+    if (N == ShapedType::kDynamic)
+      N = indicesN;
+    else if (indicesN != ShapedType::kDynamic && N != indicesN)
+      return emitOpError() << "requires indices dimension 0 to have size " << N
+                           << ", got " << indicesN;
+  }
+  if (outputShape.hasRank()) {
+    const int64_t outputN = outputShape.getDimSize(0);
+    const int64_t outputW = outputShape.getDimSize(1);
+    const int64_t outputC = outputShape.getDimSize(2);
+    if (N != ShapedType::kDynamic && outputN != ShapedType::kDynamic &&
+        N != outputN)
+      return emitOpError() << "requires output dimension 0 to have size " << N
+                           << ", got " << outputN;
+
+    if (W != ShapedType::kDynamic && outputW != ShapedType::kDynamic &&
+        W != outputW)
+      return emitOpError() << "requires output dimension 1 to have size " << W
+                           << ", got " << outputW;
+    if (C != ShapedType::kDynamic && outputC != ShapedType::kDynamic &&
+        C != outputC)
+      return emitOpError() << "requires output dimension 2 to have size " << C
+                           << ", got " << outputC;
+  }
+  return success();
 }
 
 LogicalResult tosa::ResizeOp::inferReturnTypeComponents(
diff --git a/mlir/test/Dialect/Tosa/verifier.mlir b/mlir/test/Dialect/Tosa/verifier.mlir
index 262e6d4265ea6..2b78773e4ed7f 100644
--- a/mlir/test/Dialect/Tosa/verifier.mlir
+++ b/mlir/test/Dialect/Tosa/verifier.mlir
@@ -358,3 +358,35 @@ func.func @test_concat_axis_sum_error(%arg0: tensor<1x2xf32>, %arg1: tensor<2x?x
   %0 = tosa.concat %arg0, %arg1 {axis = 0 : i32} : (tensor<1x2xf32>, tensor<2x?xf32>) -> tensor<2x?xf32>
   return %0 : tensor<2x?xf32>
 }
+
+// -----
+// CHECK-LABEL: @test_gather_invalid_indices_N
+func.func @test_gather_invalid_indices_N(%arg0: tensor<13x21x3xf32>, %arg1: tensor<12x26xi32>) -> tensor<13x26x3xf32> {
+  // expected-error@+1 {{'tosa.gather' op requires indices dimension 0 to have size 13, got 12}}
+  %0 = tosa.gather %arg0, %arg1 : (tensor<13x21x3xf32>, tensor<12x26xi32>) -> tensor<13x26x3xf32>
+  return %0 : tensor<13x26x3xf32>
+}
+
+// -----
+// CHECK-LABEL: test_gather_invalid_out_N
+func.func @test_gather_invalid_out_N(%arg0: tensor<13x21x3xf32>, %arg1: tensor<13x26xi32>) -> tensor<12x26x3xf32> {
+  // expected-error@+1 {{'tosa.gather' op requires output dimension 0 to have size 13, got 12}}
+  %0 = tosa.gather %arg0, %arg1 : (tensor<13x21x3xf32>, tensor<13x26xi32>) -> tensor<12x26x3xf32>
+  return %0 : tensor<12x26x3xf32>
+}
+
+// -----
+// CHECK-LABEL: test_gather_invalid_out_W
+func.func @test_gather_invalid_out_W(%arg0: tensor<13x21x3xf32>, %arg1: tensor<13x26xi32>) -> tensor<13x28x3xf32> {
+  // expected-error@+1 {{'tosa.gather' op requires output dimension 1 to have size 26, got 28}}
+  %0 = tosa.gather %arg0, %arg1 : (tensor<13x21x3xf32>, tensor<13x26xi32>) -> tensor<13x28x3xf32>
+  return %0 : tensor<13x28x3xf32>
+}
+
+// -----
+// CHECK-LABEL: test_gather_invalid_out_C
+func.func @test_gather_invalid_out_C(%arg0: tensor<13x21x3xf32>, %arg1: tensor<13x26xi32>) -> tensor<13x26x8xf32> {
+  // expected-error@+1 {{'tosa.gather' op requires output dimension 2 to have size 3, got 8}}
+  %0 = tosa.gather %arg0, %arg1 : (tensor<13x21x3xf32>, tensor<13x26xi32>) -> tensor<13x26x8xf32>
+  return %0 : tensor<13x26x8xf32>
+}

Copy link
Contributor

@lhutton1 lhutton1 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM, thanks!

@lhutton1 lhutton1 merged commit b4e2592 into llvm:main Apr 25, 2025
7 checks passed
IanWood1 pushed a commit to IanWood1/llvm-project that referenced this pull request May 6, 2025
This adds verifier checks for the gather op
to make sure the shapes of inputs and output
are consistent with respect to spec.

---------

Signed-off-by: Tai Ly <[email protected]>
Co-authored-by: Luke Hutton <[email protected]>
IanWood1 pushed a commit to IanWood1/llvm-project that referenced this pull request May 6, 2025
This adds verifier checks for the gather op
to make sure the shapes of inputs and output
are consistent with respect to spec.

---------

Signed-off-by: Tai Ly <[email protected]>
Co-authored-by: Luke Hutton <[email protected]>
IanWood1 pushed a commit to IanWood1/llvm-project that referenced this pull request May 6, 2025
This adds verifier checks for the gather op
to make sure the shapes of inputs and output
are consistent with respect to spec.

---------

Signed-off-by: Tai Ly <[email protected]>
Co-authored-by: Luke Hutton <[email protected]>
Ankur-0429 pushed a commit to Ankur-0429/llvm-project that referenced this pull request May 9, 2025
This adds verifier checks for the gather op
to make sure the shapes of inputs and output
are consistent with respect to spec.

---------

Signed-off-by: Tai Ly <[email protected]>
Co-authored-by: Luke Hutton <[email protected]>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants