Skip to content

Reformat whitespace in dependent dialects codegen #78090

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 2 commits into from
Jan 15, 2024

Conversation

mlevesquedion
Copy link
Contributor

The generated code for dependent dialects is awkwardly formatted, making the code harder to read. This change reformats the whitespace to align code in its context and avoid unnecessary empty lines.

Also included are some typo fixes.

Below are examples of the codegen for a dialect before and after the change.

Before:

GPUDialect::GPUDialect(::mlir::MLIRContext *context)
    : ::mlir::Dialect(getDialectNamespace(), context, ::mlir::TypeID::get<GPUDialect>()) {

    getContext()->loadDialect<arith::ArithDialect>();

  initialize();
}

After:

GPUDialect::GPUDialect(::mlir::MLIRContext *context)
    : ::mlir::Dialect(getDialectNamespace(), context, ::mlir::TypeID::get<GPUDialect>()) {
  getContext()->loadDialect<arith::ArithDialect>();
  initialize();
}

Below are examples of the codegen for a pass before and after the change.

Before:

  /// Return the dialect that must be loaded in the context before this pass.
  void getDependentDialects(::mlir::DialectRegistry &registry) const override {

  registry.insert<func::FuncDialect>();

  registry.insert<tensor::TensorDialect>();

  registry.insert<tosa::TosaDialect>();

  }

After:

  /// Register the dialects that must be loaded in the context before this pass.
  void getDependentDialects(::mlir::DialectRegistry &registry) const override {
    registry.insert<func::FuncDialect>();
    registry.insert<tensor::TensorDialect>();
    registry.insert<tosa::TosaDialect>();
  }

Copy link

Thank you for submitting a Pull Request (PR) to the LLVM Project!

This PR will be automatically labeled and the relevant teams will be
notified.

If you wish to, you can add reviewers by using the "Reviewers" section on this page.

If this is not working for you, it is probably because you do not have write
permissions for the repository. In which case you can instead tag reviewers by
name in a comment by using @ followed by their GitHub username.

If you have received no comments on your PR for a week, you can request a review
by "ping"ing the PR by adding a comment “Ping”. The common courtesy "ping" rate
is once a week. Please remember that you are asking for valuable time from other developers.

If you have further questions, they may be answered by the LLVM GitHub User Guide.

You can also ask questions in a comment on this PR, on the LLVM Discord or on the forums.

@llvmbot llvmbot added mlir:core MLIR Core Infrastructure mlir labels Jan 14, 2024
@mlevesquedion
Copy link
Contributor Author

I care about the formatting because I spend a fair bit of time reading generated code, and in general I've found the generated code to be quite readable.

I'm not sure what the best way to test this is, or even if this needs tests. I added some FileCheck tests, but these don't fully cover the formatting, and they may be brittle.

(This is my first time contributing.)

@llvmbot
Copy link
Member

llvmbot commented Jan 14, 2024

@llvm/pr-subscribers-mlir

@llvm/pr-subscribers-mlir-core

Author: mlevesquedion (mlevesquedion)

Changes

The generated code for dependent dialects is awkwardly formatted, making the code harder to read. This change reformats the whitespace to align code in its context and avoid unnecessary empty lines.

Also included are some typo fixes.

Below are examples of the codegen for a dialect before and after the change.

Before:

GPUDialect::GPUDialect(::mlir::MLIRContext *context)
    : ::mlir::Dialect(getDialectNamespace(), context, ::mlir::TypeID::get&lt;GPUDialect&gt;()) {

    getContext()-&gt;loadDialect&lt;arith::ArithDialect&gt;();

  initialize();
}

After:

GPUDialect::GPUDialect(::mlir::MLIRContext *context)
    : ::mlir::Dialect(getDialectNamespace(), context, ::mlir::TypeID::get&lt;GPUDialect&gt;()) {
  getContext()-&gt;loadDialect&lt;arith::ArithDialect&gt;();
  initialize();
}

Below are examples of the codegen for a pass before and after the change.

Before:

  /// Return the dialect that must be loaded in the context before this pass.
  void getDependentDialects(::mlir::DialectRegistry &amp;registry) const override {

  registry.insert&lt;func::FuncDialect&gt;();

  registry.insert&lt;tensor::TensorDialect&gt;();

  registry.insert&lt;tosa::TosaDialect&gt;();

  }

After:

  /// Register the dialects that must be loaded in the context before this pass.
  void getDependentDialects(::mlir::DialectRegistry &amp;registry) const override {
    registry.insert&lt;func::FuncDialect&gt;();
    registry.insert&lt;tensor::TensorDialect&gt;();
    registry.insert&lt;tosa::TosaDialect&gt;();
  }

