-
Notifications
You must be signed in to change notification settings - Fork 14.3k
[mlir][tosa] Require operand/result tensors of at least rank 1 for some operations #131335
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
…me operations This commit updates the following operations (operands/results) to be of at least rank 1 such that it aligns wih the expectations of the specification: - ARGMAX (input) - REDUCE_ALL (input/output) - REDUCE_ANY (input/output) - REDUCE_MAX (input/output) - REDUCE_MIN (input/output) - REDUCE_PRODUCT (input/output) - REDUCE_SUM (input/output) - CONCAT (each input in input1/output) - PAD (input1/output) - REVERSE (input1/output) - SLICE (input1/output) - TILE (input1/output) - TRANSPOSE (input1/output) In addition to this change, PAD has been updated to allow unranked tensors for input1/output, inline with other operations. Change-Id: I703c398e91c0a68366bab8e0778eb01a80588ebe
@llvm/pr-subscribers-mlir-tosa @llvm/pr-subscribers-mlir Author: Luke Hutton (lhutton1) ChangesThis commit updates the following operations (operands/results) to be of at least rank 1 such that it aligns wih the expectations of the specification:
In addition to this change, PAD has been updated to allow unranked tensors for input1/output, inline with other operations. Full diff: https://github.com/llvm/llvm-project/pull/131335.diff 5 Files Affected:
diff --git a/mlir/include/mlir/Dialect/Tosa/IR/TosaOps.td b/mlir/include/mlir/Dialect/Tosa/IR/TosaOps.td
index b79993f48b379..0c99dd6130c2a 100644
--- a/mlir/include/mlir/Dialect/Tosa/IR/TosaOps.td
+++ b/mlir/include/mlir/Dialect/Tosa/IR/TosaOps.td
@@ -41,7 +41,7 @@ def Tosa_ArgMaxOp : Tosa_InferShapedTypeOp<"argmax"> {
}];
let arguments = (ins
- Tosa_Tensor: $input,
+ Tosa_TensorAtLeast1D: $input,
I32Attr: $axis,
DefaultValuedAttr<Tosa_NanPropagationAttr, "\"PROPAGATE\"">:$nan_mode
);
@@ -1629,12 +1629,12 @@ def Tosa_ReduceAllOp : Tosa_InferTensorTypeOp<"reduce_all"> {
}];
let arguments = (ins
- Tosa_Tensor:$input,
+ Tosa_TensorAtLeast1D:$input,
I32Attr:$axis
);
let results = (outs
- Tosa_Tensor:$output
+ Tosa_TensorAtLeast1D:$output
);
list<Availability> availability = [
@@ -1668,12 +1668,12 @@ def Tosa_ReduceAnyOp : Tosa_InferTensorTypeOp<"reduce_any"> {
}];
let arguments = (ins
- Tosa_Tensor:$input,
+ Tosa_TensorAtLeast1D:$input,
I32Attr:$axis
);
let results = (outs
- Tosa_Tensor:$output
+ Tosa_TensorAtLeast1D:$output
);
list<Availability> availability = [
@@ -1707,13 +1707,13 @@ def Tosa_ReduceMaxOp : Tosa_InferTensorTypeOp<"reduce_max"> {
}];
let arguments = (ins
- Tosa_Tensor:$input,
+ Tosa_TensorAtLeast1D:$input,
I32Attr:$axis,
DefaultValuedAttr<Tosa_NanPropagationAttr, "\"PROPAGATE\"">:$nan_mode
);
let results = (outs
- Tosa_Tensor:$output
+ Tosa_TensorAtLeast1D:$output
);
list<Availability> availability = [
@@ -1748,13 +1748,13 @@ def Tosa_ReduceMinOp : Tosa_InferTensorTypeOp<"reduce_min"> {
}];
let arguments = (ins
- Tosa_Tensor:$input,
+ Tosa_TensorAtLeast1D:$input,
I32Attr:$axis,
DefaultValuedAttr<Tosa_NanPropagationAttr, "\"PROPAGATE\"">:$nan_mode
);
let results = (outs
- Tosa_Tensor:$output
+ Tosa_TensorAtLeast1D:$output
);
list<Availability> availability = [
@@ -1789,12 +1789,12 @@ def Tosa_ReduceProductOp : Tosa_InferTensorTypeOp<"reduce_product"> {
}];
let arguments = (ins
- Tosa_Tensor:$input,
+ Tosa_TensorAtLeast1D:$input,
I32Attr:$axis
);
let results = (outs
- Tosa_Tensor:$output
+ Tosa_TensorAtLeast1D:$output
);
list<Availability> availability = [
@@ -1828,12 +1828,12 @@ def Tosa_ReduceSumOp : Tosa_InferTensorTypeOp<"reduce_sum"> {
}];
let arguments = (ins
- Tosa_Tensor:$input,
+ Tosa_TensorAtLeast1D:$input,
I32Attr:$axis
);
let results = (outs
- Tosa_Tensor:$output
+ Tosa_TensorAtLeast1D:$output
);
list<Availability> availability = [
@@ -1872,12 +1872,12 @@ def Tosa_ConcatOp : Tosa_InferTensorTypeOp<"concat"> {
}];
let arguments = (ins
- Variadic<Tosa_Tensor>:$input1,
+ Variadic<Tosa_TensorAtLeast1D>:$input1,
I32Attr:$axis
);
let results = (outs
- Tosa_Tensor:$output
+ Tosa_TensorAtLeast1D:$output
);
list<Availability> availability = [
@@ -1923,13 +1923,13 @@ def Tosa_PadOp : Tosa_InferShapedTypeOp<"pad"> {
}];
let arguments = (ins
- Tosa_RankedTensor:$input1,
+ Tosa_TensorAtLeast1D:$input1,
Tosa_Shape:$padding,
Tosa_ScalarTensor:$pad_const
);
let results = (outs
- Tosa_RankedTensor:$output
+ Tosa_TensorAtLeast1D:$output
);
list<Availability> availability = [
@@ -1996,12 +1996,12 @@ def Tosa_ReverseOp: Tosa_Op<"reverse", [
}];
let arguments = (ins
- Tosa_Tensor:$input1,
+ Tosa_TensorAtLeast1D:$input1,
I32Attr:$axis
);
let results = (outs
- Tosa_Tensor:$output
+ Tosa_TensorAtLeast1D:$output
);
list<Availability> availability = [
@@ -2028,13 +2028,13 @@ def Tosa_SliceOp : Tosa_InferShapedTypeOp<"slice"> {
}];
let arguments = (ins
- Tosa_Tensor:$input1,
+ Tosa_TensorAtLeast1D:$input1,
Tosa_Shape:$start,
Tosa_Shape:$size
);
let results = (outs
- Tosa_Tensor:$output
+ Tosa_TensorAtLeast1D:$output
);
list<Availability> availability = [
@@ -2058,11 +2058,11 @@ def Tosa_TileOp : Tosa_InferShapedTypeOp<"tile"> {
}];
let arguments = (ins
- Tosa_Tensor:$input1,
+ Tosa_TensorAtLeast1D:$input1,
Tosa_Shape:$multiples);
let results = (outs
- Tosa_Tensor:$output
+ Tosa_TensorAtLeast1D:$output
);
list<Availability> availability = [
@@ -2093,12 +2093,12 @@ def Tosa_TransposeOp : Tosa_InferShapedTypeOp<"transpose",
}];
let arguments = (ins
- Tosa_Tensor:$input1,
+ Tosa_TensorAtLeast1D:$input1,
DenseI32ArrayAttr:$perms
);
let results = (
- outs Tosa_Tensor:$output
+ outs Tosa_TensorAtLeast1D:$output
);
list<Availability> availability = [
diff --git a/mlir/include/mlir/Dialect/Tosa/IR/TosaTypesBase.td b/mlir/include/mlir/Dialect/Tosa/IR/TosaTypesBase.td
index 0038d8c386ca7..67011f22fbe2a 100644
--- a/mlir/include/mlir/Dialect/Tosa/IR/TosaTypesBase.td
+++ b/mlir/include/mlir/Dialect/Tosa/IR/TosaTypesBase.td
@@ -101,6 +101,10 @@ def AllDimensionsAreSizeOne : And<[
IsRankedTensorTypePred,
CPred<"::llvm::all_of(::llvm::cast<::mlir::RankedTensorType>($_self).getShape(), [](auto v) { return v == 1; })">]>;
+def AtLeastRankOne : And<[
+ IsRankedTensorTypePred,
+ CPred<"::llvm::cast<::mlir::RankedTensorType>($_self).getRank() >= 1">]>;
+
class TosaTensorOf<
list<Type> allowedTypes, string summary = "tosa-conformant tensor">
: TensorOf<allowedTypes, [Or<[HasNo0Dimensions, IsUnrankedTensorTypePred]>], summary>;
@@ -183,6 +187,9 @@ def Tosa_TensorUpto4D : AnyTypeOf<[
def Tosa_Int32TensorUpto4D : AnyTypeOf<[
Tosa_UnrankedTensor, TosaTensorRankOf<[Tosa_Int32], [0,1,2,3,4]>]>;
+def Tosa_TensorAtLeast1D : AnyTypeOf<[
+ Tosa_UnrankedTensor, TosaRankedTensorOf<[Tosa_AnyNumber], [AtLeastRankOne]>], "tosa-conformant tensor of at least rank 1", "::mlir::TensorType">;
+
//===----------------------------------------------------------------------===//
// Generic scalar, vector, or tensor of a particular type.
//===----------------------------------------------------------------------===//
diff --git a/mlir/lib/Dialect/Tosa/IR/TosaOps.cpp b/mlir/lib/Dialect/Tosa/IR/TosaOps.cpp
index 97a3009a20302..cdba332792eb0 100644
--- a/mlir/lib/Dialect/Tosa/IR/TosaOps.cpp
+++ b/mlir/lib/Dialect/Tosa/IR/TosaOps.cpp
@@ -1354,8 +1354,13 @@ LogicalResult tosa::PadOp::verify() {
}
}
- RankedTensorType inputType = getInput1().getType();
- RankedTensorType outputType = getOutput().getType();
+ RankedTensorType inputType =
+ llvm::dyn_cast<RankedTensorType>(getInput1().getType());
+ RankedTensorType outputType =
+ llvm::dyn_cast<RankedTensorType>(getOutput().getType());
+ if (!inputType || !outputType)
+ return success();
+
auto paddingRank = cast<tosa::shapeType>(getPadding().getType()).getRank();
if (inputType.getRank() != outputType.getRank())
diff --git a/mlir/test/Dialect/Tosa/canonicalize.mlir b/mlir/test/Dialect/Tosa/canonicalize.mlir
index 3bc438e465e1d..077a6cee0a1bb 100644
--- a/mlir/test/Dialect/Tosa/canonicalize.mlir
+++ b/mlir/test/Dialect/Tosa/canonicalize.mlir
@@ -915,29 +915,6 @@ func.func @fold_abs_abs(%arg0: tensor<?x1xf32>) -> tensor<?x1xf32> {
// -----
-// CHECK-LABEL: @fold_reduce_rank_zero
-func.func @fold_reduce_rank_zero() {
- // CHECK-NOT: tosa.reduce_min
- // CHECK-NOT: tosa.reverse
- %0 = tensor.empty() : tensor<i32>
- %1 = tosa.reduce_min %0 {axis = 0 : i32} : (tensor<i32>) -> tensor<i32>
- %2 = tosa.reverse %0 {axis = 0 : i32} : (tensor<i32>) -> tensor<i32>
- return
-}
-
-// -----
-
-// CHECK-LABEL: @fold_tile_rank_zero
-func.func nested @fold_tile_rank_zero() -> tensor<i32> {
- // CHECK-NOT: tosa.tile
- %0 = tensor.empty() : tensor<i32>
- %cst = tosa.const_shape { values = dense<> : tensor<0xindex> } : () -> !tosa.shape<0>
- %1 = tosa.tile %0, %cst : (tensor<i32>, !tosa.shape<0>) -> tensor<i32>
- return %1 : tensor<i32>
-}
-
-// -----
-
// CHECK-LABEL: @reshape_quant_nofold
// check that segfault is fixed
func.func @reshape_quant_nofold() -> tensor<1x1x1x1xi32> {
@@ -1015,12 +992,12 @@ func.func @cast_quant_nofold() -> tensor<!quant.uniform<i8:f32, 3.07574046018999
// -----
// CHECK-LABEL: @reverse_quant_fold
-func.func @reverse_quant_fold() -> tensor<!quant.uniform<i8:f32, 3.0757404601899907E-5:-128>> {
- // CHECK: %[[CST:.*]] = "tosa.const"() <{values = dense<0> : tensor<i8>}> : () -> tensor<!quant.uniform<i8:f32, 3.0757404601899907E-5:-128>>
+func.func @reverse_quant_fold() -> tensor<1x!quant.uniform<i8:f32, 3.0757404601899907E-5:-128>> {
+ // CHECK: %[[CST:.*]] = "tosa.const"() <{values = dense<0> : tensor<1xi8>}> : () -> tensor<1x!quant.uniform<i8:f32, 3.0757404601899907E-5:-128>>
// CHECK: return %[[CST]]
- %0 = "tosa.const"() {values = dense<0> : tensor<i8>} : () -> tensor<!quant.uniform<i8:f32, 3.0757404601899907E-5:-128>>
- %1 = "tosa.reverse"(%0) { axis = 0 : i32 } : (tensor<!quant.uniform<i8:f32, 3.0757404601899907E-5:-128>>) -> tensor<!quant.uniform<i8:f32, 3.0757404601899907E-5:-128>>
- return %1 : tensor<!quant.uniform<i8:f32, 3.0757404601899907E-5:-128>>
+ %0 = "tosa.const"() {values = dense<0> : tensor<1xi8>} : () -> tensor<1x!quant.uniform<i8:f32, 3.0757404601899907E-5:-128>>
+ %1 = "tosa.reverse"(%0) { axis = 0 : i32 } : (tensor<1x!quant.uniform<i8:f32, 3.0757404601899907E-5:-128>>) -> tensor<1x!quant.uniform<i8:f32, 3.0757404601899907E-5:-128>>
+ return %1 : tensor<1x!quant.uniform<i8:f32, 3.0757404601899907E-5:-128>>
}
// -----
diff --git a/mlir/test/Dialect/Tosa/invalid.mlir b/mlir/test/Dialect/Tosa/invalid.mlir
index a488c051dcd3b..2dc749422c12d 100644
--- a/mlir/test/Dialect/Tosa/invalid.mlir
+++ b/mlir/test/Dialect/Tosa/invalid.mlir
@@ -452,9 +452,9 @@ func.func @test_reduce_sum_invalid_axis(%arg0 : tensor<2x3x4xf32>) -> () {
// -----
-func.func @test_reduce_min_invalid_output_rank(%arg0 : tensor<i32>) -> () {
+func.func @test_reduce_min_invalid_output_rank(%arg0 : tensor<1xi32>) -> () {
// expected-error@+1 {{'tosa.reduce_min' op expect output tensor rank to be equal to input tensor rank}}
- %0 = tosa.reduce_min %arg0 {axis = 0 : i32} : (tensor<i32>) -> tensor<1x10xi32>
+ %0 = tosa.reduce_min %arg0 {axis = 0 : i32} : (tensor<1xi32>) -> tensor<1x10xi32>
return
}
@@ -1852,3 +1852,72 @@ func.func @test_maxpool2d_unexpected_output_width(%arg0: tensor<1x32x32x8xf32>)
(tensor<1x32x32x8xf32>) -> tensor<1x32x2x8xf32>
return %0 : tensor<1x32x2x8xf32>
}
+
+// -----
+
+func.func @test_scalar_argmax(%arg0: tensor<i32>) -> tensor<i32> {
+ // expected-error@+1 {{'tosa.argmax' op operand #0 must be tosa-conformant tensor of at least rank 1, but got 'tensor<i32>'}}
+ %0 = tosa.argmax %arg0 {axis = 0 : i32} : (tensor<i32>) -> tensor<i32>
+ return %0 : tensor<i32>
+}
+
+// -----
+
+func.func @test_scalar_reduce_all(%arg0: tensor<i1>) -> tensor<i1> {
+ // expected-error@+1 {{'tosa.reduce_all' op operand #0 must be tosa-conformant tensor of at least rank 1, but got 'tensor<i1>'}}
+ %0 = tosa.reduce_all %arg0 {axis = 0 : i32} : (tensor<i1>) -> tensor<i1>
+ return %0 : tensor<i1>
+}
+
+// -----
+
+func.func @test_scalar_inputs_concat(%arg0: tensor<f32>, %arg1: tensor<f32>) -> tensor<2xf32> {
+ // expected-error@+1 {{'tosa.concat' op operand #0 must be variadic of tosa-conformant tensor of at least rank 1, but got 'tensor<f32>'}}
+ %0 = tosa.concat %arg0, %arg1 {axis = 0 : i32} : (tensor<f32>, tensor<f32>) -> tensor<2xf32>
+ return %0 : tensor<2xf32>
+}
+
+// -----
+
+func.func @test_scalar_pad(%arg0: tensor<f32>) -> tensor<f32> {
+ %0 = "tosa.const"() {values = dense<3.14> : tensor<1xf32>} : () -> tensor<1xf32>
+ %padding = tosa.const_shape {values = dense<0> : tensor<6xindex>} : () -> !tosa.shape<6>
+ // expected-error@+1 {{'tosa.pad' op operand #0 must be tosa-conformant tensor of at least rank 1, but got 'tensor<f32>'}}
+ %1 = tosa.pad %arg0, %padding, %0 : (tensor<f32>, !tosa.shape<6>, tensor<1xf32>) -> tensor<f32>
+ return %1 : tensor<f32>
+}
+
+// -----
+
+func.func @test_scalar_reverse(%arg0: tensor<f32>) -> tensor<f32> {
+ // expected-error@+1 {{'tosa.reverse' op operand #0 must be tosa-conformant tensor of at least rank 1, but got 'tensor<f32>'}}
+ %0 = tosa.reverse %arg0 {axis = 0: i32} : (tensor<f32>) -> tensor<f32>
+ return %arg0 : tensor<f32>
+}
+
+// -----
+
+func.func @test_scalar_slice(%arg0: tensor<f32>) -> tensor<f32> {
+ %0 = tosa.const_shape {values = dense<[]> : tensor<0xindex>} : () -> !tosa.shape<0>
+ %1 = tosa.const_shape {values = dense<[]> : tensor<0xindex>} : () -> !tosa.shape<0>
+ // expected-error@+1 {{'tosa.slice' op operand #0 must be tosa-conformant tensor of at least rank 1, but got 'tensor<f32>'}}
+ %2 = tosa.slice %arg0, %0, %1 : (tensor<f32>, !tosa.shape<0>, !tosa.shape<0>) -> tensor<f32>
+ return %2 : tensor<f32>
+}
+
+// -----
+
+func.func @test_scalar_tile(%arg0: tensor<f32>) -> tensor<*xf32> {
+ %cst = tosa.const_shape { values = dense<[]> : tensor<0xindex> } : () -> !tosa.shape<0>
+ // expected-error@+1 {{'tosa.tile' op operand #0 must be tosa-conformant tensor of at least rank 1, but got 'tensor<f32>'}}
+ %0 = tosa.tile %arg0, %cst: (tensor<f32>, !tosa.shape<0>) -> tensor<*xf32>
+ return %0 : tensor<*xf32>
+}
+
+// -----
+
+func.func @test_scalar_output_transpose(%arg0: tensor<*xf32>) -> tensor<f32> {
+ // expected-error@+1 {{'tosa.transpose' op result #0 must be tosa-conformant tensor of at least rank 1, but got 'tensor<f32>'}}
+ %1 = tosa.transpose %arg0 {perms = array<i32: 2, 0, 1>} : (tensor<*xf32>) -> tensor<f32>
+ return %1 : tensor<f32>
+}
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM
This commit updates the following operations (operands/results) to be of at least rank 1 such that it aligns with the expectations of the specification:
In addition to this change, PAD has been updated to allow unranked tensors for input1/output, inline with other operations.