Skip to content

[mlir][tosa] Move cond_if and while_loop operations to controlflow extension #128216

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 1 commit into from
Feb 25, 2025

Conversation

Tai78641
Copy link
Contributor

This commit adds the concept of a controlflow extension to the dialect and updates the validation pass to check conf_if and while_loop are supported only in the presence of the controlflow extension.

@llvmbot
Copy link
Member

llvmbot commented Feb 21, 2025

@llvm/pr-subscribers-mlir

@llvm/pr-subscribers-mlir-tosa

Author: Tai Ly (Tai78641)

Changes

This commit adds the concept of a controlflow extension to the dialect and updates the validation pass to check conf_if and while_loop are supported only in the presence of the controlflow extension.


Full diff: https://github.com/llvm/llvm-project/pull/128216.diff

11 Files Affected:

  • (modified) mlir/include/mlir/Dialect/Tosa/IR/TosaOpBase.td (+10-9)
  • (modified) mlir/include/mlir/Dialect/Tosa/IR/TosaOps.td (+4-4)
  • (modified) mlir/include/mlir/Dialect/Tosa/IR/TosaProfileCompliance.h (+1)
  • (modified) mlir/lib/Dialect/Tosa/Transforms/TosaValidation.cpp (+1-1)
  • (modified) mlir/test/Dialect/Tosa/availability.mlir (+4-4)
  • (modified) mlir/test/Dialect/Tosa/invalid.mlir (+1-1)
  • (modified) mlir/test/Dialect/Tosa/invalid_extension.mlir (+35-1)
  • (modified) mlir/test/Dialect/Tosa/level_check.mlir (+1-1)
  • (modified) mlir/test/Dialect/Tosa/profile_all_unsupported.mlir (+1-1)
  • (modified) mlir/test/Dialect/Tosa/profile_pro_fp_unsupported.mlir (+1-1)
  • (modified) mlir/test/Dialect/Tosa/profile_pro_int_unsupported.mlir (+1-1)
diff --git a/mlir/include/mlir/Dialect/Tosa/IR/TosaOpBase.td b/mlir/include/mlir/Dialect/Tosa/IR/TosaOpBase.td
index 13bbba2b492fa..95c3b5c7c983d 100644
--- a/mlir/include/mlir/Dialect/Tosa/IR/TosaOpBase.td
+++ b/mlir/include/mlir/Dialect/Tosa/IR/TosaOpBase.td
@@ -241,19 +241,20 @@ def Tosa_PRO_INT   : I32EnumAttrCase<"pro_int", 1>;
 def Tosa_PRO_FP   : I32EnumAttrCase<"pro_fp", 2>;
 def Tosa_NONE : I32EnumAttrCase<"none", 3>;
 
-def Tosa_EXT_INT16    : I32EnumAttrCase<"int16", 1>;
-def Tosa_EXT_INT4     : I32EnumAttrCase<"int4", 2>;
-def Tosa_EXT_BF16     : I32EnumAttrCase<"bf16", 3>;
-def Tosa_EXT_FP8E4M3  : I32EnumAttrCase<"fp8e4m3", 4>;
-def Tosa_EXT_FP8E5M2  : I32EnumAttrCase<"fp8e5m2", 5>;
-def Tosa_EXT_FFT      : I32EnumAttrCase<"fft", 6>;
-def Tosa_EXT_VARIABLE : I32EnumAttrCase<"variable", 7>;
-def Tosa_EXT_NONE     : I32EnumAttrCase<"none", 8>;
+def Tosa_EXT_INT16        : I32EnumAttrCase<"int16", 1>;
+def Tosa_EXT_INT4         : I32EnumAttrCase<"int4", 2>;
+def Tosa_EXT_BF16         : I32EnumAttrCase<"bf16", 3>;
+def Tosa_EXT_FP8E4M3      : I32EnumAttrCase<"fp8e4m3", 4>;
+def Tosa_EXT_FP8E5M2      : I32EnumAttrCase<"fp8e5m2", 5>;
+def Tosa_EXT_FFT          : I32EnumAttrCase<"fft", 6>;
+def Tosa_EXT_VARIABLE     : I32EnumAttrCase<"variable", 7>;
+def Tosa_EXT_CONTROLFLOW  : I32EnumAttrCase<"controlflow", 8>;
+def Tosa_EXT_NONE         : I32EnumAttrCase<"none", 9>;
 
 def Tosa_ExtensionAttr
     : Tosa_I32EnumAttr<"Extension", "supported TOSA extensions", "ext", [
       Tosa_EXT_INT16, Tosa_EXT_INT4, Tosa_EXT_BF16, Tosa_EXT_FP8E4M3,
-      Tosa_EXT_FP8E5M2, Tosa_EXT_FFT, Tosa_EXT_VARIABLE, Tosa_EXT_NONE
+      Tosa_EXT_FP8E5M2, Tosa_EXT_FFT, Tosa_EXT_VARIABLE, Tosa_EXT_CONTROLFLOW, Tosa_EXT_NONE
     ]>;
 
 def Tosa_ExtensionArrayAttr
