Skip to content

[tosa] Add duplicate indices check for Scatter #143736

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 1 commit into from
Jun 13, 2025

Conversation

Tai78641
Copy link
Contributor

Tosa scatter operator disallow duplicate indices (per batch)
This patch adds, to the validation pass, checking for duplicate values in scatter operator's constant indices values.

This patch adds, to the validation pass, checking for duplicate values
in scatter operator's constant indices values.

Signed-off-by: Tai Ly <[email protected]>
Change-Id: I19d12f76f36f2be930971b8d6393299081db9446
@llvmbot
Copy link
Member

llvmbot commented Jun 11, 2025

@llvm/pr-subscribers-mlir

Author: Tai Ly (Tai78641)

Changes

Tosa scatter operator disallow duplicate indices (per batch)
This patch adds, to the validation pass, checking for duplicate values in scatter operator's constant indices values.


Full diff: https://github.com/llvm/llvm-project/pull/143736.diff

4 Files Affected:

  • (modified) mlir/include/mlir/Dialect/Tosa/Utils/ConversionUtils.h (+5)
  • (modified) mlir/lib/Dialect/Tosa/Transforms/TosaValidation.cpp (+27-1)
  • (modified) mlir/lib/Dialect/Tosa/Utils/ConversionUtils.cpp (+27)
  • (modified) mlir/test/Dialect/Tosa/invalid.mlir (+10)
diff --git a/mlir/include/mlir/Dialect/Tosa/Utils/ConversionUtils.h b/mlir/include/mlir/Dialect/Tosa/Utils/ConversionUtils.h
index 096510a09e324..6f3b0916a7a60 100644
--- a/mlir/include/mlir/Dialect/Tosa/Utils/ConversionUtils.h
+++ b/mlir/include/mlir/Dialect/Tosa/Utils/ConversionUtils.h
@@ -243,6 +243,11 @@ bool getConstShapeValues(Operation *op,
 // returns a small vector of int64_t values that attr contains
 SmallVector<int64_t> convertFromIntAttr(const DenseElementsAttr &attr,
                                         const int rank);
+
+// returns true iff constant indices for scatter op contains unique indices
+// per batch
+bool hasUniqueConstantScatterIndices(ShapedType indicesType,
+                                     DenseIntElementsAttr indicesAttr);
 } // namespace tosa
 } // namespace mlir
 
diff --git a/mlir/lib/Dialect/Tosa/Transforms/TosaValidation.cpp b/mlir/lib/Dialect/Tosa/Transforms/TosaValidation.cpp
index d33fc902de3a1..229f42d3178b5 100644
--- a/mlir/lib/Dialect/Tosa/Transforms/TosaValidation.cpp
+++ b/mlir/lib/Dialect/Tosa/Transforms/TosaValidation.cpp
@@ -1244,10 +1244,36 @@ bool checkErrorIfCondIf(Operation *op) {
   return true;
 }
 
+bool checkErrorIfScatter(Operation *op) {
+  auto scatterOp = dyn_cast<tosa::ScatterOp>(op);
+  if (!scatterOp)
+    return true;
+
+  // for constant indices, check that there are no duplicate values
+  DenseIntElementsAttr indicesAttr;
+  if (!matchPattern(scatterOp.getIndices(), m_Constant(&indicesAttr)))
+    return true;
+
+  auto const indicesType =
+      dyn_cast<ShapedType>(scatterOp.getIndices().getType());
+  if (!indicesType || !indicesType.hasRank()) {
+    op->emitOpError("expect ranked indices tensor");
+    return false;
+  }
+
+  if (!hasUniqueConstantScatterIndices(indicesType, indicesAttr)) {
+    op->emitOpError("indices values contain duplicates");
+    return false;
+  }
+
+  return true;
+}
+
 LogicalResult TosaValidation::applyErrorIfCheck(Operation *op) {
   if (!checkErrorIfResize(op) || !checkErrorIfMul(op) ||
       !checkErrorIfTable(op) || !checkErrorIfRescale(op) ||
-      !checkErrorIfPad(op) || !checkErrorIfCondIf(op))
+      !checkErrorIfPad(op) || !checkErrorIfCondIf(op) ||
+      !checkErrorIfScatter(op))
     return failure();
   return success();
 }
