Skip to content

[mlir][Symbol] Add verification that symbol's parent is a SymbolTable #80590

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 3 commits into from
Feb 6, 2024

Conversation

caojoshua
Copy link
Contributor

Following the discussion in https://discourse.llvm.org/t/symboltable-and-symbol-parent-child-relationship/75446, we should enforce that a symbol's immediate parent is a symbol table.

I changed some tests to pass the verification. In most cases, we can wrap the func with a module, change the func to another op with regions i.e. scf.if, or change the expected error message.

@llvmbot
Copy link
Member

llvmbot commented Feb 4, 2024

@llvm/pr-subscribers-mlir-llvm
@llvm/pr-subscribers-mlir-linalg
@llvm/pr-subscribers-mlir-ods
@llvm/pr-subscribers-mlir-core

@llvm/pr-subscribers-mlir

Author: Joshua Cao (caojoshua)

Changes

Following the discussion in https://discourse.llvm.org/t/symboltable-and-symbol-parent-child-relationship/75446, we should enforce that a symbol's immediate parent is a symbol table.

I changed some tests to pass the verification. In most cases, we can wrap the func with a module, change the func to another op with regions i.e. scf.if, or change the expected error message.


Full diff: https://github.com/llvm/llvm-project/pull/80590.diff

12 Files Affected:

  • (modified) mlir/include/mlir/IR/SymbolInterfaces.td (+5)
  • (modified) mlir/test/Dialect/LLVMIR/global.mlir (+1-1)
  • (modified) mlir/test/Dialect/Linalg/transform-op-replace.mlir (+4-2)
  • (modified) mlir/test/Dialect/Transform/ops-invalid.mlir (+1-2)
  • (modified) mlir/test/IR/invalid-func-op.mlir (+2-2)
  • (modified) mlir/test/IR/region.mlir (+3-4)
  • (modified) mlir/test/IR/traits.mlir (+16-17)
  • (modified) mlir/test/Transforms/canonicalize-dce.mlir (+7-7)
  • (modified) mlir/test/Transforms/canonicalize.mlir (+6-7)
  • (modified) mlir/test/Transforms/constant-fold.mlir (+7-4)
  • (modified) mlir/test/Transforms/cse.mlir (+7-4)
  • (modified) mlir/test/Transforms/test-legalizer-full.mlir (+5-3)
diff --git a/mlir/include/mlir/IR/SymbolInterfaces.td b/mlir/include/mlir/IR/SymbolInterfaces.td
index 844601f8f6837c..0bd5de9f18920e 100644
--- a/mlir/include/mlir/IR/SymbolInterfaces.td
+++ b/mlir/include/mlir/IR/SymbolInterfaces.td
@@ -171,6 +171,11 @@ def Symbol : OpInterface<"SymbolOpInterface"> {
     if (concreteOp.isDeclaration() && concreteOp.isPublic())
       return concreteOp.emitOpError("symbol declaration cannot have public "
              "visibility");
+    auto parent = $_op->getParentOp();
+    if (parent && !parent->hasTrait<OpTrait::SymbolTable>()) {
+      return concreteOp.emitOpError("symbol's parent must have the SymbolTable "
+             "trait");
+    }
     return success();
   }];
 
