Skip to content

Commit 5418e38

Browse files
lhutton1Tai78641
authored andcommitted
[TOSA] Move cond_if and while_loop operations to controlflow extension
This commit adds the concept of a controlflow extension to the dialect and updates the validation pass to check conf_if and while_loop are supported only in the presence of the controlflow extension. Change-Id: Ia2304baebd372d85f7e4f31e82d94ab85679e660 Signed-off-by: Luke Hutton <[email protected]>
1 parent 9fac59a commit 5418e38

File tree

11 files changed

+61
-25
lines changed

11 files changed

+61
-25
lines changed

mlir/include/mlir/Dialect/Tosa/IR/TosaOpBase.td

Lines changed: 11 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -237,23 +237,24 @@ class Tosa_I32EnumAttr<string name, string description, string mnemonic,
237237
// VARIABLE : Stateful variable operations.
238238
//===----------------------------------------------------------------------===//
239239

240+
def Tosa_NONE : I32EnumAttrCase<"none", 0>;
240241
def Tosa_PRO_INT : I32EnumAttrCase<"pro_int", 1>;
241242
def Tosa_PRO_FP : I32EnumAttrCase<"pro_fp", 2>;
242-
def Tosa_NONE : I32EnumAttrCase<"none", 3>;
243243

244-
def Tosa_EXT_INT16 : I32EnumAttrCase<"int16", 1>;
245-
def Tosa_EXT_INT4 : I32EnumAttrCase<"int4", 2>;
246-
def Tosa_EXT_BF16 : I32EnumAttrCase<"bf16", 3>;
247-
def Tosa_EXT_FP8E4M3 : I32EnumAttrCase<"fp8e4m3", 4>;
248-
def Tosa_EXT_FP8E5M2 : I32EnumAttrCase<"fp8e5m2", 5>;
249-
def Tosa_EXT_FFT : I32EnumAttrCase<"fft", 6>;
250-
def Tosa_EXT_VARIABLE : I32EnumAttrCase<"variable", 7>;
251-
def Tosa_EXT_NONE : I32EnumAttrCase<"none", 8>;
244+
def Tosa_EXT_NONE : I32EnumAttrCase<"none", 0>;
245+
def Tosa_EXT_INT16 : I32EnumAttrCase<"int16", 1>;
246+
def Tosa_EXT_INT4 : I32EnumAttrCase<"int4", 2>;
247+
def Tosa_EXT_BF16 : I32EnumAttrCase<"bf16", 3>;
248+
def Tosa_EXT_FP8E4M3 : I32EnumAttrCase<"fp8e4m3", 4>;
249+
def Tosa_EXT_FP8E5M2 : I32EnumAttrCase<"fp8e5m2", 5>;
250+
def Tosa_EXT_FFT : I32EnumAttrCase<"fft", 6>;
251+
def Tosa_EXT_VARIABLE : I32EnumAttrCase<"variable", 7>;
252+
def Tosa_EXT_CONTROLFLOW : I32EnumAttrCase<"controlflow", 8>;
252253

253254
def Tosa_ExtensionAttr
254255
: Tosa_I32EnumAttr<"Extension", "supported TOSA extensions", "ext", [
255256
Tosa_EXT_INT16, Tosa_EXT_INT4, Tosa_EXT_BF16, Tosa_EXT_FP8E4M3,
256-
Tosa_EXT_FP8E5M2, Tosa_EXT_FFT, Tosa_EXT_VARIABLE, Tosa_EXT_NONE
257+
Tosa_EXT_FP8E5M2, Tosa_EXT_FFT, Tosa_EXT_VARIABLE, Tosa_EXT_CONTROLFLOW, Tosa_EXT_NONE
257258
]>;
258259

259260
def Tosa_ExtensionArrayAttr

mlir/include/mlir/Dialect/Tosa/IR/TosaOps.td

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -2436,8 +2436,8 @@ def Tosa_IfOp : Tosa_Op<"cond_if",
24362436
);
24372437

