-
Notifications
You must be signed in to change notification settings - Fork 14.3k
[tosa] Add duplicate indices check for Scatter #143736
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
This patch adds, to the validation pass, checking for duplicate values in scatter operator's constant indices values. Signed-off-by: Tai Ly <[email protected]> Change-Id: I19d12f76f36f2be930971b8d6393299081db9446
@llvm/pr-subscribers-mlir Author: Tai Ly (Tai78641) ChangesTosa scatter operator disallow duplicate indices (per batch) Full diff: https://github.com/llvm/llvm-project/pull/143736.diff 4 Files Affected:
diff --git a/mlir/include/mlir/Dialect/Tosa/Utils/ConversionUtils.h b/mlir/include/mlir/Dialect/Tosa/Utils/ConversionUtils.h
index 096510a09e324..6f3b0916a7a60 100644
--- a/mlir/include/mlir/Dialect/Tosa/Utils/ConversionUtils.h
+++ b/mlir/include/mlir/Dialect/Tosa/Utils/ConversionUtils.h
@@ -243,6 +243,11 @@ bool getConstShapeValues(Operation *op,
// returns a small vector of int64_t values that attr contains
SmallVector<int64_t> convertFromIntAttr(const DenseElementsAttr &attr,
const int rank);
+
+// returns true iff constant indices for scatter op contains unique indices
+// per batch
+bool hasUniqueConstantScatterIndices(ShapedType indicesType,
+ DenseIntElementsAttr indicesAttr);
} // namespace tosa
} // namespace mlir
diff --git a/mlir/lib/Dialect/Tosa/Transforms/TosaValidation.cpp b/mlir/lib/Dialect/Tosa/Transforms/TosaValidation.cpp
index d33fc902de3a1..229f42d3178b5 100644
--- a/mlir/lib/Dialect/Tosa/Transforms/TosaValidation.cpp
+++ b/mlir/lib/Dialect/Tosa/Transforms/TosaValidation.cpp
@@ -1244,10 +1244,36 @@ bool checkErrorIfCondIf(Operation *op) {
return true;
}
+bool checkErrorIfScatter(Operation *op) {
+ auto scatterOp = dyn_cast<tosa::ScatterOp>(op);
+ if (!scatterOp)
+ return true;
+
+ // for constant indices, check that there are no duplicate values
+ DenseIntElementsAttr indicesAttr;
+ if (!matchPattern(scatterOp.getIndices(), m_Constant(&indicesAttr)))
+ return true;
+
+ auto const indicesType =
+ dyn_cast<ShapedType>(scatterOp.getIndices().getType());
+ if (!indicesType || !indicesType.hasRank()) {
+ op->emitOpError("expect ranked indices tensor");
+ return false;
+ }
+
+ if (!hasUniqueConstantScatterIndices(indicesType, indicesAttr)) {
+ op->emitOpError("indices values contain duplicates");
+ return false;
+ }
+
+ return true;
+}
+
LogicalResult TosaValidation::applyErrorIfCheck(Operation *op) {
if (!checkErrorIfResize(op) || !checkErrorIfMul(op) ||
!checkErrorIfTable(op) || !checkErrorIfRescale(op) ||
- !checkErrorIfPad(op) || !checkErrorIfCondIf(op))
+ !checkErrorIfPad(op) || !checkErrorIfCondIf(op) ||
+ !checkErrorIfScatter(op))
return failure();
return success();
}
diff --git a/mlir/lib/Dialect/Tosa/Utils/ConversionUtils.cpp b/mlir/lib/Dialect/Tosa/Utils/ConversionUtils.cpp
index e1b3be74b50fd..9844abcc34cb1 100644
--- a/mlir/lib/Dialect/Tosa/Utils/ConversionUtils.cpp
+++ b/mlir/lib/Dialect/Tosa/Utils/ConversionUtils.cpp
@@ -213,3 +213,30 @@ mlir::tosa::convertFromIntAttr(const DenseElementsAttr &attr, const int rank) {
}
return {};
}
+
+bool mlir::tosa::hasUniqueConstantScatterIndices(
+ ShapedType indicesType, DenseIntElementsAttr indicesAttr) {
+ llvm::ArrayRef<int64_t> const indicesShape = indicesType.getShape();
+ const unsigned int indicesRank = indicesShape.size();
+ const unsigned int lastDimSize = indicesShape[indicesRank - 1];
+
+ // check each batch of indices from the flat indicesAttr values
+ // for duplicates
+ auto const indicesValues = indicesAttr.getValues<int32_t>();
+ assert(
+ (indicesValues.size() % lastDimSize == 0) &&
+ "Constant indices data length should be a multiple of indicesShape[-1]");
+
+ std::vector<uint64_t> indices(lastDimSize);
+ for (auto beg = indicesValues.begin(); beg < indicesValues.end();
+ beg += lastDimSize) {
+ std::copy(beg, beg + lastDimSize, indices.begin());
+ std::sort(indices.begin(), indices.end());
+ if (std::adjacent_find(indices.begin(), indices.end()) != indices.end()) {
+ // found duplicate values in indices in batch
+ return false;
+ }
+ }
+
+ return true;
+}
diff --git a/mlir/test/Dialect/Tosa/invalid.mlir b/mlir/test/Dialect/Tosa/invalid.mlir
index a4617fc6fba8b..805522799a6d8 100644
--- a/mlir/test/Dialect/Tosa/invalid.mlir
+++ b/mlir/test/Dialect/Tosa/invalid.mlir
@@ -2015,3 +2015,13 @@ func.func @test_rescale_output_unsigned(%arg0: tensor<1x1xi8>) -> (tensor<1x1xui
%r = tosa.rescale %arg0, %1, %0, %3, %2 {input_unsigned = false, output_unsigned = true, per_channel = false, rounding_mode = "SINGLE_ROUND", scale32 = true} : (tensor<1x1xi8>, tensor<1xi32>, tensor<1xi8>, tensor<1xi8>, tensor<1xi8>) -> tensor<1x1xui8>
return %r : tensor<1x1xui8>
}
+
+// -----
+
+// CHECK-LABEL: test_scatter_duplicate_indices
+func.func @test_scatter_duplicate_indices(%arg0: tensor<2x52x3xf32>, %arg2: tensor<2x12x3xf32>) -> tensor<2x52x3xf32> {
+ %indices = "tosa.const"() { values = dense<[[1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12], [1, 2, 3, 4, 5, 6, 7, 8, 9, 3, 11, 12]]> : tensor<2x12xi32> } : () -> tensor<2x12xi32>
+ // expected-error@+1 {{'tosa.scatter' op indices values contain duplicates}}
+ %0 = tosa.scatter %arg0, %indices, %arg2 : (tensor<2x52x3xf32>, tensor<2x12xi32>, tensor<2x12x3xf32>) -> tensor<2x52x3xf32>
+ return %0 : tensor<2x52x3xf32>
+}
|
@llvm/pr-subscribers-mlir-tosa Author: Tai Ly (Tai78641) ChangesTosa scatter operator disallow duplicate indices (per batch) Full diff: https://github.com/llvm/llvm-project/pull/143736.diff 4 Files Affected:
diff --git a/mlir/include/mlir/Dialect/Tosa/Utils/ConversionUtils.h b/mlir/include/mlir/Dialect/Tosa/Utils/ConversionUtils.h
index 096510a09e324..6f3b0916a7a60 100644
--- a/mlir/include/mlir/Dialect/Tosa/Utils/ConversionUtils.h
+++ b/mlir/include/mlir/Dialect/Tosa/Utils/ConversionUtils.h
@@ -243,6 +243,11 @@ bool getConstShapeValues(Operation *op,
// returns a small vector of int64_t values that attr contains
SmallVector<int64_t> convertFromIntAttr(const DenseElementsAttr &attr,
const int rank);
+
+// returns true iff constant indices for scatter op contains unique indices
+// per batch
+bool hasUniqueConstantScatterIndices(ShapedType indicesType,
+ DenseIntElementsAttr indicesAttr);
} // namespace tosa
} // namespace mlir
diff --git a/mlir/lib/Dialect/Tosa/Transforms/TosaValidation.cpp b/mlir/lib/Dialect/Tosa/Transforms/TosaValidation.cpp
index d33fc902de3a1..229f42d3178b5 100644
--- a/mlir/lib/Dialect/Tosa/Transforms/TosaValidation.cpp
+++ b/mlir/lib/Dialect/Tosa/Transforms/TosaValidation.cpp
@@ -1244,10 +1244,36 @@ bool checkErrorIfCondIf(Operation *op) {
return true;
}
+bool checkErrorIfScatter(Operation *op) {
+ auto scatterOp = dyn_cast<tosa::ScatterOp>(op);
+ if (!scatterOp)
+ return true;
+
+ // for constant indices, check that there are no duplicate values
+ DenseIntElementsAttr indicesAttr;
+ if (!matchPattern(scatterOp.getIndices(), m_Constant(&indicesAttr)))
+ return true;
+
+ auto const indicesType =
+ dyn_cast<ShapedType>(scatterOp.getIndices().getType());
+ if (!indicesType || !indicesType.hasRank()) {
+ op->emitOpError("expect ranked indices tensor");
+ return false;
+ }
+
+ if (!hasUniqueConstantScatterIndices(indicesType, indicesAttr)) {
+ op->emitOpError("indices values contain duplicates");
+ return false;
+ }
+
+ return true;
+}
+
LogicalResult TosaValidation::applyErrorIfCheck(Operation *op) {
if (!checkErrorIfResize(op) || !checkErrorIfMul(op) ||
!checkErrorIfTable(op) || !checkErrorIfRescale(op) ||
- !checkErrorIfPad(op) || !checkErrorIfCondIf(op))
+ !checkErrorIfPad(op) || !checkErrorIfCondIf(op) ||
+ !checkErrorIfScatter(op))
return failure();
return success();
}
diff --git a/mlir/lib/Dialect/Tosa/Utils/ConversionUtils.cpp b/mlir/lib/Dialect/Tosa/Utils/ConversionUtils.cpp
index e1b3be74b50fd..9844abcc34cb1 100644
--- a/mlir/lib/Dialect/Tosa/Utils/ConversionUtils.cpp
+++ b/mlir/lib/Dialect/Tosa/Utils/ConversionUtils.cpp
@@ -213,3 +213,30 @@ mlir::tosa::convertFromIntAttr(const DenseElementsAttr &attr, const int rank) {
}
return {};
}
+
+bool mlir::tosa::hasUniqueConstantScatterIndices(
+ ShapedType indicesType, DenseIntElementsAttr indicesAttr) {
+ llvm::ArrayRef<int64_t> const indicesShape = indicesType.getShape();
+ const unsigned int indicesRank = indicesShape.size();
+ const unsigned int lastDimSize = indicesShape[indicesRank - 1];
+
+ // check each batch of indices from the flat indicesAttr values
+ // for duplicates
+ auto const indicesValues = indicesAttr.getValues<int32_t>();
+ assert(
+ (indicesValues.size() % lastDimSize == 0) &&
+ "Constant indices data length should be a multiple of indicesShape[-1]");
+
+ std::vector<uint64_t> indices(lastDimSize);
+ for (auto beg = indicesValues.begin(); beg < indicesValues.end();
+ beg += lastDimSize) {
+ std::copy(beg, beg + lastDimSize, indices.begin());
+ std::sort(indices.begin(), indices.end());
+ if (std::adjacent_find(indices.begin(), indices.end()) != indices.end()) {
+ // found duplicate values in indices in batch
+ return false;
+ }
+ }
+
+ return true;
+}
diff --git a/mlir/test/Dialect/Tosa/invalid.mlir b/mlir/test/Dialect/Tosa/invalid.mlir
index a4617fc6fba8b..805522799a6d8 100644
--- a/mlir/test/Dialect/Tosa/invalid.mlir
+++ b/mlir/test/Dialect/Tosa/invalid.mlir
@@ -2015,3 +2015,13 @@ func.func @test_rescale_output_unsigned(%arg0: tensor<1x1xi8>) -> (tensor<1x1xui
%r = tosa.rescale %arg0, %1, %0, %3, %2 {input_unsigned = false, output_unsigned = true, per_channel = false, rounding_mode = "SINGLE_ROUND", scale32 = true} : (tensor<1x1xi8>, tensor<1xi32>, tensor<1xi8>, tensor<1xi8>, tensor<1xi8>) -> tensor<1x1xui8>
return %r : tensor<1x1xui8>
}
+
+// -----
+
+// CHECK-LABEL: test_scatter_duplicate_indices
+func.func @test_scatter_duplicate_indices(%arg0: tensor<2x52x3xf32>, %arg2: tensor<2x12x3xf32>) -> tensor<2x52x3xf32> {
+ %indices = "tosa.const"() { values = dense<[[1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12], [1, 2, 3, 4, 5, 6, 7, 8, 9, 3, 11, 12]]> : tensor<2x12xi32> } : () -> tensor<2x12xi32>
+ // expected-error@+1 {{'tosa.scatter' op indices values contain duplicates}}
+ %0 = tosa.scatter %arg0, %indices, %arg2 : (tensor<2x52x3xf32>, tensor<2x12xi32>, tensor<2x12x3xf32>) -> tensor<2x52x3xf32>
+ return %0 : tensor<2x52x3xf32>
+}
|
LGTM. Thanks. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM!
Tosa scatter operator disallow duplicate indices (per batch) This patch adds, to the validation pass, checking for duplicate values in scatter operator's constant indices values. Signed-off-by: Tai Ly <[email protected]>
Tosa scatter operator disallow duplicate indices (per batch) This patch adds, to the validation pass, checking for duplicate values in scatter operator's constant indices values. Signed-off-by: Tai Ly <[email protected]>
Tosa scatter operator disallow duplicate indices (per batch)
This patch adds, to the validation pass, checking for duplicate values in scatter operator's constant indices values.