Skip to content

[mlir][tosa] Require operand/result tensors of at least rank 1 for some operations #131335

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 1 commit into from
Mar 17, 2025

Conversation

lhutton1
Copy link
Contributor

@lhutton1 lhutton1 commented Mar 14, 2025

This commit updates the following operations (operands/results) to be of at least rank 1 such that it aligns with the expectations of the specification:

  • ARGMAX (input)
  • REDUCE_ALL (input/output)
  • REDUCE_ANY (input/output)
  • REDUCE_MAX (input/output)
  • REDUCE_MIN (input/output)
  • REDUCE_PRODUCT (input/output)
  • REDUCE_SUM (input/output)
  • CONCAT (each input in input1/output)
  • PAD (input1/output)
  • REVERSE (input1/output)
  • SLICE (input1/output)
  • TILE (input1/output)
  • TRANSPOSE (input1/output)

In addition to this change, PAD has been updated to allow unranked tensors for input1/output, inline with other operations.

…me operations

This commit updates the following operations (operands/results) to
be of at least rank 1 such that it aligns wih the expectations
of the specification:
- ARGMAX (input)
- REDUCE_ALL (input/output)
- REDUCE_ANY (input/output)
- REDUCE_MAX (input/output)
- REDUCE_MIN (input/output)
- REDUCE_PRODUCT (input/output)
- REDUCE_SUM (input/output)
- CONCAT (each input in input1/output)
- PAD (input1/output)
- REVERSE (input1/output)
- SLICE (input1/output)
- TILE (input1/output)
- TRANSPOSE (input1/output)

In addition to this change, PAD has been updated to allow
unranked tensors for input1/output, inline with other operations.

Change-Id: I703c398e91c0a68366bab8e0778eb01a80588ebe
@llvmbot
Copy link
Member

llvmbot commented Mar 14, 2025

@llvm/pr-subscribers-mlir-tosa

@llvm/pr-subscribers-mlir

Author: Luke Hutton (lhutton1)

Changes

This commit updates the following operations (operands/results) to be of at least rank 1 such that it aligns wih the expectations of the specification:

  • ARGMAX (input)
  • REDUCE_ALL (input/output)
  • REDUCE_ANY (input/output)
  • REDUCE_MAX (input/output)
  • REDUCE_MIN (input/output)
  • REDUCE_PRODUCT (input/output)
  • REDUCE_SUM (input/output)
  • CONCAT (each input in input1/output)
  • PAD (input1/output)
  • REVERSE (input1/output)
  • SLICE (input1/output)
  • TILE (input1/output)
  • TRANSPOSE (input1/output)

In addition to this change, PAD has been updated to allow unranked tensors for input1/output, inline with other operations.


Full diff: https://github.com/llvm/llvm-project/pull/131335.diff

5 Files Affected:

  • (modified) mlir/include/mlir/Dialect/Tosa/IR/TosaOps.td (+25-25)
  • (modified) mlir/include/mlir/Dialect/Tosa/IR/TosaTypesBase.td (+7)
  • (modified) mlir/lib/Dialect/Tosa/IR/TosaOps.cpp (+7-2)
  • (modified) mlir/test/Dialect/Tosa/canonicalize.mlir (+5-28)
  • (modified) mlir/test/Dialect/Tosa/invalid.mlir (+71-2)
diff --git a/mlir/include/mlir/Dialect/Tosa/IR/TosaOps.td b/mlir/include/mlir/Dialect/Tosa/IR/TosaOps.td
index b79993f48b379..0c99dd6130c2a 100644
--- a/mlir/include/mlir/Dialect/Tosa/IR/TosaOps.td
+++ b/mlir/include/mlir/Dialect/Tosa/IR/TosaOps.td
@@ -41,7 +41,7 @@ def Tosa_ArgMaxOp : Tosa_InferShapedTypeOp<"argmax"> {
   }];
 
   let arguments = (ins
-    Tosa_Tensor: $input,
+    Tosa_TensorAtLeast1D: $input,
     I32Attr: $axis,
     DefaultValuedAttr<Tosa_NanPropagationAttr, "\"PROPAGATE\"">:$nan_mode
   );
