Skip to content

[mlir][tensor] Add check for indices of tensor.gather #106894

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 1 commit into from
Sep 6, 2024

Conversation

CoTinker
Copy link
Contributor

@CoTinker CoTinker commented Sep 1, 2024

This patch add a check for indices of tensor.gather and tensor.scatter. For that the length of gather_dims/scatter_dims should match the size of last dimension of the indices. Fix #94901.

This patch add a check for indices of `tensor.gather` and
`tensor.scatter`. For that the length of gather_dims/scatter_dims
should match the size of last dimension of the indices.
@llvmbot
Copy link
Member

llvmbot commented Sep 1, 2024

@llvm/pr-subscribers-mlir-tensor

@llvm/pr-subscribers-mlir

Author: Longsheng Mou (CoTinker)

Changes

This patch add a check for indices of tensor.gather and tensor.scatter. For that the length of gather_dims/scatter_dims should match the size of last dimension of the indices. Fix #94901.


Full diff: https://github.com/llvm/llvm-project/pull/106894.diff

2 Files Affected:

  • (modified) mlir/lib/Dialect/Tensor/IR/TensorOps.cpp (+9-3)
  • (modified) mlir/test/Dialect/Tensor/invalid.mlir (+60-20)
diff --git a/mlir/lib/Dialect/Tensor/IR/TensorOps.cpp b/mlir/lib/Dialect/Tensor/IR/TensorOps.cpp
index e11c6aaccf74dd..603773c377e915 100644
--- a/mlir/lib/Dialect/Tensor/IR/TensorOps.cpp
+++ b/mlir/lib/Dialect/Tensor/IR/TensorOps.cpp
@@ -1288,7 +1288,8 @@ RankedTensorType GatherOp::inferResultType(RankedTensorType sourceType,
 }
 
 static LogicalResult
