Skip to content

Commit 783b4d9

Browse files
authored
[mlir][tosa] Check for 0-ranked-tensors during fold (llvm#68512)
Fixes llvm#67761 Trying `getDimSize()` before checking for 0-ranked-tensors throws assert errors. This PR ensures that it is checked for. Or should we throw an error if we have a 0-ranked-tensor in a tosa operation?
1 parent a4803d8 commit 783b4d9

File tree

3 files changed

+16
-3
lines changed

3 files changed

+16
-3
lines changed

mlir/lib/Dialect/Tosa/IR/TosaCanonicalizations.cpp

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -771,7 +771,7 @@ OpFoldResult ConstOp::fold(FoldAdaptor adaptor) { return getValueAttr(); }
771771
ShapedType inputTy = llvm::cast<ShapedType>(getInput().getType()); \
772772
if (!inputTy.hasRank()) \
773773
return {}; \
774-
if (inputTy.getDimSize(getAxis()) == 1) \
774+
if (inputTy.getRank() == 0 || inputTy.getDimSize(getAxis()) == 1) \
775775
return getInput(); \
776776
return {}; \
777777
}
@@ -874,7 +874,8 @@ OpFoldResult ReverseOp::fold(FoldAdaptor adaptor) {
874874
return operandAttr;
875875

876876
// If the dim-length is 1, tosa.reverse is a no-op.
877-
if (operandTy.hasRank() && operandTy.getDimSize(axis) == 1)
877+
if (operandTy.hasRank() &&
878+
(operandTy.getRank() == 0 || operandTy.getDimSize(axis) == 1))
878879
return operand;
879880

880881
return {};

mlir/lib/Dialect/Tosa/IR/TosaOps.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1109,7 +1109,7 @@ LogicalResult tosa::ScatterOp::inferReturnTypeComponents(
11091109
static LogicalResult ReduceInferReturnTypes(
11101110
ShapeAdaptor operandShape, Type inputType, IntegerAttr axis,
11111111
SmallVectorImpl<ShapedTypeComponents> &inferredReturnShapes) {
1112-
if (!operandShape.hasRank()) {
1112+
if (!operandShape.hasRank() || operandShape.getRank() == 0) {
11131113
inferredReturnShapes.push_back(ShapedTypeComponents(inputType));
11141114
return success();
11151115
}

mlir/test/Dialect/Tosa/canonicalize.mlir

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -591,3 +591,15 @@ func.func @fold_abs_abs(%arg0: tensor<?x1xf32>) -> tensor<?x1xf32> {
591591
%1 = tosa.abs %0 : (tensor<?x1xf32>) -> tensor<?x1xf32>
592592
return %1 : tensor<?x1xf32>
593593
}
594+
595+
// -----
596+
597+
// CHECK-LABEL: @fold_reduce_rank_zero
598+
func.func nested @fold_reduce_rank_zero() {
599+
// CHECK-NOT: tosa.reduce_min
600+
// CHECK-NOT: tosa.reverse
601+
%0 = tensor.empty() : tensor<i32>
602+
%1 = tosa.reduce_min %0 {axis = 0 : i32} : (tensor<i32>) -> tensor<1x10xi32>
603+
%2 = tosa.reverse %0 {axis = 0 : i32} : (tensor<i32>) -> tensor<1x10xi32>
604+
return
605+
}

0 commit comments

Comments
 (0)