@@ -1629,12 +1629,12 @@ def Tosa_ReduceAllOp : Tosa_InferTensorTypeOp<"reduce_all"> {
   }];
 
   let arguments = (ins
-    Tosa_Tensor:$input,
+    Tosa_TensorAtLeast1D:$input,
     I32Attr:$axis
   );
 
   let results = (outs
-    Tosa_Tensor:$output
+    Tosa_TensorAtLeast1D:$output
   );
 
   list<Availability> availability = [
@@ -1668,12 +1668,12 @@ def Tosa_ReduceAnyOp : Tosa_InferTensorTypeOp<"reduce_any"> {
   }];
 
   let arguments = (ins
-    Tosa_Tensor:$input,
+    Tosa_TensorAtLeast1D:$input,
     I32Attr:$axis
   );
 
   let results = (outs
-    Tosa_Tensor:$output
+    Tosa_TensorAtLeast1D:$output
   );
 
   list<Availability> availability = [
@@ -1707,13 +1707,13 @@ def Tosa_ReduceMaxOp : Tosa_InferTensorTypeOp<"reduce_max"> {
   }];
 
   let arguments = (ins
-    Tosa_Tensor:$input,
+    Tosa_TensorAtLeast1D:$input,
     I32Attr:$axis,
     DefaultValuedAttr<Tosa_NanPropagationAttr, "\"PROPAGATE\"">:$nan_mode
   );
 
   let results = (outs
-    Tosa_Tensor:$output
+    Tosa_TensorAtLeast1D:$output
   );
 
   list<Availability> availability = [
@@ -1748,13 +1748,13 @@ def Tosa_ReduceMinOp : Tosa_InferTensorTypeOp<"reduce_min"> {
   }];
 
   let arguments = (ins
-    Tosa_Tensor:$input,
+    Tosa_TensorAtLeast1D:$input,
     I32Attr:$axis,
     DefaultValuedAttr<Tosa_NanPropagationAttr, "\"PROPAGATE\"">:$nan_mode
   );
 
   let results = (outs
-    Tosa_Tensor:$output
+    Tosa_TensorAtLeast1D:$output
   );
 
   list<Availability> availability = [
@@ -1789,12 +1789,12 @@ def Tosa_ReduceProductOp : Tosa_InferTensorTypeOp<"reduce_product"> {
   }];
 
   let arguments = (ins
-    Tosa_Tensor:$input,
+    Tosa_TensorAtLeast1D:$input,
     I32Attr:$axis
   );
 
   let results = (outs
-    Tosa_Tensor:$output
+    Tosa_TensorAtLeast1D:$output
   );
 
   list<Availability> availability = [
@@ -1828,12 +1828,12 @@ def Tosa_ReduceSumOp : Tosa_InferTensorTypeOp<"reduce_sum"> {
   }];
 
   let arguments = (ins
-    Tosa_Tensor:$input,
+    Tosa_TensorAtLeast1D:$input,
     I32Attr:$axis
   );
 
   let results = (outs
-    Tosa_Tensor:$output
+    Tosa_TensorAtLeast1D:$output
   );
 
   list<Availability> availability = [
@@ -1872,12 +1872,12 @@ def Tosa_ConcatOp : Tosa_InferTensorTypeOp<"concat"> {
   }];
 
   let arguments = (ins
-    Variadic<Tosa_Tensor>:$input1,
+    Variadic<Tosa_TensorAtLeast1D>:$input1,
     I32Attr:$axis
   );
 
   let results = (outs
-    Tosa_Tensor:$output
+    Tosa_TensorAtLeast1D:$output
   );
 
   list<Availability> availability = [
@@ -1923,13 +1923,13 @@ def Tosa_PadOp : Tosa_InferShapedTypeOp<"pad"> {
   }];
 
   let arguments = (ins
-    Tosa_RankedTensor:$input1,
+    Tosa_TensorAtLeast1D:$input1,
     Tosa_Shape:$padding,
     Tosa_ScalarTensor:$pad_const
   );
 
   let results = (outs
-    Tosa_RankedTensor:$output
+    Tosa_TensorAtLeast1D:$output
   );
 
   list<Availability> availability = [
@@ -1996,12 +1996,12 @@ def Tosa_ReverseOp: Tosa_Op<"reverse", [
   }];
 
   let arguments = (ins
-    Tosa_Tensor:$input1,
+    Tosa_TensorAtLeast1D:$input1,
     I32Attr:$axis
   );
 
   let results = (outs
-    Tosa_Tensor:$output
+    Tosa_TensorAtLeast1D:$output
   );
 
   list<Availability> availability = [
@@ -2028,13 +2028,13 @@ def Tosa_SliceOp : Tosa_InferShapedTypeOp<"slice"> {
   }];
 
   let arguments = (ins
-    Tosa_Tensor:$input1,
+    Tosa_TensorAtLeast1D:$input1,
     Tosa_Shape:$start,
     Tosa_Shape:$size
   );
 
   let results = (outs
-    Tosa_Tensor:$output
+    Tosa_TensorAtLeast1D:$output
   );
 
   list<Availability> availability = [
@@ -2058,11 +2058,11 @@ def Tosa_TileOp : Tosa_InferShapedTypeOp<"tile"> {
   }];
 
   let arguments = (ins
-    Tosa_Tensor:$input1,
+    Tosa_TensorAtLeast1D:$input1,
     Tosa_Shape:$multiples);
 
   let results = (outs
-    Tosa_Tensor:$output
+    Tosa_TensorAtLeast1D:$output
   );
 
   list<Availability> availability = [
@@ -2093,12 +2093,12 @@ def Tosa_TransposeOp : Tosa_InferShapedTypeOp<"transpose",
   }];
 
   let arguments = (ins
-    Tosa_Tensor:$input1,
+    Tosa_TensorAtLeast1D:$input1,
     DenseI32ArrayAttr:$perms
   );
 
   let results = (
-    outs Tosa_Tensor:$output
+    outs Tosa_TensorAtLeast1D:$output
   );
 
   list<Availability> availability = [
diff --git a/mlir/include/mlir/Dialect/Tosa/IR/TosaTypesBase.td b/mlir/include/mlir/Dialect/Tosa/IR/TosaTypesBase.td
index 0038d8c386ca7..67011f22fbe2a 100644
--- a/mlir/include/mlir/Dialect/Tosa/IR/TosaTypesBase.td
+++ b/mlir/include/mlir/Dialect/Tosa/IR/TosaTypesBase.td
@@ -101,6 +101,10 @@ def AllDimensionsAreSizeOne : And<[
     IsRankedTensorTypePred,
     CPred<"::llvm::all_of(::llvm::cast<::mlir::RankedTensorType>($_self).getShape(), [](auto v) { return v == 1; })">]>;
 
+def AtLeastRankOne : And<[
+  IsRankedTensorTypePred,
+  CPred<"::llvm::cast<::mlir::RankedTensorType>($_self).getRank() >= 1">]>;
+
 class TosaTensorOf<
     list<Type> allowedTypes, string summary = "tosa-conformant tensor">
     : TensorOf<allowedTypes, [Or<[HasNo0Dimensions, IsUnrankedTensorTypePred]>], summary>;
@@ -183,6 +187,9 @@ def Tosa_TensorUpto4D : AnyTypeOf<[
 def Tosa_Int32TensorUpto4D : AnyTypeOf<[
   Tosa_UnrankedTensor, TosaTensorRankOf<[Tosa_Int32], [0,1,2,3,4]>]>;
 
+def Tosa_TensorAtLeast1D : AnyTypeOf<[
+  Tosa_UnrankedTensor, TosaRankedTensorOf<[Tosa_AnyNumber], [AtLeastRankOne]>], "tosa-conformant tensor of at least rank 1", "::mlir::TensorType">;
+
 //===----------------------------------------------------------------------===//
 // Generic scalar, vector, or tensor of a particular type.
 //===----------------------------------------------------------------------===//
diff --git a/mlir/lib/Dialect/Tosa/IR/TosaOps.cpp b/mlir/lib/Dialect/Tosa/IR/TosaOps.cpp
index 97a3009a20302..cdba332792eb0 100644
--- a/mlir/lib/Dialect/Tosa/IR/TosaOps.cpp
+++ b/mlir/lib/Dialect/Tosa/IR/TosaOps.cpp
@@ -1354,8 +1354,13 @@ LogicalResult tosa::PadOp::verify() {
     }
   }
 
-  RankedTensorType inputType = getInput1().getType();
-  RankedTensorType outputType = getOutput().getType();
+  RankedTensorType inputType =
+      llvm::dyn_cast<RankedTensorType>(getInput1().getType());
+  RankedTensorType outputType =
+      llvm::dyn_cast<RankedTensorType>(getOutput().getType());
+  if (!inputType || !outputType)
+    return success();
+
   auto paddingRank = cast<tosa::shapeType>(getPadding().getType()).getRank();
 
   if (inputType.getRank() != outputType.getRank())
diff --git a/mlir/test/Dialect/Tosa/canonicalize.mlir b/mlir/test/Dialect/Tosa/canonicalize.mlir
index 3bc438e465e1d..077a6cee0a1bb 100644
--- a/mlir/test/Dialect/Tosa/canonicalize.mlir
+++ b/mlir/test/Dialect/Tosa/canonicalize.mlir
@@ -915,29 +915,6 @@ func.func @fold_abs_abs(%arg0: tensor<?x1xf32>) -> tensor<?x1xf32> {
 
 // -----
 
-// CHECK-LABEL: @fold_reduce_rank_zero
-func.func @fold_reduce_rank_zero() {
-  // CHECK-NOT: tosa.reduce_min
-  // CHECK-NOT: tosa.reverse
-  %0 = tensor.empty() : tensor<i32>
-  %1 = tosa.reduce_min %0 {axis = 0 : i32} : (tensor<i32>) -> tensor<i32>
-  %2 = tosa.reverse %0 {axis = 0 : i32} : (tensor<i32>) -> tensor<i32>
-  return
-}
-
-// -----
-
-// CHECK-LABEL: @fold_tile_rank_zero
-func.func nested @fold_tile_rank_zero() -> tensor<i32> {
-  // CHECK-NOT: tosa.tile
-  %0 = tensor.empty() : tensor<i32>
-  %cst = tosa.const_shape { values = dense<> : tensor<0xindex> } : () -> !tosa.shape<0>
-  %1 = tosa.tile %0, %cst : (tensor<i32>, !tosa.shape<0>) -> tensor<i32>
-  return %1 : tensor<i32>
-}
-
-// -----
-
 // CHECK-LABEL: @reshape_quant_nofold
 // check that segfault is fixed
 func.func @reshape_quant_nofold() -> tensor<1x1x1x1xi32> {
@@ -1015,12 +992,12 @@ func.func @cast_quant_nofold() -> tensor<!quant.uniform<i8:f32, 3.07574046018999
 // -----
 
 // CHECK-LABEL: @reverse_quant_fold
-func.func @reverse_quant_fold() -> tensor<!quant.uniform<i8:f32, 3.0757404601899907E-5:-128>> {
-   // CHECK: %[[CST:.*]] = "tosa.const"() <{values = dense<0> : tensor<i8>}> : () -> tensor<!quant.uniform<i8:f32, 3.0757404601899907E-5:-128>>
+func.func @reverse_quant_fold() -> tensor<1x!quant.uniform<i8:f32, 3.0757404601899907E-5:-128>> {
+   // CHECK: %[[CST:.*]] = "tosa.const"() <{values = dense<0> : tensor<1xi8>}> : () -> tensor<1x!quant.uniform<i8:f32, 3.0757404601899907E-5:-128>>
    // CHECK: return %[[CST]]
-   %0 = "tosa.const"() {values = dense<0> : tensor<i8>} : () -> tensor<!quant.uniform<i8:f32, 3.0757404601899907E-5:-128>>
-   %1 = "tosa.reverse"(%0) { axis = 0 : i32 } : (tensor<!quant.uniform<i8:f32, 3.0757404601899907E-5:-128>>) -> tensor<!quant.uniform<i8:f32, 3.0757404601899907E-5:-128>>
-   return %1 : tensor<!quant.uniform<i8:f32, 3.0757404601899907E-5:-128>>
+   %0 = "tosa.const"() {values = dense<0> : tensor<1xi8>} : () -> tensor<1x!quant.uniform<i8:f32, 3.0757404601899907E-5:-128>>
+   %1 = "tosa.reverse"(%0) { axis = 0 : i32 } : (tensor<1x!quant.uniform<i8:f32, 3.0757404601899907E-5:-128>>) -> tensor<1x!quant.uniform<i8:f32, 3.0757404601899907E-5:-128>>
+   return %1 : tensor<1x!quant.uniform<i8:f32, 3.0757404601899907E-5:-128>>
 }
 
 // -----
diff --git a/mlir/test/Dialect/Tosa/invalid.mlir b/mlir/test/Dialect/Tosa/invalid.mlir
index a488c051dcd3b..2dc749422c12d 100644
--- a/mlir/test/Dialect/Tosa/invalid.mlir
+++ b/mlir/test/Dialect/Tosa/invalid.mlir
@@ -452,9 +452,9 @@ func.func @test_reduce_sum_invalid_axis(%arg0 : tensor<2x3x4xf32>) -> () {
 
 // -----
 
-func.func @test_reduce_min_invalid_output_rank(%arg0 : tensor<i32>) -> () {
+func.func @test_reduce_min_invalid_output_rank(%arg0 : tensor<1xi32>) -> () {
   // expected-error@+1 {{'tosa.reduce_min' op expect output tensor rank to be equal to input tensor rank}}
-  %0 = tosa.reduce_min %arg0 {axis = 0 : i32} : (tensor<i32>) -> tensor<1x10xi32>
+  %0 = tosa.reduce_min %arg0 {axis = 0 : i32} : (tensor<1xi32>) -> tensor<1x10xi32>
   return
 }
 
@@ -1852,3 +1852,72 @@ func.func @test_maxpool2d_unexpected_output_width(%arg0: tensor<1x32x32x8xf32>)
          (tensor<1x32x32x8xf32>) -> tensor<1x32x2x8xf32>
   return %0 : tensor<1x32x2x8xf32>
 }
+
+// -----
+
+func.func @test_scalar_argmax(%arg0: tensor<i32>) -> tensor<i32> {
+  // expected-error@+1 {{'tosa.argmax' op operand #0 must be tosa-conformant tensor of at least rank 1, but got 'tensor<i32>'}}
+  %0 = tosa.argmax %arg0 {axis = 0 : i32} : (tensor<i32>) -> tensor<i32>
+  return %0 : tensor<i32>
+}
+
+// -----
+
+func.func @test_scalar_reduce_all(%arg0: tensor<i1>) -> tensor<i1> {
+  // expected-error@+1 {{'tosa.reduce_all' op operand #0 must be tosa-conformant tensor of at least rank 1, but got 'tensor<i1>'}}
+  %0 = tosa.reduce_all %arg0 {axis = 0 : i32} : (tensor<i1>) -> tensor<i1>
+  return %0 : tensor<i1>
+}
+
+// -----
+
+func.func @test_scalar_inputs_concat(%arg0: tensor<f32>, %arg1: tensor<f32>) -> tensor<2xf32> {
+  // expected-error@+1 {{'tosa.concat' op operand #0 must be variadic of tosa-conformant tensor of at least rank 1, but got 'tensor<f32>'}}
+  %0 = tosa.concat %arg0, %arg1 {axis = 0 : i32} : (tensor<f32>, tensor<f32>) -> tensor<2xf32>
+  return %0 : tensor<2xf32>
+}
+
+// -----
+
+func.func @test_scalar_pad(%arg0: tensor<f32>) -> tensor<f32> {
+  %0 = "tosa.const"() {values = dense<3.14> : tensor<1xf32>} : () -> tensor<1xf32>
+  %padding = tosa.const_shape {values = dense<0> : tensor<6xindex>} : () -> !tosa.shape<6>
+  // expected-error@+1 {{'tosa.pad' op operand #0 must be tosa-conformant tensor of at least rank 1, but got 'tensor<f32>'}}
+  %1 = tosa.pad %arg0, %padding, %0 : (tensor<f32>, !tosa.shape<6>, tensor<1xf32>) -> tensor<f32>
+  return %1 : tensor<f32>
+}
+
+// -----
+
+func.func @test_scalar_reverse(%arg0: tensor<f32>) -> tensor<f32> {
+  // expected-error@+1 {{'tosa.reverse' op operand #0 must be tosa-conformant tensor of at least rank 1, but got 'tensor<f32>'}}
+  %0 = tosa.reverse %arg0 {axis = 0: i32} : (tensor<f32>) -> tensor<f32>
+  return %arg0 : tensor<f32>
+}
+
+// -----
+
+func.func @test_scalar_slice(%arg0: tensor<f32>) -> tensor<f32> {
+  %0 = tosa.const_shape {values = dense<[]> : tensor<0xindex>} : () -> !tosa.shape<0>
+  %1 = tosa.const_shape {values = dense<[]> : tensor<0xindex>} : () -> !tosa.shape<0>
+  // expected-error@+1 {{'tosa.slice' op operand #0 must be tosa-conformant tensor of at least rank 1, but got 'tensor<f32>'}}
+  %2 = tosa.slice %arg0, %0, %1 : (tensor<f32>, !tosa.shape<0>, !tosa.shape<0>) -> tensor<f32>
+  return %2 : tensor<f32>
+}
+
+// -----
+
+func.func @test_scalar_tile(%arg0: tensor<f32>) -> tensor<*xf32> {
+  %cst = tosa.const_shape { values = dense<[]> : tensor<0xindex> } : () -> !tosa.shape<0>
+  // expected-error@+1 {{'tosa.tile' op operand #0 must be tosa-conformant tensor of at least rank 1, but got 'tensor<f32>'}}
+  %0 = tosa.tile %arg0, %cst: (tensor<f32>, !tosa.shape<0>) -> tensor<*xf32>
+  return %0 : tensor<*xf32>
+}
+
+// -----
+
+func.func @test_scalar_output_transpose(%arg0: tensor<*xf32>) -> tensor<f32> {
+  // expected-error@+1 {{'tosa.transpose' op result #0 must be tosa-conformant tensor of at least rank 1, but got 'tensor<f32>'}}
+  %1 = tosa.transpose %arg0 {perms = array<i32: 2, 0, 1>} : (tensor<*xf32>) -> tensor<f32>
+  return %1 : tensor<f32>
+}

Copy link
Contributor

@CoTinker CoTinker left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM

@GeorgeARM GeorgeARM merged commit 0c34d7a into llvm:main Mar 17, 2025
14 checks passed
@lhutton1 lhutton1 deleted the ops-require-rank-1 branch April 9, 2025 21:18
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging this pull request may close these issues.

4 participants