-verifyGatherOrScatterDims(Operation *op, ArrayRef<int64_t> dims, int64_t rank,
+verifyGatherOrScatterDims(Operation *op, ArrayRef<int64_t> dims,
+                          ArrayRef<int64_t> indices, int64_t rank,
                           StringRef gatherOrScatter, StringRef sourceOrDest) {
   if (dims.empty())
     return op->emitOpError(gatherOrScatter) << "_dims must be non-empty";
@@ -1297,6 +1298,9 @@ verifyGatherOrScatterDims(Operation *op, ArrayRef<int64_t> dims, int64_t rank,
   if (numGatherDims > rank)
     return op->emitOpError(gatherOrScatter)
            << "_dims overflow " << sourceOrDest << " rank";
+  if (indices.empty() || indices.back() != numGatherDims)
+    return op->emitOpError(gatherOrScatter)
+           << "_dims length must match the size of last dimension of indices";
   for (int64_t val : dims) {
     if (val < 0)
       return op->emitOpError(gatherOrScatter)
@@ -1316,7 +1320,8 @@ verifyGatherOrScatterDims(Operation *op, ArrayRef<int64_t> dims, int64_t rank,
 LogicalResult GatherOp::verify() {
   int64_t sourceRank = getSourceType().getRank();
   ArrayRef<int64_t> gatherDims = getGatherDims();
-  if (failed(verifyGatherOrScatterDims(getOperation(), gatherDims, sourceRank,
+  if (failed(verifyGatherOrScatterDims(getOperation(), gatherDims,
+                                       getIndicesType().getShape(), sourceRank,
                                        "gather", "source")))
     return failure();
 
@@ -3530,7 +3535,8 @@ void ScatterOp::getAsmResultNames(
 LogicalResult ScatterOp::verify() {
   int64_t destRank = getDestType().getRank();
   ArrayRef<int64_t> scatterDims = getScatterDims();
-  if (failed(verifyGatherOrScatterDims(getOperation(), scatterDims, destRank,
+  if (failed(verifyGatherOrScatterDims(getOperation(), scatterDims,
+                                       getIndicesType().getShape(), destRank,
                                        "scatter", "dest")))
     return failure();
 
diff --git a/mlir/test/Dialect/Tensor/invalid.mlir b/mlir/test/Dialect/Tensor/invalid.mlir
index d9db32b8801ac2..84e6c59e403dde 100644
--- a/mlir/test/Dialect/Tensor/invalid.mlir
+++ b/mlir/test/Dialect/Tensor/invalid.mlir
@@ -455,41 +455,59 @@ func.func @gather_coordinate_rank_overflow(
 
 // -----
 
+func.func @gather_coordinate_rank_mismatch0(
+    %source: tensor<4x5x6xf32>, %indices: tensor<index>) {
+  // expected-error@+1 {{gather_dims length must match the size of last dimension of indices}}
+  %out = tensor.gather %source[%indices] gather_dims([0, 1, 2]):
+    (tensor<4x5x6xf32>, tensor<index>) -> tensor<1x2xf32>
+}
+
+// -----
+
+func.func @gather_coordinate_rank_mismatch1(
+    %source: tensor<4x5x6xf32>, %indices: tensor<1x2x2xindex>) {
+  // expected-error@+1 {{gather_dims length must match the size of last dimension of indices}}
+  %out = tensor.gather %source[%indices] gather_dims([0, 1, 2]):
+    (tensor<4x5x6xf32>, tensor<1x2x2xindex>) -> tensor<1x2xf32>
+}
+
+// -----
+
 func.func @gather_coordinate_negative(
-    %source : tensor<4x5x6xf32>, %indices: tensor<1x2x3xindex>) {
+    %source : tensor<4x5x6xf32>, %indices: tensor<1x2x1xindex>) {
   // expected-error@+1 {{gather_dims value must be non-negative}}
   %out = tensor.gather %source[%indices] gather_dims([-1]):
-    (tensor<4x5x6xf32>, tensor<1x2x3xindex>) -> tensor<1x2x1x1x1xf32>
+    (tensor<4x5x6xf32>, tensor<1x2x1xindex>) -> tensor<1x2x1xf32>
   return
 }
 
 // -----
 
 func.func @gather_coordinate_overflow(
-    %source : tensor<4x5x6xf32>, %indices: tensor<1x2x3xindex>) {
+    %source : tensor<4x5x6xf32>, %indices: tensor<1x2x1xindex>) {
   // expected-error@+1 {{gather_dims value must be smaller than source rank}}
   %out = tensor.gather %source[%indices] gather_dims([42]):
-    (tensor<4x5x6xf32>, tensor<1x2x3xindex>) -> tensor<1x2x1x1x1xf32>
+    (tensor<4x5x6xf32>, tensor<1x2x1xindex>) -> tensor<1x2x1xf32>
   return
 }
 
 // -----
 
-func.func @gather_coordinate_overflow(
-    %source : tensor<4x5x6xf32>, %indices: tensor<1x2x3xindex>) {
+func.func @gather_coordinate_increase(
+    %source : tensor<4x5x6xf32>, %indices: tensor<1x2x2xindex>) {
   // expected-error@+1 {{gather_dims values must be strictly increasing}}
   %out = tensor.gather %source[%indices] gather_dims([1, 0]):
-    (tensor<4x5x6xf32>, tensor<1x2x3xindex>) -> tensor<1x2x1x1x1xf32>
+    (tensor<4x5x6xf32>, tensor<1x2x2xindex>) -> tensor<1x2x1x1xf32>
   return
 }
 
 // -----
 
 func.func @gather_wrong_result_type(
-    %source : tensor<4x5x6xf32>, %indices: tensor<1x2x3xindex>) {
+    %source : tensor<4x5x6xf32>, %indices: tensor<1x2x2xindex>) {
   // expected-error@+1 {{result type mismatch: expected 'tensor<1x2x1x5x1xf32>' or its rank-reduced variant 'tensor<1x2x5xf32>' (got: 'tensor<1x2x1xf32>')}}
   %out = tensor.gather %source[%indices] gather_dims([0, 2]):
-    (tensor<4x5x6xf32>, tensor<1x2x3xindex>) -> tensor<1x2x1xf32>
+    (tensor<4x5x6xf32>, tensor<1x2x2xindex>) -> tensor<1x2x1xf32>
   return
 }
 
@@ -517,12 +535,34 @@ func.func @scatter_coordinate_rank_overflow(
 
 // -----
 
+func.func @scatter_coordinate_rank_mismatch0(
+    %source : tensor<f32>,
+    %dest : tensor<4x5x6xf32>, %indices: tensor<index>) {
+  // expected-error@+1 {{scatter_dims length must match the size of last dimension of indices}}
+  %out = tensor.scatter %source into %dest[%indices] scatter_dims([0, 1, 2]) unique:
+    (tensor<f32>, tensor<4x5x6xf32>, tensor<index>) -> tensor<1x2xf32>
+  return
+}
+
+// -----
+
+func.func @scatter_coordinate_rank_mismatch1(
+    %source : tensor<f32>,
+    %dest : tensor<4x5x6xf32>, %indices: tensor<1x2x2xindex>) {
+  // expected-error@+1 {{scatter_dims length must match the size of last dimension of indices}}
+  %out = tensor.scatter %source into %dest[%indices] scatter_dims([0, 1, 2]) unique:
+    (tensor<f32>, tensor<4x5x6xf32>, tensor<1x2x2xindex>) -> tensor<1x2xf32>
+  return
+}
+
+// -----
+
 func.func @scatter_coordinate_negative(
     %source : tensor<f32>,
-    %dest : tensor<4x5x6xf32>, %indices: tensor<1x2x3xindex>) {
+    %dest : tensor<4x5x6xf32>, %indices: tensor<1x2x1xindex>) {
   // expected-error@+1 {{scatter_dims value must be non-negative}}
   %out = tensor.scatter %source into %dest[%indices] scatter_dims([-1]) unique:
-    (tensor<f32>, tensor<4x5x6xf32>, tensor<1x2x3xindex>) -> tensor<1x2x1x1x1xf32>
+    (tensor<f32>, tensor<4x5x6xf32>, tensor<1x2x1xindex>) -> tensor<1x2x1xf32>
   return
 }
 
@@ -530,21 +570,21 @@ func.func @scatter_coordinate_negative(
 
 func.func @scatter_coordinate_overflow(
     %source : tensor<f32>,
-    %dest : tensor<4x5x6xf32>, %indices: tensor<1x2x3xindex>) {
+    %dest : tensor<4x5x6xf32>, %indices: tensor<1x2x1xindex>) {
   // expected-error@+1 {{scatter_dims value must be smaller than dest rank}}
   %out = tensor.scatter %source into %dest[%indices] scatter_dims([42]) unique:
-    (tensor<f32>, tensor<4x5x6xf32>, tensor<1x2x3xindex>) -> tensor<1x2x1x1x1xf32>
+    (tensor<f32>, tensor<4x5x6xf32>, tensor<1x2x1xindex>) -> tensor<1x2x1xf32>
   return
 }
 
 // -----
 
-func.func @scatter_coordinate_overflow(
+func.func @scatter_coordinate_increase(
     %source : tensor<f32>,
-    %dest : tensor<4x5x6xf32>, %indices: tensor<1x2x3xindex>) {
+    %dest : tensor<4x5x6xf32>, %indices: tensor<1x2x2xindex>) {
   // expected-error@+1 {{scatter_dims values must be strictly increasing}}
   %out = tensor.scatter %source into %dest[%indices] scatter_dims([1, 0]) unique:
-    (tensor<f32>, tensor<4x5x6xf32>, tensor<1x2x3xindex>) -> tensor<1x2x1x1x1xf32>
+    (tensor<f32>, tensor<4x5x6xf32>, tensor<1x2x2xindex>) -> tensor<1x2x1x1xf32>
   return
 }
 
@@ -552,10 +592,10 @@ func.func @scatter_coordinate_overflow(
 
 func.func @scatter_missing_unique(
     %source : tensor<f32>,
-    %dest : tensor<4x5x6xf32>, %indices: tensor<1x2x3xindex>) {
+    %dest : tensor<4x5x6xf32>, %indices: tensor<1x2x2xindex>) {
   // expected-error@+1 {{requires 'unique' attribute to be set}}
   %out = tensor.scatter %source into %dest[%indices] scatter_dims([0, 2]):
-    (tensor<f32>, tensor<4x5x6xf32>, tensor<1x2x3xindex>) -> tensor<1x2x1xf32>
+    (tensor<f32>, tensor<4x5x6xf32>, tensor<1x2x2xindex>) -> tensor<1x2x1xf32>
   return
 }
 
@@ -563,10 +603,10 @@ func.func @scatter_missing_unique(
 
 func.func @scatter_wrong_result_type(
     %source : tensor<f32>,
-    %dest : tensor<4x5x6xf32>, %indices: tensor<1x2x3xindex>) {
+    %dest : tensor<4x5x6xf32>, %indices: tensor<1x2x2xindex>) {
   // expected-error@+1 {{source type mismatch: expected 'tensor<1x2x1x5x1xf32>' or its rank-reduced variant 'tensor<1x2x5xf32>' (got: 'tensor<f32>')}}
   %out = tensor.scatter %source into %dest[%indices] scatter_dims([0, 2]) unique:
-    (tensor<f32>, tensor<4x5x6xf32>, tensor<1x2x3xindex>) -> tensor<1x2x1xf32>
+    (tensor<f32>, tensor<4x5x6xf32>, tensor<1x2x2xindex>) -> tensor<1x2x1xf32>
   return
 }
 

Copy link

github-actions bot commented Sep 1, 2024

✅ With the latest revision this PR passed the C/C++ code formatter.

@CoTinker
Copy link
Contributor Author

CoTinker commented Sep 4, 2024

Ping~

@CoTinker
Copy link
Contributor Author

CoTinker commented Sep 5, 2024

ping~

@CoTinker
Copy link
Contributor Author

CoTinker commented Sep 6, 2024

Thanks.

@CoTinker CoTinker merged commit ede40da into llvm:main Sep 6, 2024
8 checks passed
@CoTinker CoTinker deleted the check_indices branch September 6, 2024 02:46
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging this pull request may close these issues.

[MLIR] tensor::GatherOp gather_dims not verified correctly?
3 participants