diff --git a/mlir/test/Dialect/LLVMIR/global.mlir b/mlir/test/Dialect/LLVMIR/global.mlir
index 0649e814bfdfc0..3fa7636d4dd686 100644
--- a/mlir/test/Dialect/LLVMIR/global.mlir
+++ b/mlir/test/Dialect/LLVMIR/global.mlir
@@ -132,7 +132,7 @@ llvm.mlir.global internal constant @constant(37.0) : !llvm.label
 // -----
 
 func.func @foo() {
-  // expected-error @+1 {{must appear at the module level}}
+  // expected-error @+1 {{op symbol's parent must have the SymbolTable trait}}
   llvm.mlir.global internal @bar(42) : i32
 
   return
diff --git a/mlir/test/Dialect/Linalg/transform-op-replace.mlir b/mlir/test/Dialect/Linalg/transform-op-replace.mlir
index 2801522e81ac2c..1a40912977dec2 100644
--- a/mlir/test/Dialect/Linalg/transform-op-replace.mlir
+++ b/mlir/test/Dialect/Linalg/transform-op-replace.mlir
@@ -12,8 +12,10 @@ module attributes {transform.with_named_sequence} {
   transform.named_sequence @__transform_main(%arg1: !transform.any_op {transform.readonly}) {
     %0 = transform.structured.match ops{["func.func"]} in %arg1 : (!transform.any_op) -> !transform.any_op
     transform.structured.replace %0 {
-      func.func @foo() {
-        "dummy_op"() : () -> ()
+      builtin.module {
+        func.func @foo() {
+          "dummy_op"() : () -> ()
+        }
       }
     } : (!transform.any_op) -> !transform.any_op
     transform.yield
diff --git a/mlir/test/Dialect/Transform/ops-invalid.mlir b/mlir/test/Dialect/Transform/ops-invalid.mlir
index e3f5bcf403f2ad..73a5f36af92952 100644
--- a/mlir/test/Dialect/Transform/ops-invalid.mlir
+++ b/mlir/test/Dialect/Transform/ops-invalid.mlir
@@ -433,10 +433,9 @@ module {
 // -----
 
 module attributes { transform.with_named_sequence} {
-  // expected-note @below {{ancestor transform op}}
   transform.sequence failures(suppress) {
   ^bb0(%arg0: !transform.any_op):
-    // expected-error @below {{cannot be defined inside another transform op}}
+    // expected-error @below {{op symbol's parent must have the SymbolTable trai}}
     transform.named_sequence @nested() {
       transform.yield
     }
diff --git a/mlir/test/IR/invalid-func-op.mlir b/mlir/test/IR/invalid-func-op.mlir
index d995689ebb8d0b..8fd7af22e9598b 100644
--- a/mlir/test/IR/invalid-func-op.mlir
+++ b/mlir/test/IR/invalid-func-op.mlir
@@ -31,7 +31,7 @@ func.func @func_op() {
 // -----
 
 func.func @func_op() {
-  // expected-error@+1 {{entry block must have 1 arguments to match function signature}}
+  // expected-error@+1 {{op symbol's parent must have the SymbolTable trait}}
   func.func @mixed_named_arguments(f32) {
   ^entry:
     return
@@ -42,7 +42,7 @@ func.func @func_op() {
 // -----
 
 func.func @func_op() {
-  // expected-error@+1 {{type of entry block argument #0('i32') must match the type of the corresponding argument in function signature('f32')}}
+  // expected-error@+1 {{op symbol's parent must have the SymbolTable trait}}
   func.func @mixed_named_arguments(f32) {
   ^entry(%arg : i32):
     return
diff --git a/mlir/test/IR/region.mlir b/mlir/test/IR/region.mlir
index bf4b1bb4e5ab1d..0b959915d6bbbe 100644
--- a/mlir/test/IR/region.mlir
+++ b/mlir/test/IR/region.mlir
@@ -87,18 +87,17 @@ func.func @named_region_has_wrong_number_of_blocks() {
 // CHECK: test.single_no_terminator_op
 "test.single_no_terminator_op"() (
   {
-    func.func @foo1() { return }
-    func.func @foo2() { return }
+    %foo = arith.constant 1 : i32
   }
 ) : () -> ()
 
 // CHECK: test.variadic_no_terminator_op
 "test.variadic_no_terminator_op"() (
   {
-    func.func @foo1() { return }
+    %foo = arith.constant 1 : i32
   },
   {
-    func.func @foo2() { return }
+    %bar = arith.constant 1 : i32
   }
 ) : () -> ()
 
diff --git a/mlir/test/IR/traits.mlir b/mlir/test/IR/traits.mlir
index 0402ebe7587508..1e046706379cdb 100644
--- a/mlir/test/IR/traits.mlir
+++ b/mlir/test/IR/traits.mlir
@@ -572,15 +572,13 @@ func.func @failedHasDominanceScopeOutsideDominanceFreeScope() -> () {
 
 // Ensure that SSACFG regions of operations in GRAPH regions are
 // checked for dominance
-func.func @illegalInsideDominanceFreeScope() -> () {
+func.func @illegalInsideDominanceFreeScope(%cond: i1) -> () {
   test.graph_region {
-    func.func @test() -> i1 {
-    ^bb1:
+    scf.if %cond {
       // expected-error @+1 {{operand #0 does not dominate this use}}
       %2:3 = "bar"(%1) : (i64) -> (i1,i1,i1)
       // expected-note @+1 {{operand defined here}}
-	   %1 = "baz"(%2#0) : (i1) -> (i64)
-      return %2#1 : i1
+      %1 = "baz"(%2#0) : (i1) -> (i64)
     }
     "terminator"() : () -> ()
   }
@@ -591,20 +589,21 @@ func.func @illegalInsideDominanceFreeScope() -> () {
 
 // Ensure that SSACFG regions of operations in GRAPH regions are
 // checked for dominance
-func.func @illegalCDFGInsideDominanceFreeScope() -> () {
+func.func @illegalCFGInsideDominanceFreeScope(%cond: i1) -> () {
   test.graph_region {
-    func.func @test() -> i1 {
-    ^bb1:
-      // expected-error @+1 {{operand #0 does not dominate this use}}
-      %2:3 = "bar"(%1) : (i64) -> (i1,i1,i1)
-      cf.br ^bb4
-    ^bb2:
-      cf.br ^bb2
-    ^bb4:
-      %1 = "foo"() : ()->i64   // expected-note {{operand defined here}}
-		return %2#1 : i1
+    scf.if %cond {
+      "test.ssacfg_region"() ({
+      ^bb1:
+        // expected-error @+1 {{operand #0 does not dominate this use}}
+        %2:3 = "bar"(%1) : (i64) -> (i1,i1,i1)
+        cf.br ^bb4
+      ^bb2:
+        cf.br ^bb2
+      ^bb4:
+        %1 = "foo"() : ()->i64   // expected-note {{operand defined here}}
+      }) : () -> ()
     }
-     "terminator"() : () -> ()
+    "terminator"() : () -> ()
   }
   return
 }
diff --git a/mlir/test/Transforms/canonicalize-dce.mlir b/mlir/test/Transforms/canonicalize-dce.mlir
index 46545d2e9fd510..3048a7fed636b5 100644
--- a/mlir/test/Transforms/canonicalize-dce.mlir
+++ b/mlir/test/Transforms/canonicalize-dce.mlir
@@ -77,15 +77,15 @@ func.func @f(%arg0: f32, %pred: i1) {
 
 // Test case: Recursively DCE into enclosed regions.
 
-// CHECK:      func @f(%arg0: f32)
-// CHECK-NEXT:   func @g(%arg1: f32)
-// CHECK-NEXT:     return
+// CHECK:      func.func @f(%arg0: f32)
+// CHECK-NOT:     arith.addf
 
 func.func @f(%arg0: f32) {
-  func.func @g(%arg1: f32) {
-    %0 = "arith.addf"(%arg1, %arg1) : (f32, f32) -> f32
-    return
-  }
+  "test.region"() (
+    {
+      %0 = "arith.addf"(%arg0, %arg0) : (f32, f32) -> f32
+    }
+  ) : () -> ()
   return
 }
 
diff --git a/mlir/test/Transforms/canonicalize.mlir b/mlir/test/Transforms/canonicalize.mlir
index 9b578e6c2631a7..2cf86b50d432f6 100644
--- a/mlir/test/Transforms/canonicalize.mlir
+++ b/mlir/test/Transforms/canonicalize.mlir
@@ -424,16 +424,15 @@ func.func @write_only_alloca_fold(%v: f32) {
 // CHECK-LABEL: func @dead_block_elim
 func.func @dead_block_elim() {
   // CHECK-NOT: ^bb
-  func.func @nested() {
-    return
+  builtin.module {
+    func.func @nested() {
+      return
 
-  ^bb1:
-    return
+    ^bb1:
+      return
+    }
   }
   return
-
-^bb1:
-  return
 }
 
 // CHECK-LABEL: func @dyn_shape_fold(%arg0: index, %arg1: index)
diff --git a/mlir/test/Transforms/constant-fold.mlir b/mlir/test/Transforms/constant-fold.mlir
index 45ee03fa31d25f..253163f2af9110 100644
--- a/mlir/test/Transforms/constant-fold.mlir
+++ b/mlir/test/Transforms/constant-fold.mlir
@@ -756,12 +756,15 @@ func.func @cmpf_inf() -> (i1, i1, i1, i1, i1, i1, i1, i1, i1, i1, i1, i1, i1, i1
 
 // CHECK-LABEL: func @nested_isolated_region
 func.func @nested_isolated_region() {
+  // CHECK-NEXT: builtin.module {
   // CHECK-NEXT: func @isolated_op
   // CHECK-NEXT: arith.constant 2
-  func.func @isolated_op() {
-    %0 = arith.constant 1 : i32
-    %2 = arith.addi %0, %0 : i32
-    "foo.yield"(%2) : (i32) -> ()
+  builtin.module {
+    func.func @isolated_op() {
+      %0 = arith.constant 1 : i32
+      %2 = arith.addi %0, %0 : i32
+      "foo.yield"(%2) : (i32) -> ()
+    }
   }
 
   // CHECK: "foo.unknown_region"
diff --git a/mlir/test/Transforms/cse.mlir b/mlir/test/Transforms/cse.mlir
index c764d2b9bd57d8..11a33102684733 100644
--- a/mlir/test/Transforms/cse.mlir
+++ b/mlir/test/Transforms/cse.mlir
@@ -228,11 +228,14 @@ func.func @nested_isolated() -> i32 {
   // CHECK-NEXT: arith.constant 1
   %0 = arith.constant 1 : i32
 
+  // CHECK-NEXT: builtin.module
   // CHECK-NEXT: @nested_func
-  func.func @nested_func() {
-    // CHECK-NEXT: arith.constant 1
-    %foo = arith.constant 1 : i32
-    "foo.yield"(%foo) : (i32) -> ()
+  builtin.module {
+    func.func @nested_func() {
+      // CHECK-NEXT: arith.constant 1
+      %foo = arith.constant 1 : i32
+      "foo.yield"(%foo) : (i32) -> ()
+    }
   }
 
   // CHECK: "foo.region"
diff --git a/mlir/test/Transforms/test-legalizer-full.mlir b/mlir/test/Transforms/test-legalizer-full.mlir
index ecb17d5f1b67d4..4268f18e611c0a 100644
--- a/mlir/test/Transforms/test-legalizer-full.mlir
+++ b/mlir/test/Transforms/test-legalizer-full.mlir
@@ -37,9 +37,11 @@ func.func @recursively_legal_invalid_op() {
   }
   /// Operation that is dynamically legal, i.e. the function has a pattern
   /// applied to legalize the argument type before it becomes recursively legal.
-  func.func @dynamic_func(%arg: i64) attributes {test.recursively_legal} {
-    %ignored = "test.illegal_op_f"() : () -> (i32)
-    "test.return"() : () -> ()
+  builtin.module {
+    func.func @dynamic_func(%arg: i64) attributes {test.recursively_legal} {
+      %ignored = "test.illegal_op_f"() : () -> (i32)
+      "test.return"() : () -> ()
+    }
   }
 
   "test.return"() : () -> ()

Copy link
Collaborator

@joker-eph joker-eph left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM with a nit.

@caojoshua
Copy link
Contributor Author

I removed the nested function in values.py, which was added in https://reviews.llvm.org/D149902. The test just looks for tests get_name(), and we don't need nested funcs to test that.

Copy link
Contributor

@makslevental makslevental left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I removed the nested function in values.py, which was added in https://reviews.llvm.org/D149902. The test just looks for tests get_name(), and we don't need nested funcs to test that.

Makes sense.

@caojoshua caojoshua merged commit 7d055af into llvm:main Feb 6, 2024
fifield added a commit to fifield/mlir-air that referenced this pull request Feb 13, 2024
fifield added a commit to fifield/mlir-air that referenced this pull request Feb 14, 2024
fifield added a commit to Xilinx/mlir-air that referenced this pull request Feb 14, 2024
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging this pull request may close these issues.

4 participants