diff --git a/mlir/lib/Dialect/Tosa/Utils/ConversionUtils.cpp b/mlir/lib/Dialect/Tosa/Utils/ConversionUtils.cpp
index e1b3be74b50fd..9844abcc34cb1 100644
--- a/mlir/lib/Dialect/Tosa/Utils/ConversionUtils.cpp
+++ b/mlir/lib/Dialect/Tosa/Utils/ConversionUtils.cpp
@@ -213,3 +213,30 @@ mlir::tosa::convertFromIntAttr(const DenseElementsAttr &attr, const int rank) {
   }
   return {};
 }
+
+bool mlir::tosa::hasUniqueConstantScatterIndices(
+    ShapedType indicesType, DenseIntElementsAttr indicesAttr) {
+  llvm::ArrayRef<int64_t> const indicesShape = indicesType.getShape();
+  const unsigned int indicesRank = indicesShape.size();
+  const unsigned int lastDimSize = indicesShape[indicesRank - 1];
+
+  // check each batch of indices from the flat indicesAttr values
+  // for duplicates
+  auto const indicesValues = indicesAttr.getValues<int32_t>();
+  assert(
+      (indicesValues.size() % lastDimSize == 0) &&
+      "Constant indices data length should be a multiple of indicesShape[-1]");
+
+  std::vector<uint64_t> indices(lastDimSize);
+  for (auto beg = indicesValues.begin(); beg < indicesValues.end();
+       beg += lastDimSize) {
+    std::copy(beg, beg + lastDimSize, indices.begin());
+    std::sort(indices.begin(), indices.end());
+    if (std::adjacent_find(indices.begin(), indices.end()) != indices.end()) {
+      // found duplicate values in indices in batch
+      return false;
+    }
+  }
+
+  return true;
+}
diff --git a/mlir/test/Dialect/Tosa/invalid.mlir b/mlir/test/Dialect/Tosa/invalid.mlir
index a4617fc6fba8b..805522799a6d8 100644
--- a/mlir/test/Dialect/Tosa/invalid.mlir
+++ b/mlir/test/Dialect/Tosa/invalid.mlir
@@ -2015,3 +2015,13 @@ func.func @test_rescale_output_unsigned(%arg0: tensor<1x1xi8>) -> (tensor<1x1xui
   %r = tosa.rescale %arg0, %1, %0, %3, %2 {input_unsigned = false, output_unsigned = true, per_channel = false, rounding_mode = "SINGLE_ROUND", scale32 = true} : (tensor<1x1xi8>, tensor<1xi32>, tensor<1xi8>, tensor<1xi8>, tensor<1xi8>) -> tensor<1x1xui8>
   return %r : tensor<1x1xui8>
 }
+
+// -----
+
+// CHECK-LABEL: test_scatter_duplicate_indices
+func.func @test_scatter_duplicate_indices(%arg0: tensor<2x52x3xf32>, %arg2: tensor<2x12x3xf32>) -> tensor<2x52x3xf32> {
+  %indices = "tosa.const"() { values = dense<[[1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12], [1, 2, 3, 4, 5, 6, 7, 8, 9, 3, 11, 12]]> : tensor<2x12xi32> } : () -> tensor<2x12xi32>
+  // expected-error@+1 {{'tosa.scatter' op indices values contain duplicates}}
+  %0 = tosa.scatter %arg0, %indices, %arg2 : (tensor<2x52x3xf32>, tensor<2x12xi32>, tensor<2x12x3xf32>) -> tensor<2x52x3xf32>
+  return %0 : tensor<2x52x3xf32>
+}

@llvmbot
Copy link
Member

llvmbot commented Jun 11, 2025

@llvm/pr-subscribers-mlir-tosa

Author: Tai Ly (Tai78641)

Changes

