-
Notifications
You must be signed in to change notification settings - Fork 14.3k
[mlir][tensor] Add check for indices of tensor.gather
#106894
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
This patch add a check for indices of `tensor.gather` and `tensor.scatter`. For that the length of gather_dims/scatter_dims should match the size of last dimension of the indices.
e501ccc
to
d8bcd16
Compare
@llvm/pr-subscribers-mlir-tensor @llvm/pr-subscribers-mlir Author: Longsheng Mou (CoTinker) ChangesThis patch add a check for indices of Full diff: https://github.com/llvm/llvm-project/pull/106894.diff 2 Files Affected:
diff --git a/mlir/lib/Dialect/Tensor/IR/TensorOps.cpp b/mlir/lib/Dialect/Tensor/IR/TensorOps.cpp
index e11c6aaccf74dd..603773c377e915 100644
--- a/mlir/lib/Dialect/Tensor/IR/TensorOps.cpp
+++ b/mlir/lib/Dialect/Tensor/IR/TensorOps.cpp
@@ -1288,7 +1288,8 @@ RankedTensorType GatherOp::inferResultType(RankedTensorType sourceType,
}
static LogicalResult
-verifyGatherOrScatterDims(Operation *op, ArrayRef<int64_t> dims, int64_t rank,
+verifyGatherOrScatterDims(Operation *op, ArrayRef<int64_t> dims,
+ ArrayRef<int64_t> indices, int64_t rank,
StringRef gatherOrScatter, StringRef sourceOrDest) {
if (dims.empty())
return op->emitOpError(gatherOrScatter) << "_dims must be non-empty";
@@ -1297,6 +1298,9 @@ verifyGatherOrScatterDims(Operation *op, ArrayRef<int64_t> dims, int64_t rank,
if (numGatherDims > rank)
return op->emitOpError(gatherOrScatter)
<< "_dims overflow " << sourceOrDest << " rank";
+ if (indices.empty() || indices.back() != numGatherDims)
+ return op->emitOpError(gatherOrScatter)
+ << "_dims length must match the size of last dimension of indices";
for (int64_t val : dims) {
if (val < 0)
return op->emitOpError(gatherOrScatter)
@@ -1316,7 +1320,8 @@ verifyGatherOrScatterDims(Operation *op, ArrayRef<int64_t> dims, int64_t rank,
LogicalResult GatherOp::verify() {
int64_t sourceRank = getSourceType().getRank();
ArrayRef<int64_t> gatherDims = getGatherDims();
- if (failed(verifyGatherOrScatterDims(getOperation(), gatherDims, sourceRank,
+ if (failed(verifyGatherOrScatterDims(getOperation(), gatherDims,
+ getIndicesType().getShape(), sourceRank,
"gather", "source")))
return failure();
@@ -3530,7 +3535,8 @@ void ScatterOp::getAsmResultNames(
LogicalResult ScatterOp::verify() {
int64_t destRank = getDestType().getRank();
ArrayRef<int64_t> scatterDims = getScatterDims();
- if (failed(verifyGatherOrScatterDims(getOperation(), scatterDims, destRank,
+ if (failed(verifyGatherOrScatterDims(getOperation(), scatterDims,
+ getIndicesType().getShape(), destRank,
"scatter", "dest")))
return failure();
diff --git a/mlir/test/Dialect/Tensor/invalid.mlir b/mlir/test/Dialect/Tensor/invalid.mlir
index d9db32b8801ac2..84e6c59e403dde 100644
--- a/mlir/test/Dialect/Tensor/invalid.mlir
+++ b/mlir/test/Dialect/Tensor/invalid.mlir
@@ -455,41 +455,59 @@ func.func @gather_coordinate_rank_overflow(
// -----
+func.func @gather_coordinate_rank_mismatch0(
+ %source: tensor<4x5x6xf32>, %indices: tensor<index>) {
+ // expected-error@+1 {{gather_dims length must match the size of last dimension of indices}}
+ %out = tensor.gather %source[%indices] gather_dims([0, 1, 2]):
+ (tensor<4x5x6xf32>, tensor<index>) -> tensor<1x2xf32>
+}
+
+// -----
+
+func.func @gather_coordinate_rank_mismatch1(
+ %source: tensor<4x5x6xf32>, %indices: tensor<1x2x2xindex>) {
+ // expected-error@+1 {{gather_dims length must match the size of last dimension of indices}}
+ %out = tensor.gather %source[%indices] gather_dims([0, 1, 2]):
+ (tensor<4x5x6xf32>, tensor<1x2x2xindex>) -> tensor<1x2xf32>
+}
+
+// -----
+
func.func @gather_coordinate_negative(
- %source : tensor<4x5x6xf32>, %indices: tensor<1x2x3xindex>) {
+ %source : tensor<4x5x6xf32>, %indices: tensor<1x2x1xindex>) {
// expected-error@+1 {{gather_dims value must be non-negative}}
%out = tensor.gather %source[%indices] gather_dims([-1]):
- (tensor<4x5x6xf32>, tensor<1x2x3xindex>) -> tensor<1x2x1x1x1xf32>
+ (tensor<4x5x6xf32>, tensor<1x2x1xindex>) -> tensor<1x2x1xf32>
return
}
// -----
func.func @gather_coordinate_overflow(
- %source : tensor<4x5x6xf32>, %indices: tensor<1x2x3xindex>) {
+ %source : tensor<4x5x6xf32>, %indices: tensor<1x2x1xindex>) {
// expected-error@+1 {{gather_dims value must be smaller than source rank}}
%out = tensor.gather %source[%indices] gather_dims([42]):
- (tensor<4x5x6xf32>, tensor<1x2x3xindex>) -> tensor<1x2x1x1x1xf32>
+ (tensor<4x5x6xf32>, tensor<1x2x1xindex>) -> tensor<1x2x1xf32>
return
}
// -----
-func.func @gather_coordinate_overflow(
- %source : tensor<4x5x6xf32>, %indices: tensor<1x2x3xindex>) {
+func.func @gather_coordinate_increase(
+ %source : tensor<4x5x6xf32>, %indices: tensor<1x2x2xindex>) {
// expected-error@+1 {{gather_dims values must be strictly increasing}}
%out = tensor.gather %source[%indices] gather_dims([1, 0]):
- (tensor<4x5x6xf32>, tensor<1x2x3xindex>) -> tensor<1x2x1x1x1xf32>
+ (tensor<4x5x6xf32>, tensor<1x2x2xindex>) -> tensor<1x2x1x1xf32>
return
}
// -----
func.func @gather_wrong_result_type(
- %source : tensor<4x5x6xf32>, %indices: tensor<1x2x3xindex>) {
+ %source : tensor<4x5x6xf32>, %indices: tensor<1x2x2xindex>) {
// expected-error@+1 {{result type mismatch: expected 'tensor<1x2x1x5x1xf32>' or its rank-reduced variant 'tensor<1x2x5xf32>' (got: 'tensor<1x2x1xf32>')}}
%out = tensor.gather %source[%indices] gather_dims([0, 2]):
- (tensor<4x5x6xf32>, tensor<1x2x3xindex>) -> tensor<1x2x1xf32>
+ (tensor<4x5x6xf32>, tensor<1x2x2xindex>) -> tensor<1x2x1xf32>
return
}
@@ -517,12 +535,34 @@ func.func @scatter_coordinate_rank_overflow(
// -----
+func.func @scatter_coordinate_rank_mismatch0(
+ %source : tensor<f32>,
+ %dest : tensor<4x5x6xf32>, %indices: tensor<index>) {
+ // expected-error@+1 {{scatter_dims length must match the size of last dimension of indices}}
+ %out = tensor.scatter %source into %dest[%indices] scatter_dims([0, 1, 2]) unique:
+ (tensor<f32>, tensor<4x5x6xf32>, tensor<index>) -> tensor<1x2xf32>
+ return
+}
+
+// -----
+
+func.func @scatter_coordinate_rank_mismatch1(
+ %source : tensor<f32>,
+ %dest : tensor<4x5x6xf32>, %indices: tensor<1x2x2xindex>) {
+ // expected-error@+1 {{scatter_dims length must match the size of last dimension of indices}}
+ %out = tensor.scatter %source into %dest[%indices] scatter_dims([0, 1, 2]) unique:
+ (tensor<f32>, tensor<4x5x6xf32>, tensor<1x2x2xindex>) -> tensor<1x2xf32>
+ return
+}
+
+// -----
+
func.func @scatter_coordinate_negative(
%source : tensor<f32>,
- %dest : tensor<4x5x6xf32>, %indices: tensor<1x2x3xindex>) {
+ %dest : tensor<4x5x6xf32>, %indices: tensor<1x2x1xindex>) {
// expected-error@+1 {{scatter_dims value must be non-negative}}
%out = tensor.scatter %source into %dest[%indices] scatter_dims([-1]) unique:
- (tensor<f32>, tensor<4x5x6xf32>, tensor<1x2x3xindex>) -> tensor<1x2x1x1x1xf32>
+ (tensor<f32>, tensor<4x5x6xf32>, tensor<1x2x1xindex>) -> tensor<1x2x1xf32>
return
}
@@ -530,21 +570,21 @@ func.func @scatter_coordinate_negative(
func.func @scatter_coordinate_overflow(
%source : tensor<f32>,
- %dest : tensor<4x5x6xf32>, %indices: tensor<1x2x3xindex>) {
+ %dest : tensor<4x5x6xf32>, %indices: tensor<1x2x1xindex>) {
// expected-error@+1 {{scatter_dims value must be smaller than dest rank}}
%out = tensor.scatter %source into %dest[%indices] scatter_dims([42]) unique:
- (tensor<f32>, tensor<4x5x6xf32>, tensor<1x2x3xindex>) -> tensor<1x2x1x1x1xf32>
+ (tensor<f32>, tensor<4x5x6xf32>, tensor<1x2x1xindex>) -> tensor<1x2x1xf32>
return
}
// -----
-func.func @scatter_coordinate_overflow(
+func.func @scatter_coordinate_increase(
%source : tensor<f32>,
- %dest : tensor<4x5x6xf32>, %indices: tensor<1x2x3xindex>) {
+ %dest : tensor<4x5x6xf32>, %indices: tensor<1x2x2xindex>) {
// expected-error@+1 {{scatter_dims values must be strictly increasing}}
%out = tensor.scatter %source into %dest[%indices] scatter_dims([1, 0]) unique:
- (tensor<f32>, tensor<4x5x6xf32>, tensor<1x2x3xindex>) -> tensor<1x2x1x1x1xf32>
+ (tensor<f32>, tensor<4x5x6xf32>, tensor<1x2x2xindex>) -> tensor<1x2x1x1xf32>
return
}
@@ -552,10 +592,10 @@ func.func @scatter_coordinate_overflow(
func.func @scatter_missing_unique(
%source : tensor<f32>,
- %dest : tensor<4x5x6xf32>, %indices: tensor<1x2x3xindex>) {
+ %dest : tensor<4x5x6xf32>, %indices: tensor<1x2x2xindex>) {
// expected-error@+1 {{requires 'unique' attribute to be set}}
%out = tensor.scatter %source into %dest[%indices] scatter_dims([0, 2]):
- (tensor<f32>, tensor<4x5x6xf32>, tensor<1x2x3xindex>) -> tensor<1x2x1xf32>
+ (tensor<f32>, tensor<4x5x6xf32>, tensor<1x2x2xindex>) -> tensor<1x2x1xf32>
return
}
@@ -563,10 +603,10 @@ func.func @scatter_missing_unique(
func.func @scatter_wrong_result_type(
%source : tensor<f32>,
- %dest : tensor<4x5x6xf32>, %indices: tensor<1x2x3xindex>) {
+ %dest : tensor<4x5x6xf32>, %indices: tensor<1x2x2xindex>) {
// expected-error@+1 {{source type mismatch: expected 'tensor<1x2x1x5x1xf32>' or its rank-reduced variant 'tensor<1x2x5xf32>' (got: 'tensor<f32>')}}
%out = tensor.scatter %source into %dest[%indices] scatter_dims([0, 2]) unique:
- (tensor<f32>, tensor<4x5x6xf32>, tensor<1x2x3xindex>) -> tensor<1x2x1xf32>
+ (tensor<f32>, tensor<4x5x6xf32>, tensor<1x2x2xindex>) -> tensor<1x2x1xf32>
return
}
|
✅ With the latest revision this PR passed the C/C++ code formatter. |
Ping~ |
ping~ |
Thanks. |
This patch add a check for indices of
tensor.gather
andtensor.scatter
. For that the length of gather_dims/scatter_dims should match the size of last dimension of the indices. Fix #94901.