Skip to content

Commit 1072196

Browse files
authored
[tosa] Add duplicate indices check for Scatter (#143736)
Tosa scatter operator disallow duplicate indices (per batch) This patch adds, to the validation pass, checking for duplicate values in scatter operator's constant indices values. Signed-off-by: Tai Ly <[email protected]>
1 parent 9e23e85 commit 1072196

File tree

4 files changed

+69
-1
lines changed

4 files changed

+69
-1
lines changed

mlir/include/mlir/Dialect/Tosa/Utils/ConversionUtils.h

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -243,6 +243,11 @@ bool getConstShapeValues(Operation *op,
243243
// returns a small vector of int64_t values that attr contains
244244
SmallVector<int64_t> convertFromIntAttr(const DenseElementsAttr &attr,
245245
const int rank);
246+
247+
// returns true iff constant indices for scatter op contains unique indices
248+
// per batch
249+
bool hasUniqueConstantScatterIndices(ShapedType indicesType,
250+
DenseIntElementsAttr indicesAttr);
246251
} // namespace tosa
247252
} // namespace mlir
248253

mlir/lib/Dialect/Tosa/Transforms/TosaValidation.cpp

Lines changed: 27 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1244,10 +1244,36 @@ bool checkErrorIfCondIf(Operation *op) {
12441244
return true;
12451245
}
12461246

1247+
bool checkErrorIfScatter(Operation *op) {
1248+
auto scatterOp = dyn_cast<tosa::ScatterOp>(op);
1249+
if (!scatterOp)
1250+
return true;
1251+
1252+
// for constant indices, check that there are no duplicate values
1253+
DenseIntElementsAttr indicesAttr;
1254+
if (!matchPattern(scatterOp.getIndices(), m_Constant(&indicesAttr)))
1255+
return true;
1256+
1257+
auto const indicesType =
1258+
dyn_cast<ShapedType>(scatterOp.getIndices().getType());
1259+
if (!indicesType || !indicesType.hasRank()) {
1260+
op->emitOpError("expect ranked indices tensor");
1261+
return false;
1262+
}
1263+
1264+
if (!hasUniqueConstantScatterIndices(indicesType, indicesAttr)) {
1265+
op->emitOpError("indices values contain duplicates");
1266+
return false;
1267+
}
1268+
1269+
return true;
1270+
}
1271+
12471272
LogicalResult TosaValidation::applyErrorIfCheck(Operation *op) {
12481273
if (!checkErrorIfResize(op) || !checkErrorIfMul(op) ||
12491274
!checkErrorIfTable(op) || !checkErrorIfRescale(op) ||
1250-
!checkErrorIfPad(op) || !checkErrorIfCondIf(op))
1275+
!checkErrorIfPad(op) || !checkErrorIfCondIf(op) ||
1276+
!checkErrorIfScatter(op))
12511277
return failure();
12521278
return success();
12531279
}

mlir/lib/Dialect/Tosa/Utils/ConversionUtils.cpp

Lines changed: 27 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -213,3 +213,30 @@ mlir::tosa::convertFromIntAttr(const DenseElementsAttr &attr, const int rank) {
213213
}
214214
return {};
215215
}
216+
217+
bool mlir::tosa::hasUniqueConstantScatterIndices(
218+
ShapedType indicesType, DenseIntElementsAttr indicesAttr) {
219+
llvm::ArrayRef<int64_t> const indicesShape = indicesType.getShape();
220+
const unsigned int indicesRank = indicesShape.size();
221+
const unsigned int lastDimSize = indicesShape[indicesRank - 1];
222+
223+
// check each batch of indices from the flat indicesAttr values
224+
// for duplicates
225+
auto const indicesValues = indicesAttr.getValues<int32_t>();
226+
assert(
227+
(indicesValues.size() % lastDimSize == 0) &&
228+
"Constant indices data length should be a multiple of indicesShape[-1]");
229+
230+
std::vector<uint64_t> indices(lastDimSize);
231+
for (auto beg = indicesValues.begin(); beg < indicesValues.end();
232+
beg += lastDimSize) {
233+
std::copy(beg, beg + lastDimSize, indices.begin());
234+
std::sort(indices.begin(), indices.end());
235+
if (std::adjacent_find(indices.begin(), indices.end()) != indices.end()) {
236+
// found duplicate values in indices in batch
237+
return false;
238+
}
239+
}
240+
241+
return true;
242+
}

mlir/test/Dialect/Tosa/invalid.mlir

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2015,3 +2015,13 @@ func.func @test_rescale_output_unsigned(%arg0: tensor<1x1xi8>) -> (tensor<1x1xui
20152015
%r = tosa.rescale %arg0, %1, %0, %3, %2 {input_unsigned = false, output_unsigned = true, per_channel = false, rounding_mode = "SINGLE_ROUND", scale32 = true} : (tensor<1x1xi8>, tensor<1xi32>, tensor<1xi8>, tensor<1xi8>, tensor<1xi8>) -> tensor<1x1xui8>
20162016
return %r : tensor<1x1xui8>
20172017
}
2018+
2019+
// -----
2020+
2021+
// CHECK-LABEL: test_scatter_duplicate_indices
2022+
func.func @test_scatter_duplicate_indices(%arg0: tensor<2x52x3xf32>, %arg2: tensor<2x12x3xf32>) -> tensor<2x52x3xf32> {
2023+
%indices = "tosa.const"() { values = dense<[[1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12], [1, 2, 3, 4, 5, 6, 7, 8, 9, 3, 11, 12]]> : tensor<2x12xi32> } : () -> tensor<2x12xi32>
2024+
// expected-error@+1 {{'tosa.scatter' op indices values contain duplicates}}
2025+
%0 = tosa.scatter %arg0, %indices, %arg2 : (tensor<2x52x3xf32>, tensor<2x12xi32>, tensor<2x12x3xf32>) -> tensor<2x52x3xf32>
2026+
return %0 : tensor<2x52x3xf32>
2027+
}

0 commit comments

Comments
 (0)