Tosa scatter operator disallow duplicate indices (per batch)
This patch adds, to the validation pass, checking for duplicate values in scatter operator's constant indices values.


Full diff: https://github.com/llvm/llvm-project/pull/143736.diff

4 Files Affected:

  • (modified) mlir/include/mlir/Dialect/Tosa/Utils/ConversionUtils.h (+5)
  • (modified) mlir/lib/Dialect/Tosa/Transforms/TosaValidation.cpp (+27-1)
  • (modified) mlir/lib/Dialect/Tosa/Utils/ConversionUtils.cpp (+27)
  • (modified) mlir/test/Dialect/Tosa/invalid.mlir (+10)
diff --git a/mlir/include/mlir/Dialect/Tosa/Utils/ConversionUtils.h b/mlir/include/mlir/Dialect/Tosa/Utils/ConversionUtils.h
index 096510a09e324..6f3b0916a7a60 100644
--- a/mlir/include/mlir/Dialect/Tosa/Utils/ConversionUtils.h
+++ b/mlir/include/mlir/Dialect/Tosa/Utils/ConversionUtils.h
@@ -243,6 +243,11 @@ bool getConstShapeValues(Operation *op,
 // returns a small vector of int64_t values that attr contains
 SmallVector<int64_t> convertFromIntAttr(const DenseElementsAttr &attr,
                                         const int rank);
+
+// returns true iff constant indices for scatter op contains unique indices
+// per batch
+bool hasUniqueConstantScatterIndices(ShapedType indicesType,
+                                     DenseIntElementsAttr indicesAttr);
 } // namespace tosa
 } // namespace mlir
 
diff --git a/mlir/lib/Dialect/Tosa/Transforms/TosaValidation.cpp b/mlir/lib/Dialect/Tosa/Transforms/TosaValidation.cpp
index d33fc902de3a1..229f42d3178b5 100644
--- a/mlir/lib/Dialect/Tosa/Transforms/TosaValidation.cpp
+++ b/mlir/lib/Dialect/Tosa/Transforms/TosaValidation.cpp
@@ -1244,10 +1244,36 @@ bool checkErrorIfCondIf(Operation *op) {
   return true;
 }
 
+bool checkErrorIfScatter(Operation *op) {
+  auto scatterOp = dyn_cast<tosa::ScatterOp>(op);
+  if (!scatterOp)
+    return true;
+
+  // for constant indices, check that there are no duplicate values
+  DenseIntElementsAttr indicesAttr;
+  if (!matchPattern(scatterOp.getIndices(), m_Constant(&indicesAttr)))
+    return true;
+
+  auto const indicesType =
+      dyn_cast<ShapedType>(scatterOp.getIndices().getType());
+  if (!indicesType || !indicesType.hasRank()) {
+    op->emitOpError("expect ranked indices tensor");
+    return false;
+  }
+
+  if (!hasUniqueConstantScatterIndices(indicesType, indicesAttr)) {
+    op->emitOpError("indices values contain duplicates");
+    return false;
+  }
+
+  return true;
+}
+
 LogicalResult TosaValidation::applyErrorIfCheck(Operation *op) {
   if (!checkErrorIfResize(op) || !checkErrorIfMul(op) ||
       !checkErrorIfTable(op) || !checkErrorIfRescale(op) ||
-      !checkErrorIfPad(op) || !checkErrorIfCondIf(op))
+      !checkErrorIfPad(op) || !checkErrorIfCondIf(op) ||
+      !checkErrorIfScatter(op))
     return failure();
   return success();
 }
diff --git a/mlir/lib/Dialect/Tosa/Utils/ConversionUtils.cpp b/mlir/lib/Dialect/Tosa/Utils/ConversionUtils.cpp
index e1b3be74b50fd..9844abcc34cb1 100644
--- a/mlir/lib/Dialect/Tosa/Utils/ConversionUtils.cpp
+++ b/mlir/lib/Dialect/Tosa/Utils/ConversionUtils.cpp
@@ -213,3 +213,30 @@ mlir::tosa::convertFromIntAttr(const DenseElementsAttr &attr, const int rank) {
   }
   return {};
 }