24382438
list<Availability> availability = [
2439-
Profile<[Tosa_PRO_INT, Tosa_PRO_FP]>,
2440-
Extension<[]>,
2439+
Profile<[]>,
2440+
Extension<[Tosa_EXT_CONTROLFLOW]>,
24412441
];
24422442

24432443
let regions = (region
@@ -2477,8 +2477,8 @@ def Tosa_WhileOp : Tosa_Op<"while_loop", [
24772477
);
24782478

24792479
list<Availability> availability = [
2480-
Profile<[Tosa_PRO_INT, Tosa_PRO_FP]>,
2481-
Extension<[]>,
2480+
Profile<[]>,
2481+
Extension<[Tosa_EXT_CONTROLFLOW]>,
24822482
];
24832483

24842484
let regions = (region

mlir/include/mlir/Dialect/Tosa/IR/TosaProfileCompliance.h

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -143,6 +143,7 @@ class TosaProfileCompliance {
143143
case Extension::fft:
144144
return {Profile::pro_fp};
145145
case Extension::variable:
146+
case Extension::controlflow:
146147
return {Profile::pro_fp, Profile::pro_int};
147148
case Extension::none:
148149
return {};

mlir/lib/Dialect/Tosa/Transforms/TosaValidation.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -425,7 +425,7 @@ struct TosaValidation : public tosa::impl::TosaValidationBase<TosaValidation> {
425425
} else {
426426
llvm::errs() << "unknown TOSA extension name passed in: " << ext
427427
<< ", supported extension are int16, int4, bf16, "
428-
<< "fp8e4m3, fp8e5m2, fft, and variable\n";
428+
<< "fp8e4m3, fp8e5m2, fft, variable and controlflow\n";
429429
return signalPassFailure();
430430
}
431431
}

mlir/test/Dialect/Tosa/availability.mlir

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -629,8 +629,8 @@ func.func @test_identity(%arg0: tensor<13x21x3xi32>) -> tensor<13x21x3xi32> {
629629
// -----
630630
// CHECK-LABEL: cond_if
631631
func.func @test_cond_if(%arg0: tensor<f32>, %arg1: tensor<f32>, %arg2: tensor<i1>) -> tensor<f32> {
632-
// CHECK: profiles: [ [pro_int, pro_fp] ]
633-
// CHECK: extensions: [ [bf16] ]
632+
// CHECK: tosa.cond_if profiles: [ ]
633+
// CHECK: tosa.cond_if extensions: [ [controlflow] ]
634634
%0 = tosa.cond_if %arg2 -> (tensor<f32>) {
635635
%1 = tosa.add %arg0, %arg1 : (tensor<f32>, tensor<f32>) -> tensor<f32>
636636
tosa.yield %1 : tensor<f32>
@@ -645,8 +645,8 @@ func.func @test_cond_if(%arg0: tensor<f32>, %arg1: tensor<f32>, %arg2: tensor<i1
645645
// CHECK-LABEL: while_loop
646646
func.func @test_while_loop(%arg0: tensor<10xi32>, %arg1: tensor<i32>) {
647647
%0 = "tosa.const"() {value = dense<0> : tensor<i32>} : () -> tensor<i32>
648-
// CHECK: profiles: [ [pro_int, pro_fp] ]
649-
// CHECK: extensions: [ [bf16] ]
648+
// CHECK: profiles: [ ]
649+
// CHECK: extensions: [ [controlflow] ]
650650
%1:3 = tosa.while_loop (%arg2 = %0, %arg3 = %0, %arg4 = %arg0) : (tensor<i32>, tensor<i32>, tensor<10xi32>) -> (tensor<i32>, tensor<i32>, tensor<10xi32>) {
651651
%2 = tosa.greater_equal %arg3, %arg1 : (tensor<i32>, tensor<i32>) -> tensor<i1>
652652
%3 = tosa.logical_not %2 : (tensor<i1>) -> tensor<i1>

mlir/test/Dialect/Tosa/invalid.mlir

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@
44
// validation flow.
55
//--------------------------------------------------------------------------------------------------
66

7-
// RUN: mlir-opt %s -split-input-file -verify-diagnostics --tosa-validate="profile=pro_int,pro_fp extension=int16,int4,bf16,fp8e4m3,fp8e5m2,fft,variable strict-op-spec-alignment"
7+
// RUN: mlir-opt %s -split-input-file -verify-diagnostics --tosa-validate="profile=pro_int,pro_fp extension=int16,int4,bf16,fp8e4m3,fp8e5m2,fft,variable,controlflow strict-op-spec-alignment"
88

99
func.func @test_const() -> tensor<1xf32> {
1010
// expected-error@+1{{'tosa.const' op expected same attr/result element types}}

mlir/test/Dialect/Tosa/invalid_extension.mlir

Lines changed: 35 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,7 @@
22
// Enable all supported profiles to focus the verification of expected extension requirement errors.
33
//--------------------------------------------------------------------------------------------------
44

5-
// RUN: mlir-opt %s -split-input-file -verify-diagnostics -tosa-validate="profile=pro_int,pro_fp,mt strict-op-spec-alignment"
5+
// RUN: mlir-opt %s -split-input-file -verify-diagnostics -tosa-validate="profile=pro_int,pro_fp strict-op-spec-alignment"
66

77
// -----
88
func.func @test_fft2d(%arg0: tensor<1x4x8xf32>, %arg1: tensor<1x4x8xf32>) -> (tensor<1x4x8xf32>, tensor<1x4x8xf32>) {
@@ -36,3 +36,37 @@ func.func @test_cast_bf16_i32(%arg0: tensor<13x21x3xbf16>) -> tensor<13x21x3xi32
3636
return %0 : tensor<13x21x3xi32>
3737
}
3838

39+
// -----
40+
func.func @test_cond_if(%arg0: tensor<f32>, %arg1: tensor<f32>, %arg2: tensor<i1>) -> tensor<f32> {
41+
// expected-error@+1 {{'tosa.cond_if' op illegal: requires [controlflow]}}
42+
%0 = tosa.cond_if %arg2 -> (tensor<f32>) {
43+
%1 = tosa.add %arg0, %arg1 : (tensor<f32>, tensor<f32>) -> tensor<f32>
44+
tosa.yield %1 : tensor<f32>
45+
} else {
46+
%1 = tosa.sub %arg0, %arg1 : (tensor<f32>, tensor<f32>) -> tensor<f32>
47+
tosa.yield %1 : tensor<f32>
48+
}
49+
return %0 : tensor<f32>
50+
}
51+
52+
// -----
53+
func.func @test_while_loop(%arg0: tensor<10xi32>, %arg1: tensor<i32>) {
54+
%0 = "tosa.const"() {value = dense<0> : tensor<i32>} : () -> tensor<i32>
55+
// expected-error@+1 {{'tosa.while_loop' op illegal: requires [controlflow]}}
56+
%1:3 = tosa.while_loop (%arg2 = %0, %arg3 = %0, %arg4 = %arg0) : (tensor<i32>, tensor<i32>, tensor<10xi32>) -> (tensor<i32>, tensor<i32>, tensor<10xi32>) {
57+
%2 = tosa.greater_equal %arg3, %arg1 : (tensor<i32>, tensor<i32>) -> tensor<i1>
58+
%3 = tosa.logical_not %2 : (tensor<i1>) -> tensor<i1>
59+
tosa.yield %3 : tensor<i1>
60+
} do {
61+
^bb0(%arg2: tensor<i32>, %arg3: tensor<i32>, %arg4: tensor<10xi32>):
62+
%2 = "tosa.const"() {value = dense<1> : tensor<i32>} : () -> tensor<i32>
63+
%3 = tosa.add %arg3, %2 : (tensor<i32>, tensor<i32>) -> tensor<i32>
64+
%7 = tosa.const_shape {value = dense<[1]> : tensor<1xindex>} : () -> !tosa.shape<1>
65+
%4 = tosa.reshape %2, %7 : (tensor<i32>, !tosa.shape<1>) -> tensor<1xi32>
66+
%5 = tosa.add %arg4, %4 : (tensor<10xi32>, tensor<1xi32>) -> tensor<10xi32>
67+
%6 = tosa.add %arg2, %2 : (tensor<i32>, tensor<i32>) -> tensor<i32>
68+
tosa.yield %6, %3, %5 : tensor<i32>, tensor<i32>, tensor<10xi32>
69+
}
70+
return
71+
}
72+

mlir/test/Dialect/Tosa/level_check.mlir

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,7 @@
22
// Enable all supported profiles and extensions to focus the verification of expected level errors.
33
//--------------------------------------------------------------------------------------------------
44

5-
// RUN: mlir-opt %s -split-input-file -verify-diagnostics --tosa-validate="profile=pro_int,pro_fp,mt extension=int16,int4,bf16,fp8e4m3,fp8e5m2,fft,variable"
5+
// RUN: mlir-opt %s -split-input-file -verify-diagnostics --tosa-validate="profile=pro_int,pro_fp extension=int16,int4,bf16,fp8e4m3,fp8e5m2,fft,variable,controlflow"
66

77
func.func @test_argmax(%arg0: tensor<1x1x1x1x29x29x4xf32>) -> tensor<1x1x1x1x29x4xi32> {
88
// expected-error@+1 {{'tosa.argmax' op failed level check: operand rank(shape) <= MAX_RANK}}

mlir/test/Dialect/Tosa/profile_all_unsupported.mlir

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,7 @@
22
// Enable all supported extensions to focus the verification of expected profile requirement errors.
33
//--------------------------------------------------------------------------------------------------
44

5-
// RUN: mlir-opt %s -split-input-file -verify-diagnostics -tosa-validate="extension=int16,int4,bf16,fp8e4m3,fp8e5m2,fft,variable strict-op-spec-alignment"
5+
// RUN: mlir-opt %s -split-input-file -verify-diagnostics -tosa-validate="extension=int16,int4,bf16,fp8e4m3,fp8e5m2,fft,variable,controlflow strict-op-spec-alignment"
66

77
// -----
88
func.func @test_table(%arg0 : tensor<4x5xi8>, %arg1 : tensor<513xi8>) -> () {

mlir/test/Dialect/Tosa/profile_pro_fp_unsupported.mlir

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,7 @@
22
// Enable all supported extensions to focus the verification of expected profile requirement errors.
33
//--------------------------------------------------------------------------------------------------
44

5-
// RUN: mlir-opt %s -split-input-file -verify-diagnostics -tosa-validate="profile=pro_int extension=int16,int4,bf16,fp8e4m3,fp8e5m2,fft,variable strict-op-spec-alignment"
5+
// RUN: mlir-opt %s -split-input-file -verify-diagnostics -tosa-validate="profile=pro_int extension=int16,int4,bf16,fp8e4m3,fp8e5m2,fft,variable,controlflow strict-op-spec-alignment"
66

77
// -----
88
func.func @test_conv2d(%arg0: tensor<1x4x4x4xf32>, %arg1: tensor<8x1x1x4xf32>, %arg2: tensor<8xf32>) -> tensor<1x4x4x8xf32> {

mlir/test/Dialect/Tosa/profile_pro_int_unsupported.mlir

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,7 @@
22
// Enable all supported extensions to focus the verification of expected profile requirement errors.
33
//--------------------------------------------------------------------------------------------------
44

5-
// RUN: mlir-opt %s -split-input-file -verify-diagnostics -tosa-validate="profile=pro_fp extension=int16,int4,bf16,fp8e4m3,fp8e5m2,fft,variable strict-op-spec-alignment"
5+
// RUN: mlir-opt %s -split-input-file -verify-diagnostics -tosa-validate="profile=pro_fp extension=int16,int4,bf16,fp8e4m3,fp8e5m2,fft,variable,controlflow strict-op-spec-alignment"
66

77
// -----
88
func.func @test_table(%arg0 : tensor<4x5xi8>, %arg1 : tensor<513xi8>) -> () {

0 commit comments

Comments
 (0)