Full diff: https://github.com/llvm/llvm-project/pull/78090.diff

4 Files Affected:

  • (added) mlir/test/mlir-tblgen/dialect-with-dependents.td (+15)
  • (added) mlir/test/mlir-tblgen/pass.td (+22)
  • (modified) mlir/tools/mlir-tblgen/DialectGen.cpp (+11-9)
  • (modified) mlir/tools/mlir-tblgen/PassGen.cpp (+16-11)
diff --git a/mlir/test/mlir-tblgen/dialect-with-dependents.td b/mlir/test/mlir-tblgen/dialect-with-dependents.td
new file mode 100644
index 00000000000000..e915e13841b5e5
--- /dev/null
+++ b/mlir/test/mlir-tblgen/dialect-with-dependents.td
@@ -0,0 +1,15 @@
+// RUN: mlir-tblgen -gen-dialect-defs -I %S/../../include %s | FileCheck %s
+
+include "mlir/IR/OpBase.td"
+
+def FooDialect : Dialect {
+  let name = "foo";
+  let dependentDialects = ["func::FuncDialect", "shape::ShapeDialect"];
+}
+
+// CHECK-LABEL: FooDialect::FooDialect
+// CHECK: {
+// CHECK-NEXT: getContext()->loadDialect<func::FuncDialect>();
+// CHECK-NEXT: getContext()->loadDialect<shape::ShapeDialect>();
+// CHECK-NEXT: initialize();
+// CHECK-NEXT: }
diff --git a/mlir/test/mlir-tblgen/pass.td b/mlir/test/mlir-tblgen/pass.td
new file mode 100644
index 00000000000000..fb5580a2e3dd16
--- /dev/null
+++ b/mlir/test/mlir-tblgen/pass.td
@@ -0,0 +1,22 @@
+// RUN: mlir-tblgen -gen-pass-decls -I %S/../../include %s | FileCheck %s
+
+include "mlir/Pass/PassBase.td"
+
+def FooPass : Pass<"foo", "ModuleOp"> {
+  let summary = "A pass for testing pass code generation.";
+  let dependentDialects = ["func::FuncDialect", "shape::ShapeDialect"];
+}
+
+// CHECK-LABEL: GEN_PASS_DEF_FOO
+// CHECK: FooPassBase
+// CHECK: void getDependentDialects
+// CHECK-NEXT: registry.insert<func::FuncDialect>();
+// CHECK-NEXT: registry.insert<shape::ShapeDialect>();
+// CHECK-NEXT: }
+
+// CHECK-LABEL: GEN_PASS_CLASSES
+// CHECK: FooPassBase
+// CHECK: void getDependentDialects
+// CHECK-NEXT: registry.insert<func::FuncDialect>();
+// CHECK-NEXT: registry.insert<shape::ShapeDialect>();
+// CHECK-NEXT: }
diff --git a/mlir/tools/mlir-tblgen/DialectGen.cpp b/mlir/tools/mlir-tblgen/DialectGen.cpp
index f22434f755abe3..5d4c294bcd8004 100644
--- a/mlir/tools/mlir-tblgen/DialectGen.cpp
+++ b/mlir/tools/mlir-tblgen/DialectGen.cpp
@@ -106,9 +106,8 @@ class {0} : public ::mlir::{2} {
 
 /// Registration for a single dependent dialect: to be inserted in the ctor
 /// above for each dependent dialect.
-const char *const dialectRegistrationTemplate = R"(
-    getContext()->loadDialect<{0}>();
-)";
+const char *const dialectRegistrationTemplate =
+    "getContext()->loadDialect<{0}>();";
 
 /// The code block for the attribute parser/printer hooks.
 static const char *const attrParserDecl = R"(
@@ -250,8 +249,8 @@ static bool emitDialectDecls(const llvm::RecordKeeper &recordKeeper,
 /// The code block to generate a dialect constructor definition.
 ///
 /// {0}: The name of the dialect class.
-/// {1}: initialization code that is emitted in the ctor body before calling
-///      initialize().
+/// {1}: Initialization code that is emitted in the ctor body before calling
+///      initialize(), such as dependent dialect registration.
 /// {2}: The dialect parent class.
 static const char *const dialectConstructorStr = R"(
 {0}::{0}(::mlir::MLIRContext *context)
@@ -261,7 +260,7 @@ static const char *const dialectConstructorStr = R"(
 }
 )";
 
