-
Notifications
You must be signed in to change notification settings - Fork 14.3k
[mlir][Symbol] Add verification that symbol's parent is a SymbolTable #80590
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
@llvm/pr-subscribers-mlir-llvm @llvm/pr-subscribers-mlir Author: Joshua Cao (caojoshua) ChangesFollowing the discussion in https://discourse.llvm.org/t/symboltable-and-symbol-parent-child-relationship/75446, we should enforce that a symbol's immediate parent is a symbol table. I changed some tests to pass the verification. In most cases, we can wrap the func with a module, change the func to another op with regions i.e. scf.if, or change the expected error message. Full diff: https://github.com/llvm/llvm-project/pull/80590.diff 12 Files Affected:
diff --git a/mlir/include/mlir/IR/SymbolInterfaces.td b/mlir/include/mlir/IR/SymbolInterfaces.td
index 844601f8f6837c..0bd5de9f18920e 100644
--- a/mlir/include/mlir/IR/SymbolInterfaces.td
+++ b/mlir/include/mlir/IR/SymbolInterfaces.td
@@ -171,6 +171,11 @@ def Symbol : OpInterface<"SymbolOpInterface"> {
if (concreteOp.isDeclaration() && concreteOp.isPublic())
return concreteOp.emitOpError("symbol declaration cannot have public "
"visibility");
+ auto parent = $_op->getParentOp();
+ if (parent && !parent->hasTrait<OpTrait::SymbolTable>()) {
+ return concreteOp.emitOpError("symbol's parent must have the SymbolTable "
+ "trait");
+ }
return success();
}];
diff --git a/mlir/test/Dialect/LLVMIR/global.mlir b/mlir/test/Dialect/LLVMIR/global.mlir
index 0649e814bfdfc0..3fa7636d4dd686 100644
--- a/mlir/test/Dialect/LLVMIR/global.mlir
+++ b/mlir/test/Dialect/LLVMIR/global.mlir
@@ -132,7 +132,7 @@ llvm.mlir.global internal constant @constant(37.0) : !llvm.label
// -----
func.func @foo() {
- // expected-error @+1 {{must appear at the module level}}
+ // expected-error @+1 {{op symbol's parent must have the SymbolTable trait}}
llvm.mlir.global internal @bar(42) : i32
return
diff --git a/mlir/test/Dialect/Linalg/transform-op-replace.mlir b/mlir/test/Dialect/Linalg/transform-op-replace.mlir
index 2801522e81ac2c..1a40912977dec2 100644
--- a/mlir/test/Dialect/Linalg/transform-op-replace.mlir
+++ b/mlir/test/Dialect/Linalg/transform-op-replace.mlir
@@ -12,8 +12,10 @@ module attributes {transform.with_named_sequence} {
transform.named_sequence @__transform_main(%arg1: !transform.any_op {transform.readonly}) {
%0 = transform.structured.match ops{["func.func"]} in %arg1 : (!transform.any_op) -> !transform.any_op
transform.structured.replace %0 {
- func.func @foo() {
- "dummy_op"() : () -> ()
+ builtin.module {
+ func.func @foo() {
+ "dummy_op"() : () -> ()
+ }
}
} : (!transform.any_op) -> !transform.any_op
transform.yield
diff --git a/mlir/test/Dialect/Transform/ops-invalid.mlir b/mlir/test/Dialect/Transform/ops-invalid.mlir
index e3f5bcf403f2ad..73a5f36af92952 100644
--- a/mlir/test/Dialect/Transform/ops-invalid.mlir
+++ b/mlir/test/Dialect/Transform/ops-invalid.mlir
@@ -433,10 +433,9 @@ module {
// -----
module attributes { transform.with_named_sequence} {
- // expected-note @below {{ancestor transform op}}
transform.sequence failures(suppress) {
^bb0(%arg0: !transform.any_op):
- // expected-error @below {{cannot be defined inside another transform op}}
+ // expected-error @below {{op symbol's parent must have the SymbolTable trai}}
transform.named_sequence @nested() {
transform.yield
}
diff --git a/mlir/test/IR/invalid-func-op.mlir b/mlir/test/IR/invalid-func-op.mlir
index d995689ebb8d0b..8fd7af22e9598b 100644
--- a/mlir/test/IR/invalid-func-op.mlir
+++ b/mlir/test/IR/invalid-func-op.mlir
@@ -31,7 +31,7 @@ func.func @func_op() {
// -----
func.func @func_op() {
- // expected-error@+1 {{entry block must have 1 arguments to match function signature}}
+ // expected-error@+1 {{op symbol's parent must have the SymbolTable trait}}
func.func @mixed_named_arguments(f32) {
^entry:
return
@@ -42,7 +42,7 @@ func.func @func_op() {
// -----
func.func @func_op() {
- // expected-error@+1 {{type of entry block argument #0('i32') must match the type of the corresponding argument in function signature('f32')}}
+ // expected-error@+1 {{op symbol's parent must have the SymbolTable trait}}
func.func @mixed_named_arguments(f32) {
^entry(%arg : i32):
return
diff --git a/mlir/test/IR/region.mlir b/mlir/test/IR/region.mlir
index bf4b1bb4e5ab1d..0b959915d6bbbe 100644
--- a/mlir/test/IR/region.mlir
+++ b/mlir/test/IR/region.mlir
@@ -87,18 +87,17 @@ func.func @named_region_has_wrong_number_of_blocks() {
// CHECK: test.single_no_terminator_op
"test.single_no_terminator_op"() (
{
- func.func @foo1() { return }
- func.func @foo2() { return }
+ %foo = arith.constant 1 : i32
}
) : () -> ()
// CHECK: test.variadic_no_terminator_op
"test.variadic_no_terminator_op"() (
{
- func.func @foo1() { return }
+ %foo = arith.constant 1 : i32
},
{
- func.func @foo2() { return }
+ %bar = arith.constant 1 : i32
}
) : () -> ()
diff --git a/mlir/test/IR/traits.mlir b/mlir/test/IR/traits.mlir
index 0402ebe7587508..1e046706379cdb 100644
--- a/mlir/test/IR/traits.mlir
+++ b/mlir/test/IR/traits.mlir
@@ -572,15 +572,13 @@ func.func @failedHasDominanceScopeOutsideDominanceFreeScope() -> () {
// Ensure that SSACFG regions of operations in GRAPH regions are
// checked for dominance
-func.func @illegalInsideDominanceFreeScope() -> () {
+func.func @illegalInsideDominanceFreeScope(%cond: i1) -> () {
test.graph_region {
- func.func @test() -> i1 {
- ^bb1:
+ scf.if %cond {
// expected-error @+1 {{operand #0 does not dominate this use}}
%2:3 = "bar"(%1) : (i64) -> (i1,i1,i1)
// expected-note @+1 {{operand defined here}}
- %1 = "baz"(%2#0) : (i1) -> (i64)
- return %2#1 : i1
+ %1 = "baz"(%2#0) : (i1) -> (i64)
}
"terminator"() : () -> ()
}
@@ -591,20 +589,21 @@ func.func @illegalInsideDominanceFreeScope() -> () {
// Ensure that SSACFG regions of operations in GRAPH regions are
// checked for dominance
-func.func @illegalCDFGInsideDominanceFreeScope() -> () {
+func.func @illegalCFGInsideDominanceFreeScope(%cond: i1) -> () {
test.graph_region {
- func.func @test() -> i1 {
- ^bb1:
- // expected-error @+1 {{operand #0 does not dominate this use}}
- %2:3 = "bar"(%1) : (i64) -> (i1,i1,i1)
- cf.br ^bb4
- ^bb2:
- cf.br ^bb2
- ^bb4:
- %1 = "foo"() : ()->i64 // expected-note {{operand defined here}}
- return %2#1 : i1
+ scf.if %cond {
+ "test.ssacfg_region"() ({
+ ^bb1:
+ // expected-error @+1 {{operand #0 does not dominate this use}}
+ %2:3 = "bar"(%1) : (i64) -> (i1,i1,i1)
+ cf.br ^bb4
+ ^bb2:
+ cf.br ^bb2
+ ^bb4:
+ %1 = "foo"() : ()->i64 // expected-note {{operand defined here}}
+ }) : () -> ()
}
- "terminator"() : () -> ()
+ "terminator"() : () -> ()
}
return
}
diff --git a/mlir/test/Transforms/canonicalize-dce.mlir b/mlir/test/Transforms/canonicalize-dce.mlir
index 46545d2e9fd510..3048a7fed636b5 100644
--- a/mlir/test/Transforms/canonicalize-dce.mlir
+++ b/mlir/test/Transforms/canonicalize-dce.mlir
@@ -77,15 +77,15 @@ func.func @f(%arg0: f32, %pred: i1) {
// Test case: Recursively DCE into enclosed regions.
-// CHECK: func @f(%arg0: f32)
-// CHECK-NEXT: func @g(%arg1: f32)
-// CHECK-NEXT: return
+// CHECK: func.func @f(%arg0: f32)
+// CHECK-NOT: arith.addf
func.func @f(%arg0: f32) {
- func.func @g(%arg1: f32) {
- %0 = "arith.addf"(%arg1, %arg1) : (f32, f32) -> f32
- return
- }
+ "test.region"() (
+ {
+ %0 = "arith.addf"(%arg0, %arg0) : (f32, f32) -> f32
+ }
+ ) : () -> ()
return
}
diff --git a/mlir/test/Transforms/canonicalize.mlir b/mlir/test/Transforms/canonicalize.mlir
index 9b578e6c2631a7..2cf86b50d432f6 100644
--- a/mlir/test/Transforms/canonicalize.mlir
+++ b/mlir/test/Transforms/canonicalize.mlir
@@ -424,16 +424,15 @@ func.func @write_only_alloca_fold(%v: f32) {
// CHECK-LABEL: func @dead_block_elim
func.func @dead_block_elim() {
// CHECK-NOT: ^bb
- func.func @nested() {
- return
+ builtin.module {
+ func.func @nested() {
+ return
- ^bb1:
- return
+ ^bb1:
+ return
+ }
}
return
-
-^bb1:
- return
}
// CHECK-LABEL: func @dyn_shape_fold(%arg0: index, %arg1: index)
diff --git a/mlir/test/Transforms/constant-fold.mlir b/mlir/test/Transforms/constant-fold.mlir
index 45ee03fa31d25f..253163f2af9110 100644
--- a/mlir/test/Transforms/constant-fold.mlir
+++ b/mlir/test/Transforms/constant-fold.mlir
@@ -756,12 +756,15 @@ func.func @cmpf_inf() -> (i1, i1, i1, i1, i1, i1, i1, i1, i1, i1, i1, i1, i1, i1
// CHECK-LABEL: func @nested_isolated_region
func.func @nested_isolated_region() {
+ // CHECK-NEXT: builtin.module {
// CHECK-NEXT: func @isolated_op
// CHECK-NEXT: arith.constant 2
- func.func @isolated_op() {
- %0 = arith.constant 1 : i32
- %2 = arith.addi %0, %0 : i32
- "foo.yield"(%2) : (i32) -> ()
+ builtin.module {
+ func.func @isolated_op() {
+ %0 = arith.constant 1 : i32
+ %2 = arith.addi %0, %0 : i32
+ "foo.yield"(%2) : (i32) -> ()
+ }
}
// CHECK: "foo.unknown_region"
diff --git a/mlir/test/Transforms/cse.mlir b/mlir/test/Transforms/cse.mlir
index c764d2b9bd57d8..11a33102684733 100644
--- a/mlir/test/Transforms/cse.mlir
+++ b/mlir/test/Transforms/cse.mlir
@@ -228,11 +228,14 @@ func.func @nested_isolated() -> i32 {
// CHECK-NEXT: arith.constant 1
%0 = arith.constant 1 : i32
+ // CHECK-NEXT: builtin.module
// CHECK-NEXT: @nested_func
- func.func @nested_func() {
- // CHECK-NEXT: arith.constant 1
- %foo = arith.constant 1 : i32
- "foo.yield"(%foo) : (i32) -> ()
+ builtin.module {
+ func.func @nested_func() {
+ // CHECK-NEXT: arith.constant 1
+ %foo = arith.constant 1 : i32
+ "foo.yield"(%foo) : (i32) -> ()
+ }
}
// CHECK: "foo.region"
diff --git a/mlir/test/Transforms/test-legalizer-full.mlir b/mlir/test/Transforms/test-legalizer-full.mlir
index ecb17d5f1b67d4..4268f18e611c0a 100644
--- a/mlir/test/Transforms/test-legalizer-full.mlir
+++ b/mlir/test/Transforms/test-legalizer-full.mlir
@@ -37,9 +37,11 @@ func.func @recursively_legal_invalid_op() {
}
/// Operation that is dynamically legal, i.e. the function has a pattern
/// applied to legalize the argument type before it becomes recursively legal.
- func.func @dynamic_func(%arg: i64) attributes {test.recursively_legal} {
- %ignored = "test.illegal_op_f"() : () -> (i32)
- "test.return"() : () -> ()
+ builtin.module {
+ func.func @dynamic_func(%arg: i64) attributes {test.recursively_legal} {
+ %ignored = "test.illegal_op_f"() : () -> (i32)
+ "test.return"() : () -> ()
+ }
}
"test.return"() : () -> ()
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM with a nit.
Co-authored-by: Mehdi Amini <[email protected]>
I removed the nested function in values.py, which was added in https://reviews.llvm.org/D149902. The test just looks for tests |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I removed the nested function in values.py, which was added in https://reviews.llvm.org/D149902. The test just looks for tests
get_name()
, and we don't need nested funcs to test that.
Makes sense.
Following the discussion in https://discourse.llvm.org/t/symboltable-and-symbol-parent-child-relationship/75446, we should enforce that a symbol's immediate parent is a symbol table.
I changed some tests to pass the verification. In most cases, we can wrap the func with a module, change the func to another op with regions i.e. scf.if, or change the expected error message.