diff --git a/mlir/include/mlir/Dialect/Tosa/IR/TosaOps.td b/mlir/include/mlir/Dialect/Tosa/IR/TosaOps.td
index 69a408767b3c6..7839548f4273e 100644
--- a/mlir/include/mlir/Dialect/Tosa/IR/TosaOps.td
+++ b/mlir/include/mlir/Dialect/Tosa/IR/TosaOps.td
@@ -2390,8 +2390,8 @@ def Tosa_IfOp : Tosa_Op<"cond_if",
   );
 
   list<Availability> availability = [
-    Profile<[Tosa_PRO_INT, Tosa_PRO_FP]>,
-    Extension<[]>,
+    Profile<[]>,
+    Extension<[Tosa_EXT_CONTROLFLOW]>,
   ];
 
   let regions = (region
@@ -2431,8 +2431,8 @@ def Tosa_WhileOp : Tosa_Op<"while_loop", [
   );
 
   list<Availability> availability = [
-    Profile<[Tosa_PRO_INT, Tosa_PRO_FP]>,
-    Extension<[]>,
+    Profile<[]>,
+    Extension<[Tosa_EXT_CONTROLFLOW]>,
   ];
 
   let regions = (region
diff --git a/mlir/include/mlir/Dialect/Tosa/IR/TosaProfileCompliance.h b/mlir/include/mlir/Dialect/Tosa/IR/TosaProfileCompliance.h
index a831bae12f3c1..57a4cd6a382ee 100644
--- a/mlir/include/mlir/Dialect/Tosa/IR/TosaProfileCompliance.h
+++ b/mlir/include/mlir/Dialect/Tosa/IR/TosaProfileCompliance.h
@@ -141,6 +141,7 @@ class TosaProfileCompliance {
     case Extension::fft:
       return {Profile::pro_fp};
     case Extension::variable:
+    case Extension::controlflow:
       return {Profile::pro_fp, Profile::pro_int};
     case Extension::none:
       return {};
diff --git a/mlir/lib/Dialect/Tosa/Transforms/TosaValidation.cpp b/mlir/lib/Dialect/Tosa/Transforms/TosaValidation.cpp
index f74a4b4c58b80..32648830bb760 100644
--- a/mlir/lib/Dialect/Tosa/Transforms/TosaValidation.cpp
+++ b/mlir/lib/Dialect/Tosa/Transforms/TosaValidation.cpp
@@ -425,7 +425,7 @@ struct TosaValidation : public tosa::impl::TosaValidationBase<TosaValidation> {
         } else {
           llvm::errs() << "unknown TOSA extension name passed in: " << ext
                        << ", supported extension are int16, int4, bf16, "
-                       << "fp8e4m3, fp8e5m2, fft, and variable\n";
+                       << "fp8e4m3, fp8e5m2, fft, variable and controlflow\n";
           return signalPassFailure();
         }
       }
diff --git a/mlir/test/Dialect/Tosa/availability.mlir b/mlir/test/Dialect/Tosa/availability.mlir
index e66ff4cacfd89..da8f9ef82c839 100644
--- a/mlir/test/Dialect/Tosa/availability.mlir
+++ b/mlir/test/Dialect/Tosa/availability.mlir
@@ -629,8 +629,8 @@ func.func @test_identity(%arg0: tensor<13x21x3xi32>) -> tensor<13x21x3xi32> {
 // -----
 // CHECK-LABEL: cond_if
 func.func @test_cond_if(%arg0: tensor<f32>, %arg1: tensor<f32>, %arg2: tensor<i1>) -> tensor<f32> {
-  // CHECK: profiles: [ [pro_int, pro_fp] ]
-  // CHECK: extensions: [ [bf16] ]
+  // CHECK: tosa.cond_if profiles: [ ]
+  // CHECK: tosa.cond_if extensions: [ [controlflow] ]
   %0 = tosa.cond_if %arg2 -> (tensor<f32>) {
     %1 = tosa.add %arg0, %arg1 : (tensor<f32>, tensor<f32>) -> tensor<f32>
     tosa.yield %1 : tensor<f32>
@@ -645,8 +645,8 @@ func.func @test_cond_if(%arg0: tensor<f32>, %arg1: tensor<f32>, %arg2: tensor<i1
 // CHECK-LABEL: while_loop
 func.func @test_while_loop(%arg0: tensor<10xi32>, %arg1: tensor<i32>) {
   %0 = "tosa.const"() {value = dense<0> : tensor<i32>} : () -> tensor<i32>
-  // CHECK: profiles: [ [pro_int, pro_fp] ]
-  // CHECK: extensions: [ [bf16] ]
+  // CHECK: profiles: [ ]
+  // CHECK: extensions: [ [controlflow] ]
   %1:3 = tosa.while_loop (%arg2 = %0, %arg3 = %0, %arg4 = %arg0) : (tensor<i32>, tensor<i32>, tensor<10xi32>) -> (tensor<i32>, tensor<i32>, tensor<10xi32>) {
     %2 = tosa.greater_equal %arg3, %arg1 : (tensor<i32>, tensor<i32>) -> tensor<i1>
     %3 = tosa.logical_not %2 : (tensor<i1>) -> tensor<i1>
diff --git a/mlir/test/Dialect/Tosa/invalid.mlir b/mlir/test/Dialect/Tosa/invalid.mlir
index 1aa8547cb2fdb..c44a0d1c09215 100644
--- a/mlir/test/Dialect/Tosa/invalid.mlir
+++ b/mlir/test/Dialect/Tosa/invalid.mlir
@@ -4,7 +4,7 @@
 // validation flow.
 //--------------------------------------------------------------------------------------------------
 
-// RUN: mlir-opt %s -split-input-file -verify-diagnostics --tosa-validate="profile=pro_int,pro_fp extension=int16,int4,bf16,fp8e4m3,fp8e5m2,fft,variable strict-op-spec-alignment"
+// RUN: mlir-opt %s -split-input-file -verify-diagnostics --tosa-validate="profile=pro_int,pro_fp extension=int16,int4,bf16,fp8e4m3,fp8e5m2,fft,variable,controlflow strict-op-spec-alignment"
 
 func.func @test_const() -> tensor<1xf32> {
   // expected-error@+1{{'tosa.const' op expected same attr/result element types}}
diff --git a/mlir/test/Dialect/Tosa/invalid_extension.mlir b/mlir/test/Dialect/Tosa/invalid_extension.mlir
index 046b9d5615074..684875f231dec 100644
--- a/mlir/test/Dialect/Tosa/invalid_extension.mlir
+++ b/mlir/test/Dialect/Tosa/invalid_extension.mlir
@@ -2,7 +2,7 @@
 // Enable all supported profiles to focus the verification of expected extension requirement errors.
 //--------------------------------------------------------------------------------------------------
 
-// RUN: mlir-opt %s -split-input-file -verify-diagnostics -tosa-validate="profile=pro_int,pro_fp,mt strict-op-spec-alignment"
+// RUN: mlir-opt %s -split-input-file -verify-diagnostics -tosa-validate="profile=pro_int,pro_fp strict-op-spec-alignment"
 
 // -----
 func.func @test_fft2d(%arg0: tensor<1x4x8xf32>, %arg1: tensor<1x4x8xf32>) -> (tensor<1x4x8xf32>, tensor<1x4x8xf32>) {
@@ -36,3 +36,37 @@ func.func @test_cast_bf16_i32(%arg0: tensor<13x21x3xbf16>) -> tensor<13x21x3xi32
   return %0 : tensor<13x21x3xi32>
 }
 
+// -----
+func.func @test_cond_if(%arg0: tensor<f32>, %arg1: tensor<f32>, %arg2: tensor<i1>) -> tensor<f32> {
+  // expected-error@+1 {{'tosa.cond_if' op illegal: requires [controlflow]}}
+  %0 = tosa.cond_if %arg2 -> (tensor<f32>) {
+    %1 = tosa.add %arg0, %arg1 : (tensor<f32>, tensor<f32>) -> tensor<f32>
+    tosa.yield %1 : tensor<f32>
+  } else {
+    %1 = tosa.sub %arg0, %arg1 : (tensor<f32>, tensor<f32>) -> tensor<f32>
+    tosa.yield %1 : tensor<f32>
+  }
+  return %0 : tensor<f32>
+}
+
+// -----
+func.func @test_while_loop(%arg0: tensor<10xi32>, %arg1: tensor<i32>) {
+  %0 = "tosa.const"() {value = dense<0> : tensor<i32>} : () -> tensor<i32>
+  // expected-error@+1 {{'tosa.while_loop' op illegal: requires [controlflow]}}
+  %1:3 = tosa.while_loop (%arg2 = %0, %arg3 = %0, %arg4 = %arg0) : (tensor<i32>, tensor<i32>, tensor<10xi32>) -> (tensor<i32>, tensor<i32>, tensor<10xi32>) {
+    %2 = tosa.greater_equal %arg3, %arg1 : (tensor<i32>, tensor<i32>) -> tensor<i1>
+    %3 = tosa.logical_not %2 : (tensor<i1>) -> tensor<i1>
+    tosa.yield %3 : tensor<i1>
+  } do {
+  ^bb0(%arg2: tensor<i32>, %arg3: tensor<i32>, %arg4: tensor<10xi32>):
+    %2 = "tosa.const"() {value = dense<1> : tensor<i32>} : () -> tensor<i32>
+    %3 = tosa.add %arg3, %2 : (tensor<i32>, tensor<i32>) -> tensor<i32>
+    %7 = tosa.const_shape {value = dense<[1]> : tensor<1xindex>} : () -> !tosa.shape<1>
+    %4 = tosa.reshape %2, %7 : (tensor<i32>, !tosa.shape<1>) -> tensor<1xi32>
+    %5 = tosa.add %arg4, %4 : (tensor<10xi32>, tensor<1xi32>) -> tensor<10xi32>
+    %6 = tosa.add %arg2, %2 : (tensor<i32>, tensor<i32>) -> tensor<i32>
+    tosa.yield %6, %3, %5 : tensor<i32>, tensor<i32>, tensor<10xi32>
+  }
+  return
+}
+
diff --git a/mlir/test/Dialect/Tosa/level_check.mlir b/mlir/test/Dialect/Tosa/level_check.mlir
index 90c4551564d1e..a75a6bee8e809 100644
--- a/mlir/test/Dialect/Tosa/level_check.mlir
+++ b/mlir/test/Dialect/Tosa/level_check.mlir
@@ -2,7 +2,7 @@
 // Enable all supported profiles and extensions to focus the verification of expected level errors.
 //--------------------------------------------------------------------------------------------------
 
-// RUN: mlir-opt %s -split-input-file -verify-diagnostics --tosa-validate="profile=pro_int,pro_fp,mt extension=int16,int4,bf16,fp8e4m3,fp8e5m2,fft,variable"
+// RUN: mlir-opt %s -split-input-file -verify-diagnostics --tosa-validate="profile=pro_int,pro_fp extension=int16,int4,bf16,fp8e4m3,fp8e5m2,fft,variable,controlflow"
 
 func.func @test_argmax(%arg0: tensor<1x1x1x1x29x29x4xf32>) -> tensor<1x1x1x1x29x4xi32> {
   // expected-error@+1 {{'tosa.argmax' op failed level check: operand rank(shape) <= MAX_RANK}}
diff --git a/mlir/test/Dialect/Tosa/profile_all_unsupported.mlir b/mlir/test/Dialect/Tosa/profile_all_unsupported.mlir
index 6dddcf329d110..8183b58272e84 100644
--- a/mlir/test/Dialect/Tosa/profile_all_unsupported.mlir
+++ b/mlir/test/Dialect/Tosa/profile_all_unsupported.mlir
@@ -2,7 +2,7 @@
 // Enable all supported extensions to focus the verification of expected profile requirement errors.
 //--------------------------------------------------------------------------------------------------
 
-// RUN: mlir-opt %s -split-input-file -verify-diagnostics -tosa-validate="extension=int16,int4,bf16,fp8e4m3,fp8e5m2,fft,variable strict-op-spec-alignment"
+// RUN: mlir-opt %s -split-input-file -verify-diagnostics -tosa-validate="extension=int16,int4,bf16,fp8e4m3,fp8e5m2,fft,variable,controlflow strict-op-spec-alignment"
 
 // -----
 func.func @test_table(%arg0 : tensor<4x5xi8>, %arg1 : tensor<513xi8>) -> () {
diff --git a/mlir/test/Dialect/Tosa/profile_pro_fp_unsupported.mlir b/mlir/test/Dialect/Tosa/profile_pro_fp_unsupported.mlir
index c46b2543fbed5..f7cbd114280dc 100644
--- a/mlir/test/Dialect/Tosa/profile_pro_fp_unsupported.mlir
+++ b/mlir/test/Dialect/Tosa/profile_pro_fp_unsupported.mlir
@@ -2,7 +2,7 @@
 // Enable all supported extensions to focus the verification of expected profile requirement errors.
 //--------------------------------------------------------------------------------------------------
 
-// RUN: mlir-opt %s -split-input-file -verify-diagnostics -tosa-validate="profile=pro_int extension=int16,int4,bf16,fp8e4m3,fp8e5m2,fft,variable strict-op-spec-alignment"
+// RUN: mlir-opt %s -split-input-file -verify-diagnostics -tosa-validate="profile=pro_int extension=int16,int4,bf16,fp8e4m3,fp8e5m2,fft,variable,controlflow strict-op-spec-alignment"
 
 // -----
 func.func @test_conv2d(%arg0: tensor<1x4x4x4xf32>, %arg1: tensor<8x1x1x4xf32>, %arg2: tensor<8xf32>) -> tensor<1x4x4x8xf32> {
diff --git a/mlir/test/Dialect/Tosa/profile_pro_int_unsupported.mlir b/mlir/test/Dialect/Tosa/profile_pro_int_unsupported.mlir
index 479b7569f54ae..1d6d33b9a02c7 100644
--- a/mlir/test/Dialect/Tosa/profile_pro_int_unsupported.mlir
+++ b/mlir/test/Dialect/Tosa/profile_pro_int_unsupported.mlir
@@ -2,7 +2,7 @@
 // Enable all supported extensions to focus the verification of expected profile requirement errors.
 //--------------------------------------------------------------------------------------------------
 
-// RUN: mlir-opt %s -split-input-file -verify-diagnostics -tosa-validate="profile=pro_fp extension=int16,int4,bf16,fp8e4m3,fp8e5m2,fft,variable strict-op-spec-alignment"
+// RUN: mlir-opt %s -split-input-file -verify-diagnostics -tosa-validate="profile=pro_fp extension=int16,int4,bf16,fp8e4m3,fp8e5m2,fft,variable,controlflow strict-op-spec-alignment"
 
 // -----
 func.func @test_table(%arg0 : tensor<4x5xi8>, %arg1 : tensor<513xi8>) -> () {

@Tai78641 Tai78641 changed the title [TOSA] Move cond_if and while_loop operations to controlflow extension [mlir][tosa] Move cond_if and while_loop operations to controlflow extension Feb 21, 2025
def Tosa_EXT_FFT : I32EnumAttrCase<"fft", 6>;
def Tosa_EXT_VARIABLE : I32EnumAttrCase<"variable", 7>;
def Tosa_EXT_CONTROLFLOW : I32EnumAttrCase<"controlflow", 8>;
def Tosa_EXT_NONE : I32EnumAttrCase<"none", 9>;
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why is this needed? Does it makes sense to move it to id 0?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Are you referring to Tosa_EXT_NONE? Yes, it makes sense to move this to id 0 to avoid assigning zero to any valid extension,

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

moved to id 0

This commit adds the concept of a controlflow extension to the dialect
and updates the validation pass to check conf_if and while_loop are
supported only in the presence of the controlflow extension.

Change-Id: Ia2304baebd372d85f7e4f31e82d94ab85679e660
Signed-off-by: Luke Hutton <[email protected]>
@Jerry-Ge Jerry-Ge merged commit e58f475 into llvm:main Feb 25, 2025
11 checks passed
Jerry-Ge pushed a commit to Jerry-Ge/llvm-project that referenced this pull request Feb 28, 2025
A previous patch(llvm#128216) that added the support for the control flow
extension overlooked adding a comment. This patch adds the comment.

Change-Id: I5f922e23c6eecdef70968d5ee3e986446dd147de
Signed-off-by: Luke Hutton <[email protected]>
lhutton1 added a commit that referenced this pull request Mar 1, 2025
A previous patch(#128216) that added the support for the control flow
extension overlooked adding a comment. This patch adds the comment.

Signed-off-by: Luke Hutton <[email protected]>
Co-authored-by: Luke Hutton <[email protected]>
jph-13 pushed a commit to jph-13/llvm-project that referenced this pull request Mar 21, 2025
A previous patch(llvm#128216) that added the support for the control flow
extension overlooked adding a comment. This patch adds the comment.

Signed-off-by: Luke Hutton <[email protected]>
Co-authored-by: Luke Hutton <[email protected]>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging this pull request may close these issues.

6 participants