+
+bool mlir::tosa::hasUniqueConstantScatterIndices(
+    ShapedType indicesType, DenseIntElementsAttr indicesAttr) {
+  llvm::ArrayRef<int64_t> const indicesShape = indicesType.getShape();
+  const unsigned int indicesRank = indicesShape.size();
+  const unsigned int lastDimSize = indicesShape[indicesRank - 1];
+
+  // check each batch of indices from the flat indicesAttr values
+  // for duplicates
+  auto const indicesValues = indicesAttr.getValues<int32_t>();
+  assert(
+      (indicesValues.size() % lastDimSize == 0) &&
+      "Constant indices data length should be a multiple of indicesShape[-1]");
+
+  std::vector<uint64_t> indices(lastDimSize);
+  for (auto beg = indicesValues.begin(); beg < indicesValues.end();
+       beg += lastDimSize) {
+    std::copy(beg, beg + lastDimSize, indices.begin());
+    std::sort(indices.begin(), indices.end());
+    if (std::adjacent_find(indices.begin(), indices.end()) != indices.end()) {
+      // found duplicate values in indices in batch
+      return false;
+    }
+  }
+
+  return true;
+}
diff --git a/mlir/test/Dialect/Tosa/invalid.mlir b/mlir/test/Dialect/Tosa/invalid.mlir
index a4617fc6fba8b..805522799a6d8 100644
--- a/mlir/test/Dialect/Tosa/invalid.mlir
+++ b/mlir/test/Dialect/Tosa/invalid.mlir
@@ -2015,3 +2015,13 @@ func.func @test_rescale_output_unsigned(%arg0: tensor<1x1xi8>) -> (tensor<1x1xui
   %r = tosa.rescale %arg0, %1, %0, %3, %2 {input_unsigned = false, output_unsigned = true, per_channel = false, rounding_mode = "SINGLE_ROUND", scale32 = true} : (tensor<1x1xi8>, tensor<1xi32>, tensor<1xi8>, tensor<1xi8>, tensor<1xi8>) -> tensor<1x1xui8>
   return %r : tensor<1x1xui8>
 }
+
+// -----
+
+// CHECK-LABEL: test_scatter_duplicate_indices
+func.func @test_scatter_duplicate_indices(%arg0: tensor<2x52x3xf32>, %arg2: tensor<2x12x3xf32>) -> tensor<2x52x3xf32> {
+  %indices = "tosa.const"() { values = dense<[[1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12], [1, 2, 3, 4, 5, 6, 7, 8, 9, 3, 11, 12]]> : tensor<2x12xi32> } : () -> tensor<2x12xi32>
+  // expected-error@+1 {{'tosa.scatter' op indices values contain duplicates}}
+  %0 = tosa.scatter %arg0, %indices, %arg2 : (tensor<2x52x3xf32>, tensor<2x12xi32>, tensor<2x12x3xf32>) -> tensor<2x52x3xf32>
+  return %0 : tensor<2x52x3xf32>
+}

@tatwaichong
Copy link
Contributor

LGTM. Thanks.

Copy link
Contributor

@lhutton1 lhutton1 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM!

@lhutton1 lhutton1 merged commit 1072196 into llvm:main Jun 13, 2025
10 checks passed
tomtor pushed a commit to tomtor/llvm-project that referenced this pull request Jun 14, 2025
Tosa scatter operator disallow duplicate indices (per batch)
This patch adds, to the validation pass, checking for duplicate values
in scatter operator's constant indices values.

Signed-off-by: Tai Ly <[email protected]>
akuhlens pushed a commit to akuhlens/llvm-project that referenced this pull request Jun 24, 2025
Tosa scatter operator disallow duplicate indices (per batch)
This patch adds, to the validation pass, checking for duplicate values
in scatter operator's constant indices values.

Signed-off-by: Tai Ly <[email protected]>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging this pull request may close these issues.

4 participants