-/// The code block to generate a default desturctor definition.
+/// The code block to generate a default destructor definition.
 ///
 /// {0}: The name of the dialect class.
 static const char *const dialectDestructorStr = R"(
@@ -284,9 +283,12 @@ static void emitDialectDef(Dialect &dialect, raw_ostream &os) {
   std::string dependentDialectRegistrations;
   {
     llvm::raw_string_ostream dialectsOs(dependentDialectRegistrations);
-    for (StringRef dependentDialect : dialect.getDependentDialects())
-      dialectsOs << llvm::formatv(dialectRegistrationTemplate,
-                                  dependentDialect);
+    llvm::interleave(
+        dialect.getDependentDialects(), dialectsOs,
+        [&dialectsOs](StringRef dd) {
+          dialectsOs << llvm::formatv(dialectRegistrationTemplate, dd);
+        },
+        "\n  ");
   }
 
   // Emit the constructor and destructor.
diff --git a/mlir/tools/mlir-tblgen/PassGen.cpp b/mlir/tools/mlir-tblgen/PassGen.cpp
index de159d144ffbb4..4c895266177429 100644
--- a/mlir/tools/mlir-tblgen/PassGen.cpp
+++ b/mlir/tools/mlir-tblgen/PassGen.cpp
@@ -173,7 +173,8 @@ static void emitRegistrations(llvm::ArrayRef<Pass> passes, raw_ostream &os) {
 /// {0}: The def name of the pass record.
 /// {1}: The base class for the pass.
 /// {2): The command line argument for the pass.
-/// {3}: The dependent dialects registration.
+/// {3}: The summary for the pass.
+/// {4}: The dependent dialects registration.
 const char *const baseClassBegin = R"(
 template <typename DerivedT>
 class {0}Base : public {1} {
@@ -221,9 +222,7 @@ class {0}Base : public {1} {
 
 /// Registration for a single dependent dialect, to be inserted for each
 /// dependent dialect in the `getDependentDialects` above.
-const char *const dialectRegistrationTemplate = R"(
-  registry.insert<{0}>();
-)";
+const char *const dialectRegistrationTemplate = "registry.insert<{0}>();";
 
 const char *const friendDefaultConstructorDeclTemplate = R"(
 namespace impl {{
@@ -307,9 +306,12 @@ static void emitPassDefs(const Pass &pass, raw_ostream &os) {
   std::string dependentDialectRegistrations;
   {
     llvm::raw_string_ostream dialectsOs(dependentDialectRegistrations);
-    for (StringRef dependentDialect : pass.getDependentDialects())
-      dialectsOs << llvm::formatv(dialectRegistrationTemplate,
-                                  dependentDialect);
+    llvm::interleave(
+        pass.getDependentDialects(), dialectsOs,
+        [&dialectsOs](StringRef dd) {
+          dialectsOs << llvm::formatv(dialectRegistrationTemplate, dd);
+        },
+        "\n    ");
   }
 
   os << "namespace impl {\n";
@@ -402,7 +404,7 @@ class {0}Base : public {1} {
     return std::make_unique<DerivedT>(*static_cast<const DerivedT *>(this));
   }
 
-  /// Return the dialect that must be loaded in the context before this pass.
+  /// Register the dialects that must be loaded in the context before this pass.
   void getDependentDialects(::mlir::DialectRegistry &registry) const override {
     {4}
   }
@@ -422,9 +424,12 @@ static void emitOldPassDecl(const Pass &pass, raw_ostream &os) {
   std::string dependentDialectRegistrations;
   {
     llvm::raw_string_ostream dialectsOs(dependentDialectRegistrations);
-    for (StringRef dependentDialect : pass.getDependentDialects())
-      dialectsOs << llvm::formatv(dialectRegistrationTemplate,
-                                  dependentDialect);
+    llvm::interleave(
+        pass.getDependentDialects(), dialectsOs,
+        [&dialectsOs](StringRef dd) {
+          dialectsOs << llvm::formatv(dialectRegistrationTemplate, dd);
+        },
+        "\n    ");
   }
   os << llvm::formatv(oldPassDeclBegin, defName, pass.getBaseClass(),
                       pass.getArgument(), pass.getSummary(),

Copy link
Member

@zero9178 zero9178 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Big fan of this change as an avid reader of TableGen generated code 🙂

One thing I am worried about are the tests. I don't think they provide a lot of benefits: They effectively only test whether the insert calls are one line after the other with no whitespace lines before or after and do not test indendation.
At the same time they are probably prone to breakage if any future changes are done to the codegen of these methods.

I'd personally just remove the tests.

@mlevesquedion
Copy link
Contributor Author

Thank you for taking the time to review! I removed the tests and resolved your other comment.

Copy link
Contributor

@MaheshRavishankar MaheshRavishankar left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don't know this code path as well, but change seems ok to me.
But overall, maybe we should just call clang-format on the generated code if we can directly invoke the code from c++.

@mlevesquedion
Copy link
Contributor Author

@MaheshRavishankar Good point. Another approach would be to generate the code as data structures and pretty print them. I don't know if either of these approaches is practical though. I think the change is valuable as it is, but I agree there would be room to look for a more complete solution.

@joker-eph Thank you for taking the time to review! I don't have write access to the repo, feel free to submit the change.

Michael Levesque-Dion added 2 commits January 14, 2024 21:12
The generated code for dependent dialects is awkwardly formatted, making
the code harder to read. This change reformats the whitespace to align
code in its context and avoid unnecessary empty lines.

Also included are some typo fixes.

Below are examples of the codegen for a dialect before and after the
change.

Before:

```
GPUDialect::GPUDialect(::mlir::MLIRContext *context)
    : ::mlir::Dialect(getDialectNamespace(), context, ::mlir::TypeID::get<GPUDialect>()) {

    getContext()->loadDialect<arith::ArithDialect>();

  initialize();
}
```

After:

```
GPUDialect::GPUDialect(::mlir::MLIRContext *context)
    : ::mlir::Dialect(getDialectNamespace(), context, ::mlir::TypeID::get<GPUDialect>()) {
  getContext()->loadDialect<arith::ArithDialect>();
  initialize();
}
```

Below are examples of the codegen for a pass before and after the change.

Before:

```
  /// Return the dialect that must be loaded in the context before this pass.
  void getDependentDialects(::mlir::DialectRegistry &registry) const override {

  registry.insert<func::FuncDialect>();

  registry.insert<tensor::TensorDialect>();

  registry.insert<tosa::TosaDialect>();

  }
```

After:

```
  /// Register the dialects that must be loaded in the context before this pass.
  void getDependentDialects(::mlir::DialectRegistry &registry) const override {
    registry.insert<func::FuncDialect>();
    registry.insert<tensor::TensorDialect>();
    registry.insert<tosa::TosaDialect>();
  }
```
@mlevesquedion mlevesquedion force-pushed the tidy-up-dependent-dialects-codegen branch from 5b2e1a4 to 5c42ffc Compare January 15, 2024 05:12
@zero9178 zero9178 merged commit 3dff20c into llvm:main Jan 15, 2024
@mlevesquedion mlevesquedion deleted the tidy-up-dependent-dialects-codegen branch January 15, 2024 21:07
justinfargnoli pushed a commit to justinfargnoli/llvm-project that referenced this pull request Jan 28, 2024
The generated code for dependent dialects is awkwardly formatted, making
the code harder to read. This change reformats the whitespace to align
code in its context and avoid unnecessary empty lines.

Also included are some typo fixes.

Below are examples of the codegen for a dialect before and after the
change.

Before:

```
GPUDialect::GPUDialect(::mlir::MLIRContext *context)
    : ::mlir::Dialect(getDialectNamespace(), context, ::mlir::TypeID::get<GPUDialect>()) {

    getContext()->loadDialect<arith::ArithDialect>();

  initialize();
}
```

After:

```
GPUDialect::GPUDialect(::mlir::MLIRContext *context)
    : ::mlir::Dialect(getDialectNamespace(), context, ::mlir::TypeID::get<GPUDialect>()) {
  getContext()->loadDialect<arith::ArithDialect>();
  initialize();
}
```

Below are examples of the codegen for a pass before and after the
change.

Before:

```
  /// Return the dialect that must be loaded in the context before this pass.
  void getDependentDialects(::mlir::DialectRegistry &registry) const override {

  registry.insert<func::FuncDialect>();

  registry.insert<tensor::TensorDialect>();

  registry.insert<tosa::TosaDialect>();

  }
```

After:

```
  /// Register the dialects that must be loaded in the context before this pass.
  void getDependentDialects(::mlir::DialectRegistry &registry) const override {
    registry.insert<func::FuncDialect>();
    registry.insert<tensor::TensorDialect>();
    registry.insert<tosa::TosaDialect>();
  }
```
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
mlir:core MLIR Core Infrastructure mlir
Projects
None yet
Development

Successfully merging this pull request may close these issues.

5 participants