Skip to content

[mlir] Start rewrite tool #77668

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 1 commit into from
Oct 12, 2024
Merged

[mlir] Start rewrite tool #77668

merged 1 commit into from
Oct 12, 2024

Conversation

jpienaar
Copy link
Member

Initial commit of a tool to help in textual rewrites of .mlir files. This tool builds of of AsmParserState and is rather simple. Took some inspiration from when I used clang's AST rewrites where I'd often treat it as a "localizing" regex applicator in fallback cases, and started with that as functionality. There though, one does have access to the lower level info than here, but still a step up over sed over entire file.

This aims to be helpful (e.g., rewrite syntax including best effort inside comments) rather than bulletproof tool. It may even be better suited under utils than tools. And most of the rewrites would be rather short lived and might never make it upstream (while the helpers of those rewrites may for future rewrites).

The layering at the moment is not ideal as it is reusing the RewriteBuffer class from clang's rewrite engine. So only optionally enabling where clang is also enable. There doesn't seem to be anything clang specific there (the dep does pull in more dependencies than ideal, but leaving both refactorings).

Additionally started it as a single file to prototype more easily, planning to refactor later to include and libs for out of file usage.

@llvmbot llvmbot added mlir:core MLIR Core Infrastructure mlir labels Jan 10, 2024
@llvmbot
Copy link
Member

llvmbot commented Jan 10, 2024

@llvm/pr-subscribers-mlir

Author: Jacques Pienaar (jpienaar)

Changes

Initial commit of a tool to help in textual rewrites of .mlir files. This tool builds of of AsmParserState and is rather simple. Took some inspiration from when I used clang's AST rewrites where I'd often treat it as a "localizing" regex applicator in fallback cases, and started with that as functionality. There though, one does have access to the lower level info than here, but still a step up over sed over entire file.

This aims to be helpful (e.g., rewrite syntax including best effort inside comments) rather than bulletproof tool. It may even be better suited under utils than tools. And most of the rewrites would be rather short lived and might never make it upstream (while the helpers of those rewrites may for future rewrites).

The layering at the moment is not ideal as it is reusing the RewriteBuffer class from clang's rewrite engine. So only optionally enabling where clang is also enable. There doesn't seem to be anything clang specific there (the dep does pull in more dependencies than ideal, but leaving both refactorings).

Additionally started it as a single file to prototype more easily, planning to refactor later to include and libs for out of file usage.


Full diff: https://github.com/llvm/llvm-project/pull/77668.diff

9 Files Affected:

  • (modified) mlir/CMakeLists.txt (+5)
  • (added) mlir/docs/Tools/mlir-rewrite.md (+29)
  • (modified) mlir/test/CMakeLists.txt (+7)
  • (modified) mlir/test/lit.cfg.py (+4)
  • (modified) mlir/test/lit.site.cfg.py.in (+1)
  • (added) mlir/test/mlir-rewrite/simple.mlir (+12)
  • (modified) mlir/tools/CMakeLists.txt (+5)
  • (added) mlir/tools/mlir-rewrite/CMakeLists.txt (+37)
  • (added) mlir/tools/mlir-rewrite/mlir-rewrite.cpp (+395)
diff --git a/mlir/CMakeLists.txt b/mlir/CMakeLists.txt
index 2d9f78e03ba76b..64aad84e90a5ad 100644
--- a/mlir/CMakeLists.txt
+++ b/mlir/CMakeLists.txt
@@ -285,3 +285,8 @@ endif()
 if(MLIR_STANDALONE_BUILD)
   llvm_distribution_add_targets()
 endif()
+
+# FIXME: Currently depends on utility functions inside clang.
+if ("clang" IN_LIST LLVM_ENABLE_PROJECTS)
+  set(MLIR_ENABLE_REWRITE ON CACHE BOOL "mlir-rewrite enabled")
+endif()
diff --git a/mlir/docs/Tools/mlir-rewrite.md b/mlir/docs/Tools/mlir-rewrite.md
new file mode 100644
index 00000000000000..178f92f72cbb6e
--- /dev/null
+++ b/mlir/docs/Tools/mlir-rewrite.md
@@ -0,0 +1,29 @@
+# mlir-rewrite
+
+Tool to simplify rewriting .mlir files. There are a couple of build in rewrites
+discussed below along with usage.
+
+Note: This is still in very early stage. Its so early its less a tool than a
+growing collection of useful functions: to use its best to do what's needed on
+a brance by just hacking it (dialects registered, rewrites etc) to say help
+ease a rename, upstream useful utility functions, point to ease others
+migrating, and then bin eventually. Once there are actually useful parts it
+should be refactored same as mlir-opt.
+
+[TOC]
+
+## simple-rename
+
+Rename per op given a substring to a target. The match and replace uses LLVM's
+regex sub for the match and replace while the op-name is matched via regular
+string comparison. E.g.,
+
+```
+mlir-rewrite input.mlir -o output.mlir --simple-rename \
+   --simple-rename-op-name="test.concat" --simple-rename-match="axis" \
+                                         --simple-rename-replace="bxis"
+```
+
+to replace `axis` substring in the text of the range corresponding to
+`test.concat` ops with `bxis`.
+
diff --git a/mlir/test/CMakeLists.txt b/mlir/test/CMakeLists.txt
index 8ce030feeded92..397a2efcf5e9f2 100644
--- a/mlir/test/CMakeLists.txt
+++ b/mlir/test/CMakeLists.txt
@@ -197,6 +197,13 @@ if(MLIR_ENABLE_BINDINGS_PYTHON)
   )
 endif()
 
+# FIXME: Currently depends on utility functions inside clang.
+if ("clang" IN_LIST LLVM_ENABLE_PROJECTS)
+  list(APPEND MLIR_TEST_DEPENDS
+    mlir-rewrite
+  )
+endif()
+
 # This target can be used to just build the dependencies
 # for the check-mlir target without executing the tests.
 # This is useful for bots when splitting the build step
diff --git a/mlir/test/lit.cfg.py b/mlir/test/lit.cfg.py
index 0a1ea1d16da452..35d6a3bd1f5636 100644
--- a/mlir/test/lit.cfg.py
+++ b/mlir/test/lit.cfg.py
@@ -144,6 +144,10 @@ def add_runtime(name):
         )
     )
 
+if config.enable_mlir_rewrite:
+    tools.extend(["mlir-rewrite"])
+    config.available_features.add('mlir-rewrite')
+
 # The following tools are optional
 tools.extend(
     [
diff --git a/mlir/test/lit.site.cfg.py.in b/mlir/test/lit.site.cfg.py.in
index c0fa1b8980e539..d35e3701198a56 100644
--- a/mlir/test/lit.site.cfg.py.in
+++ b/mlir/test/lit.site.cfg.py.in
@@ -23,6 +23,7 @@ config.mlir_obj_root = "@MLIR_BINARY_DIR@"
 config.mlir_tools_dir = "@MLIR_TOOLS_DIR@"
 config.mlir_cmake_dir = "@MLIR_CMAKE_DIR@"
 config.mlir_lib_dir = "@MLIR_LIB_DIR@"
+config.enable_mlir_rewrite = "@MLIR_ENABLE_REWRITE@"
 
 config.build_examples = @LLVM_BUILD_EXAMPLES@
 config.run_cuda_tests = @MLIR_ENABLE_CUDA_CONVERSIONS@
diff --git a/mlir/test/mlir-rewrite/simple.mlir b/mlir/test/mlir-rewrite/simple.mlir
new file mode 100644
index 00000000000000..cf3a029b0653b0
--- /dev/null
+++ b/mlir/test/mlir-rewrite/simple.mlir
@@ -0,0 +1,12 @@
+// RUN: mlir-opt %s | mlir-rewrite --simple-rename --simple-rename-op-name="test.concat" --simple-rename-match="axis" --simple-rename-replace="bxis" | FileCheck %s -check-prefix=RENAME
+// RUN: mlir-opt %s | mlir-rewrite --mark-ranges | FileCheck %s -check-prefix=RANGE
+// Note: running through mlir-opt to just strip out comments & avoid self matches.
+// REQUIRES: mlir-rewrite
+
+func.func @two_dynamic_one_direct_shape(%arg0: tensor<?x4x?xf32>, %arg1: tensor<2x4x?xf32>) -> tensor<?x4x?xf32> {
+  // RENAME: "test.concat"({{.*}}) {bxis = 0 : i64}
+  // RANGE: 《%{{.*}} = 〖"test.concat"〗({{.*}}) {axis = 0 : i64} : (tensor<?x4x?xf32>, tensor<2x4x?xf32>) -> tensor<?x4x?xf32>》
+  %5 = "test.concat"(%arg0, %arg1) {axis = 0 : i64} : (tensor<?x4x?xf32>, tensor<2x4x?xf32>) -> tensor<?x4x?xf32>
+  return %5 : tensor<?x4x?xf32>
+}
+
diff --git a/mlir/tools/CMakeLists.txt b/mlir/tools/CMakeLists.txt
index 9b474385fdae18..7d330e124a2ca2 100644
--- a/mlir/tools/CMakeLists.txt
+++ b/mlir/tools/CMakeLists.txt
@@ -15,3 +15,8 @@ add_subdirectory(tblgen-to-irdl)
 if(MLIR_ENABLE_EXECUTION_ENGINE)
   add_subdirectory(mlir-cpu-runner)
 endif()
+
+# FIXME: Currently depends on utility functions inside clang.
+if ("clang" IN_LIST LLVM_ENABLE_PROJECTS)
+  add_subdirectory(mlir-rewrite)
+endif()
diff --git a/mlir/tools/mlir-rewrite/CMakeLists.txt b/mlir/tools/mlir-rewrite/CMakeLists.txt
new file mode 100644
index 00000000000000..29126432d2de5d
--- /dev/null
+++ b/mlir/tools/mlir-rewrite/CMakeLists.txt
@@ -0,0 +1,37 @@
+get_property(dialect_libs GLOBAL PROPERTY MLIR_DIALECT_LIBS)
+set(LLVM_LINK_COMPONENTS
+  Support
+  )
+
+set(LIBS
+  ${dialect_libs}
+  ${test_libs}
+
+  clangRewrite
+  MLIRAffineAnalysis
+  MLIRAnalysis
+  MLIRCastInterfaces
+  MLIRDialect
+  MLIROptLib
+  MLIRParser
+  MLIRPass
+  MLIRTransforms
+  MLIRTransformUtils
+  MLIRSupport
+  MLIRIR
+  )
+
+include_directories(../../../clang/include)
+
+add_mlir_tool(mlir-rewrite
+  mlir-rewrite.cpp
+
+  DEPENDS
+  ${LIBS}
+  SUPPORT_PLUGINS
+  )
+target_link_libraries(mlir-rewrite PRIVATE ${LIBS})
+llvm_update_compile_flags(mlir-rewrite)
+
+mlir_check_all_link_libraries(mlir-rewrite)
+export_executable_symbols_for_plugins(mlir-rewrite)
diff --git a/mlir/tools/mlir-rewrite/mlir-rewrite.cpp b/mlir/tools/mlir-rewrite/mlir-rewrite.cpp
new file mode 100644
index 00000000000000..0648dc4309ab6c
--- /dev/null
+++ b/mlir/tools/mlir-rewrite/mlir-rewrite.cpp
@@ -0,0 +1,395 @@
+//===- mlir-rewrite.cpp - MLIR Rewrite Driver -----------------------------===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+//
+// Main entry function for mlir-rewrite.
+//
+//===----------------------------------------------------------------------===//
+
+#include "mlir/AsmParser/AsmParser.h"
+#include "mlir/AsmParser/AsmParserState.h"
+#include "mlir/IR/AsmState.h"
+#include "mlir/IR/Dialect.h"
+#include "mlir/IR/MLIRContext.h"
+#include "mlir/InitAllDialects.h"
+#include "mlir/Pass/Pass.h"
+#include "mlir/Pass/PassManager.h"
+#include "mlir/Support/FileUtilities.h"
+#include "mlir/Tools/ParseUtilities.h"
+// TODO: Refactor the RewriteBuffer out to avoid the weird Clang dep.
+#include "clang/Rewrite/Core/RewriteBuffer.h"
+#include "llvm/Support/CommandLine.h"
+#include "llvm/Support/InitLLVM.h"
+#include "llvm/Support/LineIterator.h"
+#include "llvm/Support/Regex.h"
+#include "llvm/Support/SourceMgr.h"
+#include "llvm/Support/ToolOutputFile.h"
+
+using namespace llvm;
+using namespace mlir;
+
+namespace mlir {
+using OperationDefinition = AsmParserState::OperationDefinition;
+
+/// Return the source code associated with the OperationDefinition.
+SMRange getOpRange(const OperationDefinition &op) {
+  const char *startOp = op.scopeLoc.Start.getPointer();
+  const char *endOp = op.scopeLoc.End.getPointer();
+
+  for (auto res : op.resultGroups) {
+    SMRange range = res.definition.loc;
+    startOp = std::min(startOp, range.Start.getPointer());
+  }
+  return {SMLoc::getFromPointer(startOp), SMLoc::getFromPointer(endOp)};
+}
+
+/// Helper to simplify rewriting the source file.
+class RewriteBuffer {
+public:
+  static std::unique_ptr<RewriteBuffer> init(StringRef inputFilename,
+                                             StringRef outputFilename);
+
+  /// Return the context the file was parsed into.
+  MLIRContext *getContext() { return &context; }
+
+  /// Return the OperationDefinition's of the operation's parsed.
+  auto getOpDefs() { return asmState.getOpDefs(); }
+
+  /// Insert the specified string at the specified location in the original
+  /// buffer.
+  void insertText(SMLoc pos, StringRef str, bool insertAfter = true) {
+    rewriteBuffer.InsertText(pos.getPointer() - start, str, insertAfter);
+  }
+
+  /// Replace the range of the source text with the corresponding string in the
+  /// output.
+  void replaceRange(SMRange range, StringRef str) {
+    rewriteBuffer.ReplaceText(range.Start.getPointer() - start,
+                              range.End.getPointer() - range.Start.getPointer(),
+                              str);
+  }
+
+  /// Replace the range of the operation in the source text with the
+  /// corresponding string in the output.
+  void replaceDef(const OperationDefinition &opDef, StringRef newDef) {
+    replaceRange(getOpRange(opDef), newDef);
+  }
+
+  /// Return the source string corresponding to the source range.
+  StringRef getSourceString(SMRange range) {
+    return StringRef(range.Start.getPointer(),
+                     range.End.getPointer() - range.Start.getPointer());
+  }
+
+  /// Return the source string corresponding to operation definition.
+  StringRef getSourceString(const OperationDefinition &opDef) {
+    auto range = getOpRange(opDef);
+    return getSourceString(range);
+  }
+
+  /// Write to stream the result of applying all changes to the
+  /// original buffer.
+  /// Note that it isn't safe to use this function to overwrite memory mapped
+  /// files in-place (PR17960).
+  ///
+  /// The original buffer is not actually changed.
+  raw_ostream &write(raw_ostream &stream) const {
+    return rewriteBuffer.write(stream);
+  }
+
+  /// Return lines that are purely comments.
+  SmallVector<SMRange> getSingleLineComments() {
+    unsigned curBuf = sourceMgr.getMainFileID();
+    const MemoryBuffer *curMB = sourceMgr.getMemoryBuffer(curBuf);
+    auto lineIterator = line_iterator(*curMB);
+    SmallVector<SMRange> ret;
+    for (; !lineIterator.is_at_end(); ++lineIterator) {
+      StringRef trimmed = lineIterator->ltrim();
+      if (trimmed.starts_with("//")) {
+        ret.emplace_back(
+            SMLoc::getFromPointer(trimmed.data()),
+            SMLoc::getFromPointer(trimmed.data() + trimmed.size()));
+      }
+    }
+    return ret;
+  }
+
+  /// Return the IR from parsed file.
+  Block *getParsed() { return &parsedIR; }
+
+  /// Return the definition for the given operation, or nullptr if the given
+  /// operation does not have a definition.
+  const OperationDefinition &getOpDef(Operation *op) const {
+    return *asmState.getOpDef(op);
+  }
+
+private:
+  // The context and state required to parse.
+  MLIRContext context;
+  SourceMgr sourceMgr;
+  DialectRegistry registry;
+  FallbackAsmResourceMap fallbackResourceMap;
+
+  // Storage of textual parsing results.
+  AsmParserState asmState;
+
+  // Parsed IR.
+  Block parsedIR;
+
+  // The RewriteBuffer from clang-rewrite is doing most of the real work.
+  clang::RewriteBuffer rewriteBuffer;
+
+  // Start of the original input, used to compute offset.
+  const char *start;
+};
+
+std::unique_ptr<RewriteBuffer> RewriteBuffer::init(StringRef inputFilename,
+                                                   StringRef outputFilename) {
+  std::unique_ptr<RewriteBuffer> r = std::make_unique<RewriteBuffer>();
+
+  // Register all the dialects needed.
+  registerAllDialects(r->registry);
+
+  // Set up the input file.
+  std::string errorMessage;
+  std::unique_ptr<llvm::MemoryBuffer> file =
+      openInputFile(inputFilename, &errorMessage);
+  if (!file) {
+    llvm::errs() << errorMessage << "\n";
+    return nullptr;
+  }
+  r->sourceMgr.AddNewSourceBuffer(std::move(file), SMLoc());
+
+  // Set up the MLIR context and error handling.
+  r->context.appendDialectRegistry(r->registry);
+
+  // Record the start of the buffer to compute offsets with.
+  unsigned curBuf = r->sourceMgr.getMainFileID();
+  const MemoryBuffer *curMB = r->sourceMgr.getMemoryBuffer(curBuf);
+  r->start = curMB->getBufferStart();
+  r->rewriteBuffer.Initialize(curMB->getBuffer());
+
+  // Parse and populate the AsmParserState.
+  ParserConfig parseConfig(&r->context, /*verifyAfterParse=*/true,
+                           &r->fallbackResourceMap);
+  // Always allow unregistered.
+  r->context.allowUnregisteredDialects(true);
+  if (failed(parseAsmSourceFile(r->sourceMgr, &r->parsedIR, parseConfig,
+                                &r->asmState)))
+    return nullptr;
+
+  return r;
+}
+
+/// Return the source code associated with the operation name.
+SMRange getOpNameRange(const OperationDefinition &op) { return op.loc; }
+
+/// Return whether the operation was printed using generic syntax in original
+/// buffer.
+bool isGeneric(const OperationDefinition &op) {
+  return op.loc.Start.getPointer()[0] == '"';
+}
+
+inline int asMainReturnCode(LogicalResult r) {
+  return r.succeeded() ? EXIT_SUCCESS : EXIT_FAILURE;
+}
+
+/// Reriter function to invoke.
+using RewriterFunction = std::function<mlir::LogicalResult(
+    mlir::RewriteBuffer &rewriteBuffer, llvm::raw_ostream &os)>;
+
+/// Structure to group information about a rewriter (argument to invoke via
+/// mlir-tblgen, description, and rewriter function).
+class RewriterInfo {
+public:
+  /// RewriterInfo constructor should not be invoked directly, instead use
+  /// RewriterRegistration or registerRewriter.
+  RewriterInfo(StringRef arg, StringRef description, RewriterFunction rewriter)
+      : arg(arg), description(description), rewriter(std::move(rewriter)) {}
+
+  /// Invokes the rewriter and returns whether the rewriter failed.
+  LogicalResult invoke(mlir::RewriteBuffer &rewriteBuffer,
+                       raw_ostream &os) const {
+    assert(rewriter && "Cannot call rewriter with null rewriter");
+    return rewriter(rewriteBuffer, os);
+  }
+
+  /// Returns the command line option that may be passed to 'mlir-rewrite' to
+  /// invoke this rewriter.
+  StringRef getRewriterArgument() const { return arg; }
+
+  /// Returns a description for the rewriter.
+  StringRef getRewriterDescription() const { return description; }
+
+private:
+  // The argument with which to invoke the rewriter via mlir-tblgen.
+  StringRef arg;
+
+  // Description of the rewriter.
+  StringRef description;
+
+  // Rewritererator function.
+  RewriterFunction rewriter;
+};
+
+static llvm::ManagedStatic<std::vector<RewriterInfo>> rewriterRegistry;
+
+/// Adds command line option for each registered rewriter.
+struct RewriterNameParser : public llvm::cl::parser<const RewriterInfo *> {
+  RewriterNameParser(llvm::cl::Option &opt);
+
+  void printOptionInfo(const llvm::cl::Option &o,
+                       size_t globalWidth) const override;
+};
+
+/// RewriterRegistration provides a global initializer that registers a rewriter
+/// function.
+struct RewriterRegistration {
+  RewriterRegistration(StringRef arg, StringRef description,
+                       const RewriterFunction &function);
+};
+
+RewriterRegistration::RewriterRegistration(StringRef arg, StringRef description,
+                                           const RewriterFunction &function) {
+  rewriterRegistry->emplace_back(arg, description, function);
+}
+
+RewriterNameParser::RewriterNameParser(llvm::cl::Option &opt)
+    : llvm::cl::parser<const RewriterInfo *>(opt) {
+  for (const auto &kv : *rewriterRegistry) {
+    addLiteralOption(kv.getRewriterArgument(), &kv,
+                     kv.getRewriterDescription());
+  }
+}
+
+void RewriterNameParser::printOptionInfo(const llvm::cl::Option &o,
+                                         size_t globalWidth) const {
+  RewriterNameParser *tp = const_cast<RewriterNameParser *>(this);
+  llvm::array_pod_sort(tp->Values.begin(), tp->Values.end(),
+                       [](const RewriterNameParser::OptionInfo *vT1,
+                          const RewriterNameParser::OptionInfo *vT2) {
+                         return vT1->Name.compare(vT2->Name);
+                       });
+  using llvm::cl::parser;
+  parser<const RewriterInfo *>::printOptionInfo(o, globalWidth);
+}
+
+} // namespace mlir
+
+// TODO: Make these injectable too in non-global way.
+static llvm::cl::OptionCategory clSimpleRenameCategory{"simple-rename options"};
+static llvm::cl::opt<std::string> simpleRenameOpName{
+    "simple-rename-op-name", llvm::cl::desc("Name of op to match on"),
+    llvm::cl::cat(clSimpleRenameCategory)};
+static llvm::cl::opt<std::string> simpleRenameMatch{
+    "simple-rename-match", llvm::cl::desc("Match string for rename"),
+    llvm::cl::cat(clSimpleRenameCategory)};
+static llvm::cl::opt<std::string> simpleRenameReplace{
+    "simple-rename-replace", llvm::cl::desc("Replace string for rename"),
+    llvm::cl::cat(clSimpleRenameCategory)};
+
+// Rewriter that does simple renames.
+LogicalResult simpleRename(RewriteBuffer &rewriteBuffer, raw_ostream &os) {
+  StringRef opName = simpleRenameOpName;
+  StringRef match = simpleRenameMatch;
+  StringRef replace = simpleRenameReplace;
+  llvm::Regex regex(match);
+
+  rewriteBuffer.getParsed()->walk([&](Operation *op) {
+    if (op->getName().getStringRef() != opName)
+      return;
+
+    const OperationDefinition &opDef = rewriteBuffer.getOpDef(op);
+    SMRange range = getOpRange(opDef);
+    // This is a little bit overkill for simple.
+    std::string str = regex.sub(replace, rewriteBuffer.getSourceString(range));
+    rewriteBuffer.replaceRange(range, str);
+  });
+  return success();
+}
+
+static mlir::RewriterRegistration rewriteSimpleRename("simple-rename",
+                                                      "Perform a simple rename",
+                                                      simpleRename);
+
+// Rewriter that insert range markers.
+LogicalResult markRanges(RewriteBuffer &rewriteBuffer, raw_ostream &os) {
+  int i = 0;
+  for (auto it : rewriteBuffer.getOpDefs()) {
+    auto [startOp, endOp] = getOpRange(it);
+
+    rewriteBuffer.insertText(startOp, "《");
+    rewriteBuffer.insertText(endOp, "》");
+
+    auto nameRange = getOpNameRange(it);
+
+    if (isGeneric(it)) {
+      rewriteBuffer.insertText(nameRange.Start, "〖");
+      rewriteBuffer.insertText(nameRange.End, "〗");
+    } else {
+      rewriteBuffer.insertText(nameRange.Start, "〔");
+      rewriteBuffer.insertText(nameRange.End, "〕");
+    }
+    ++i;
+  }
+
+  // Highlight all comment lines.
+  // TODO: Could be replaced if this is kept in memory.
+  for (auto commentLine : rewriteBuffer.getSingleLineComments()) {
+    rewriteBuffer.insertText(commentLine.Start, "❰");
+    rewriteBuffer.insertText(commentLine.End, "❱");
+  }
+
+  return success();
+}
+
+static mlir::RewriterRegistration
+    rewriteMarkRanges("mark-ranges", "Indicate ranges parsed", markRanges);
+
+int main(int argc, char **argv) {
+  static cl::opt<std::string> inputFilename(
+      cl::Positional, cl::desc("<input file>"), cl::init("-"));
+
+  static cl::opt<std::string> outputFilename("o", cl::desc("Output filename"),
+                                             cl::value_desc("filename"),
+                                             cl::init("-"));
+
+  llvm::cl::opt<const mlir::RewriterInfo *, false, mlir::RewriterNameParser>
+      rewriter("", llvm::cl::desc("Rewriter to run"));
+
+  std::string helpHeader = "mlir-rewrite";
+
+  cl::ParseCommandLineOptions(argc, argv, helpHeader);
+
+  // If no rewriter has been selected, exit with error code. Could also just
+  // return but its unlikely this was intentionally being used as `cp`.
+  if (!rewriter) {
+    llvm::errs() << "No rewriter selected!\n";
+    return mlir::asMainReturnCode(mlir::failure());
+  }
+
+  // Set up rewrite buffer.
+  auto rewriterOr = RewriteBuffer::init(inputFilename, outputFilename);
+  if (!rewriterOr)
+    return mlir::asMainReturnCode(mlir::failure());
+
+  // Set up the output file.
+  std::string errorMessage;
+  auto output = openOutputFile(outputFilename, &errorMessage);
+  if (!output) {
+    llvm::errs() << errorMessage << "\n";
+    return mlir::asMainReturnCode(mlir::failure());
+  }
+
+  LogicalResult result = rewriter->invoke(*rewriterOr, output->os());
+  if (succeeded(result)) {
+    rewriterOr->write(output->os());
+    output->keep();
+  }
+  return mlir::asMainReturnCode(result);
+}
+

@llvmbot
Copy link
Member

llvmbot commented Jan 10, 2024

@llvm/pr-subscribers-mlir-core

Author: Jacques Pienaar (jpienaar)

Changes

Initial commit of a tool to help in textual rewrites of .mlir files. This tool builds of of AsmParserState and is rather simple. Took some inspiration from when I used clang's AST rewrites where I'd often treat it as a "localizing" regex applicator in fallback cases, and started with that as functionality. There though, one does have access to the lower level info than here, but still a step up over sed over entire file.

This aims to be helpful (e.g., rewrite syntax including best effort inside comments) rather than bulletproof tool. It may even be better suited under utils than tools. And most of the rewrites would be rather short lived and might never make it upstream (while the helpers of those rewrites may for future rewrites).

The layering at the moment is not ideal as it is reusing the RewriteBuffer class from clang's rewrite engine. So only optionally enabling where clang is also enable. There doesn't seem to be anything clang specific there (the dep does pull in more dependencies than ideal, but leaving both refactorings).

Additionally started it as a single file to prototype more easily, planning to refactor later to include and libs for out of file usage.


Full diff: https://github.com/llvm/llvm-project/pull/77668.diff

9 Files Affected:

  • (modified) mlir/CMakeLists.txt (+5)
  • (added) mlir/docs/Tools/mlir-rewrite.md (+29)
  • (modified) mlir/test/CMakeLists.txt (+7)
  • (modified) mlir/test/lit.cfg.py (+4)
  • (modified) mlir/test/lit.site.cfg.py.in (+1)
  • (added) mlir/test/mlir-rewrite/simple.mlir (+12)
  • (modified) mlir/tools/CMakeLists.txt (+5)
  • (added) mlir/tools/mlir-rewrite/CMakeLists.txt (+37)
  • (added) mlir/tools/mlir-rewrite/mlir-rewrite.cpp (+395)
diff --git a/mlir/CMakeLists.txt b/mlir/CMakeLists.txt
index 2d9f78e03ba76b..64aad84e90a5ad 100644
--- a/mlir/CMakeLists.txt
+++ b/mlir/CMakeLists.txt
@@ -285,3 +285,8 @@ endif()
 if(MLIR_STANDALONE_BUILD)
   llvm_distribution_add_targets()
 endif()
+
+# FIXME: Currently depends on utility functions inside clang.
+if ("clang" IN_LIST LLVM_ENABLE_PROJECTS)
+  set(MLIR_ENABLE_REWRITE ON CACHE BOOL "mlir-rewrite enabled")
+endif()
diff --git a/mlir/docs/Tools/mlir-rewrite.md b/mlir/docs/Tools/mlir-rewrite.md
new file mode 100644
index 00000000000000..178f92f72cbb6e
--- /dev/null
+++ b/mlir/docs/Tools/mlir-rewrite.md
@@ -0,0 +1,29 @@
+# mlir-rewrite
+
+Tool to simplify rewriting .mlir files. There are a couple of build in rewrites
+discussed below along with usage.
+
+Note: This is still in very early stage. Its so early its less a tool than a
+growing collection of useful functions: to use its best to do what's needed on
+a brance by just hacking it (dialects registered, rewrites etc) to say help
+ease a rename, upstream useful utility functions, point to ease others
+migrating, and then bin eventually. Once there are actually useful parts it
+should be refactored same as mlir-opt.
+
+[TOC]
+
+## simple-rename
+
+Rename per op given a substring to a target. The match and replace uses LLVM's
+regex sub for the match and replace while the op-name is matched via regular
+string comparison. E.g.,
+
+```
+mlir-rewrite input.mlir -o output.mlir --simple-rename \
+   --simple-rename-op-name="test.concat" --simple-rename-match="axis" \
+                                         --simple-rename-replace="bxis"
+```
+
+to replace `axis` substring in the text of the range corresponding to
+`test.concat` ops with `bxis`.
+
diff --git a/mlir/test/CMakeLists.txt b/mlir/test/CMakeLists.txt
index 8ce030feeded92..397a2efcf5e9f2 100644
--- a/mlir/test/CMakeLists.txt
+++ b/mlir/test/CMakeLists.txt
@@ -197,6 +197,13 @@ if(MLIR_ENABLE_BINDINGS_PYTHON)
   )
 endif()
 
+# FIXME: Currently depends on utility functions inside clang.
+if ("clang" IN_LIST LLVM_ENABLE_PROJECTS)
+  list(APPEND MLIR_TEST_DEPENDS
+    mlir-rewrite
+  )
+endif()
+
 # This target can be used to just build the dependencies
 # for the check-mlir target without executing the tests.
 # This is useful for bots when splitting the build step
diff --git a/mlir/test/lit.cfg.py b/mlir/test/lit.cfg.py
index 0a1ea1d16da452..35d6a3bd1f5636 100644
--- a/mlir/test/lit.cfg.py
+++ b/mlir/test/lit.cfg.py
@@ -144,6 +144,10 @@ def add_runtime(name):
         )
     )
 
+if config.enable_mlir_rewrite:
+    tools.extend(["mlir-rewrite"])
+    config.available_features.add('mlir-rewrite')
+
 # The following tools are optional
 tools.extend(
     [
diff --git a/mlir/test/lit.site.cfg.py.in b/mlir/test/lit.site.cfg.py.in
index c0fa1b8980e539..d35e3701198a56 100644
--- a/mlir/test/lit.site.cfg.py.in
+++ b/mlir/test/lit.site.cfg.py.in
@@ -23,6 +23,7 @@ config.mlir_obj_root = "@MLIR_BINARY_DIR@"
 config.mlir_tools_dir = "@MLIR_TOOLS_DIR@"
 config.mlir_cmake_dir = "@MLIR_CMAKE_DIR@"
 config.mlir_lib_dir = "@MLIR_LIB_DIR@"
+config.enable_mlir_rewrite = "@MLIR_ENABLE_REWRITE@"
 
 config.build_examples = @LLVM_BUILD_EXAMPLES@
 config.run_cuda_tests = @MLIR_ENABLE_CUDA_CONVERSIONS@
diff --git a/mlir/test/mlir-rewrite/simple.mlir b/mlir/test/mlir-rewrite/simple.mlir
new file mode 100644
index 00000000000000..cf3a029b0653b0
--- /dev/null
+++ b/mlir/test/mlir-rewrite/simple.mlir
@@ -0,0 +1,12 @@
+// RUN: mlir-opt %s | mlir-rewrite --simple-rename --simple-rename-op-name="test.concat" --simple-rename-match="axis" --simple-rename-replace="bxis" | FileCheck %s -check-prefix=RENAME
+// RUN: mlir-opt %s | mlir-rewrite --mark-ranges | FileCheck %s -check-prefix=RANGE
+// Note: running through mlir-opt to just strip out comments & avoid self matches.
+// REQUIRES: mlir-rewrite
+
+func.func @two_dynamic_one_direct_shape(%arg0: tensor<?x4x?xf32>, %arg1: tensor<2x4x?xf32>) -> tensor<?x4x?xf32> {
+  // RENAME: "test.concat"({{.*}}) {bxis = 0 : i64}
+  // RANGE: 《%{{.*}} = 〖"test.concat"〗({{.*}}) {axis = 0 : i64} : (tensor<?x4x?xf32>, tensor<2x4x?xf32>) -> tensor<?x4x?xf32>》
+  %5 = "test.concat"(%arg0, %arg1) {axis = 0 : i64} : (tensor<?x4x?xf32>, tensor<2x4x?xf32>) -> tensor<?x4x?xf32>
+  return %5 : tensor<?x4x?xf32>
+}
+
diff --git a/mlir/tools/CMakeLists.txt b/mlir/tools/CMakeLists.txt
index 9b474385fdae18..7d330e124a2ca2 100644
--- a/mlir/tools/CMakeLists.txt
+++ b/mlir/tools/CMakeLists.txt
@@ -15,3 +15,8 @@ add_subdirectory(tblgen-to-irdl)
 if(MLIR_ENABLE_EXECUTION_ENGINE)
   add_subdirectory(mlir-cpu-runner)
 endif()
+
+# FIXME: Currently depends on utility functions inside clang.
+if ("clang" IN_LIST LLVM_ENABLE_PROJECTS)
+  add_subdirectory(mlir-rewrite)
+endif()
diff --git a/mlir/tools/mlir-rewrite/CMakeLists.txt b/mlir/tools/mlir-rewrite/CMakeLists.txt
new file mode 100644
index 00000000000000..29126432d2de5d
--- /dev/null
+++ b/mlir/tools/mlir-rewrite/CMakeLists.txt
@@ -0,0 +1,37 @@
+get_property(dialect_libs GLOBAL PROPERTY MLIR_DIALECT_LIBS)
+set(LLVM_LINK_COMPONENTS
+  Support
+  )
+
+set(LIBS
+  ${dialect_libs}
+  ${test_libs}
+
+  clangRewrite
+  MLIRAffineAnalysis
+  MLIRAnalysis
+  MLIRCastInterfaces
+  MLIRDialect
+  MLIROptLib
+  MLIRParser
+  MLIRPass
+  MLIRTransforms
+  MLIRTransformUtils
+  MLIRSupport
+  MLIRIR
+  )
+
+include_directories(../../../clang/include)
+
+add_mlir_tool(mlir-rewrite
+  mlir-rewrite.cpp
+
+  DEPENDS
+  ${LIBS}
+  SUPPORT_PLUGINS
+  )
+target_link_libraries(mlir-rewrite PRIVATE ${LIBS})
+llvm_update_compile_flags(mlir-rewrite)
+
+mlir_check_all_link_libraries(mlir-rewrite)
+export_executable_symbols_for_plugins(mlir-rewrite)
diff --git a/mlir/tools/mlir-rewrite/mlir-rewrite.cpp b/mlir/tools/mlir-rewrite/mlir-rewrite.cpp
new file mode 100644
index 00000000000000..0648dc4309ab6c
--- /dev/null
+++ b/mlir/tools/mlir-rewrite/mlir-rewrite.cpp
@@ -0,0 +1,395 @@
+//===- mlir-rewrite.cpp - MLIR Rewrite Driver -----------------------------===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+//
+// Main entry function for mlir-rewrite.
+//
+//===----------------------------------------------------------------------===//
+
+#include "mlir/AsmParser/AsmParser.h"
+#include "mlir/AsmParser/AsmParserState.h"
+#include "mlir/IR/AsmState.h"
+#include "mlir/IR/Dialect.h"
+#include "mlir/IR/MLIRContext.h"
+#include "mlir/InitAllDialects.h"
+#include "mlir/Pass/Pass.h"
+#include "mlir/Pass/PassManager.h"
+#include "mlir/Support/FileUtilities.h"
+#include "mlir/Tools/ParseUtilities.h"
+// TODO: Refactor the RewriteBuffer out to avoid the weird Clang dep.
+#include "clang/Rewrite/Core/RewriteBuffer.h"
+#include "llvm/Support/CommandLine.h"
+#include "llvm/Support/InitLLVM.h"
+#include "llvm/Support/LineIterator.h"
+#include "llvm/Support/Regex.h"
+#include "llvm/Support/SourceMgr.h"
+#include "llvm/Support/ToolOutputFile.h"
+
+using namespace llvm;
+using namespace mlir;
+
+namespace mlir {
+using OperationDefinition = AsmParserState::OperationDefinition;
+
+/// Return the source code associated with the OperationDefinition.
+SMRange getOpRange(const OperationDefinition &op) {
+  const char *startOp = op.scopeLoc.Start.getPointer();
+  const char *endOp = op.scopeLoc.End.getPointer();
+
+  for (auto res : op.resultGroups) {
+    SMRange range = res.definition.loc;
+    startOp = std::min(startOp, range.Start.getPointer());
+  }
+  return {SMLoc::getFromPointer(startOp), SMLoc::getFromPointer(endOp)};
+}
+
+/// Helper to simplify rewriting the source file.
+class RewriteBuffer {
+public:
+  static std::unique_ptr<RewriteBuffer> init(StringRef inputFilename,
+                                             StringRef outputFilename);
+
+  /// Return the context the file was parsed into.
+  MLIRContext *getContext() { return &context; }
+
+  /// Return the OperationDefinition's of the operation's parsed.
+  auto getOpDefs() { return asmState.getOpDefs(); }
+
+  /// Insert the specified string at the specified location in the original
+  /// buffer.
+  void insertText(SMLoc pos, StringRef str, bool insertAfter = true) {
+    rewriteBuffer.InsertText(pos.getPointer() - start, str, insertAfter);
+  }
+
+  /// Replace the range of the source text with the corresponding string in the
+  /// output.
+  void replaceRange(SMRange range, StringRef str) {
+    rewriteBuffer.ReplaceText(range.Start.getPointer() - start,
+                              range.End.getPointer() - range.Start.getPointer(),
+                              str);
+  }
+
+  /// Replace the range of the operation in the source text with the
+  /// corresponding string in the output.
+  void replaceDef(const OperationDefinition &opDef, StringRef newDef) {
+    replaceRange(getOpRange(opDef), newDef);
+  }
+
+  /// Return the source string corresponding to the source range.
+  StringRef getSourceString(SMRange range) {
+    return StringRef(range.Start.getPointer(),
+                     range.End.getPointer() - range.Start.getPointer());
+  }
+
+  /// Return the source string corresponding to operation definition.
+  StringRef getSourceString(const OperationDefinition &opDef) {
+    auto range = getOpRange(opDef);
+    return getSourceString(range);
+  }
+
+  /// Write to stream the result of applying all changes to the
+  /// original buffer.
+  /// Note that it isn't safe to use this function to overwrite memory mapped
+  /// files in-place (PR17960).
+  ///
+  /// The original buffer is not actually changed.
+  raw_ostream &write(raw_ostream &stream) const {
+    return rewriteBuffer.write(stream);
+  }
+
+  /// Return lines that are purely comments.
+  SmallVector<SMRange> getSingleLineComments() {
+    unsigned curBuf = sourceMgr.getMainFileID();
+    const MemoryBuffer *curMB = sourceMgr.getMemoryBuffer(curBuf);
+    auto lineIterator = line_iterator(*curMB);
+    SmallVector<SMRange> ret;
+    for (; !lineIterator.is_at_end(); ++lineIterator) {
+      StringRef trimmed = lineIterator->ltrim();
+      if (trimmed.starts_with("//")) {
+        ret.emplace_back(
+            SMLoc::getFromPointer(trimmed.data()),
+            SMLoc::getFromPointer(trimmed.data() + trimmed.size()));
+      }
+    }
+    return ret;
+  }
+
+  /// Return the IR from parsed file.
+  Block *getParsed() { return &parsedIR; }
+
+  /// Return the definition for the given operation, or nullptr if the given
+  /// operation does not have a definition.
+  const OperationDefinition &getOpDef(Operation *op) const {
+    return *asmState.getOpDef(op);
+  }
+
+private:
+  // The context and state required to parse.
+  MLIRContext context;
+  SourceMgr sourceMgr;
+  DialectRegistry registry;
+  FallbackAsmResourceMap fallbackResourceMap;
+
+  // Storage of textual parsing results.
+  AsmParserState asmState;
+
+  // Parsed IR.
+  Block parsedIR;
+
+  // The RewriteBuffer from clang-rewrite is doing most of the real work.
+  clang::RewriteBuffer rewriteBuffer;
+
+  // Start of the original input, used to compute offset.
+  const char *start;
+};
+
+std::unique_ptr<RewriteBuffer> RewriteBuffer::init(StringRef inputFilename,
+                                                   StringRef outputFilename) {
+  std::unique_ptr<RewriteBuffer> r = std::make_unique<RewriteBuffer>();
+
+  // Register all the dialects needed.
+  registerAllDialects(r->registry);
+
+  // Set up the input file.
+  std::string errorMessage;
+  std::unique_ptr<llvm::MemoryBuffer> file =
+      openInputFile(inputFilename, &errorMessage);
+  if (!file) {
+    llvm::errs() << errorMessage << "\n";
+    return nullptr;
+  }
+  r->sourceMgr.AddNewSourceBuffer(std::move(file), SMLoc());
+
+  // Set up the MLIR context and error handling.
+  r->context.appendDialectRegistry(r->registry);
+
+  // Record the start of the buffer to compute offsets with.
+  unsigned curBuf = r->sourceMgr.getMainFileID();
+  const MemoryBuffer *curMB = r->sourceMgr.getMemoryBuffer(curBuf);
+  r->start = curMB->getBufferStart();
+  r->rewriteBuffer.Initialize(curMB->getBuffer());
+
+  // Parse and populate the AsmParserState.
+  ParserConfig parseConfig(&r->context, /*verifyAfterParse=*/true,
+                           &r->fallbackResourceMap);
+  // Always allow unregistered.
+  r->context.allowUnregisteredDialects(true);
+  if (failed(parseAsmSourceFile(r->sourceMgr, &r->parsedIR, parseConfig,
+                                &r->asmState)))
+    return nullptr;
+
+  return r;
+}
+
+/// Return the source code associated with the operation name.
+SMRange getOpNameRange(const OperationDefinition &op) { return op.loc; }
+
+/// Return whether the operation was printed using generic syntax in original
+/// buffer.
+bool isGeneric(const OperationDefinition &op) {
+  return op.loc.Start.getPointer()[0] == '"';
+}
+
+inline int asMainReturnCode(LogicalResult r) {
+  return r.succeeded() ? EXIT_SUCCESS : EXIT_FAILURE;
+}
+
+/// Reriter function to invoke.
+using RewriterFunction = std::function<mlir::LogicalResult(
+    mlir::RewriteBuffer &rewriteBuffer, llvm::raw_ostream &os)>;
+
+/// Structure to group information about a rewriter (argument to invoke via
+/// mlir-tblgen, description, and rewriter function).
+class RewriterInfo {
+public:
+  /// RewriterInfo constructor should not be invoked directly, instead use
+  /// RewriterRegistration or registerRewriter.
+  RewriterInfo(StringRef arg, StringRef description, RewriterFunction rewriter)
+      : arg(arg), description(description), rewriter(std::move(rewriter)) {}
+
+  /// Invokes the rewriter and returns whether the rewriter failed.
+  LogicalResult invoke(mlir::RewriteBuffer &rewriteBuffer,
+                       raw_ostream &os) const {
+    assert(rewriter && "Cannot call rewriter with null rewriter");
+    return rewriter(rewriteBuffer, os);
+  }
+
+  /// Returns the command line option that may be passed to 'mlir-rewrite' to
+  /// invoke this rewriter.
+  StringRef getRewriterArgument() const { return arg; }
+
+  /// Returns a description for the rewriter.
+  StringRef getRewriterDescription() const { return description; }
+
+private:
+  // The argument with which to invoke the rewriter via mlir-tblgen.
+  StringRef arg;
+
+  // Description of the rewriter.
+  StringRef description;
+
+  // Rewritererator function.
+  RewriterFunction rewriter;
+};
+
+static llvm::ManagedStatic<std::vector<RewriterInfo>> rewriterRegistry;
+
+/// Adds command line option for each registered rewriter.
+struct RewriterNameParser : public llvm::cl::parser<const RewriterInfo *> {
+  RewriterNameParser(llvm::cl::Option &opt);
+
+  void printOptionInfo(const llvm::cl::Option &o,
+                       size_t globalWidth) const override;
+};
+
+/// RewriterRegistration provides a global initializer that registers a rewriter
+/// function.
+struct RewriterRegistration {
+  RewriterRegistration(StringRef arg, StringRef description,
+                       const RewriterFunction &function);
+};
+
+RewriterRegistration::RewriterRegistration(StringRef arg, StringRef description,
+                                           const RewriterFunction &function) {
+  rewriterRegistry->emplace_back(arg, description, function);
+}
+
+RewriterNameParser::RewriterNameParser(llvm::cl::Option &opt)
+    : llvm::cl::parser<const RewriterInfo *>(opt) {
+  for (const auto &kv : *rewriterRegistry) {
+    addLiteralOption(kv.getRewriterArgument(), &kv,
+                     kv.getRewriterDescription());
+  }
+}
+
+void RewriterNameParser::printOptionInfo(const llvm::cl::Option &o,
+                                         size_t globalWidth) const {
+  RewriterNameParser *tp = const_cast<RewriterNameParser *>(this);
+  llvm::array_pod_sort(tp->Values.begin(), tp->Values.end(),
+                       [](const RewriterNameParser::OptionInfo *vT1,
+                          const RewriterNameParser::OptionInfo *vT2) {
+                         return vT1->Name.compare(vT2->Name);
+                       });
+  using llvm::cl::parser;
+  parser<const RewriterInfo *>::printOptionInfo(o, globalWidth);
+}
+
+} // namespace mlir
+
+// TODO: Make these injectable too in non-global way.
+static llvm::cl::OptionCategory clSimpleRenameCategory{"simple-rename options"};
+static llvm::cl::opt<std::string> simpleRenameOpName{
+    "simple-rename-op-name", llvm::cl::desc("Name of op to match on"),
+    llvm::cl::cat(clSimpleRenameCategory)};
+static llvm::cl::opt<std::string> simpleRenameMatch{
+    "simple-rename-match", llvm::cl::desc("Match string for rename"),
+    llvm::cl::cat(clSimpleRenameCategory)};
+static llvm::cl::opt<std::string> simpleRenameReplace{
+    "simple-rename-replace", llvm::cl::desc("Replace string for rename"),
+    llvm::cl::cat(clSimpleRenameCategory)};
+
+// Rewriter that does simple renames.
+LogicalResult simpleRename(RewriteBuffer &rewriteBuffer, raw_ostream &os) {
+  StringRef opName = simpleRenameOpName;
+  StringRef match = simpleRenameMatch;
+  StringRef replace = simpleRenameReplace;
+  llvm::Regex regex(match);
+
+  rewriteBuffer.getParsed()->walk([&](Operation *op) {
+    if (op->getName().getStringRef() != opName)
+      return;
+
+    const OperationDefinition &opDef = rewriteBuffer.getOpDef(op);
+    SMRange range = getOpRange(opDef);
+    // This is a little bit overkill for simple.
+    std::string str = regex.sub(replace, rewriteBuffer.getSourceString(range));
+    rewriteBuffer.replaceRange(range, str);
+  });
+  return success();
+}
+
+static mlir::RewriterRegistration rewriteSimpleRename("simple-rename",
+                                                      "Perform a simple rename",
+                                                      simpleRename);
+
+// Rewriter that insert range markers.
+LogicalResult markRanges(RewriteBuffer &rewriteBuffer, raw_ostream &os) {
+  int i = 0;
+  for (auto it : rewriteBuffer.getOpDefs()) {
+    auto [startOp, endOp] = getOpRange(it);
+
+    rewriteBuffer.insertText(startOp, "《");
+    rewriteBuffer.insertText(endOp, "》");
+
+    auto nameRange = getOpNameRange(it);
+
+    if (isGeneric(it)) {
+      rewriteBuffer.insertText(nameRange.Start, "〖");
+      rewriteBuffer.insertText(nameRange.End, "〗");
+    } else {
+      rewriteBuffer.insertText(nameRange.Start, "〔");
+      rewriteBuffer.insertText(nameRange.End, "〕");
+    }
+    ++i;
+  }
+
+  // Highlight all comment lines.
+  // TODO: Could be replaced if this is kept in memory.
+  for (auto commentLine : rewriteBuffer.getSingleLineComments()) {
+    rewriteBuffer.insertText(commentLine.Start, "❰");
+    rewriteBuffer.insertText(commentLine.End, "❱");
+  }
+
+  return success();
+}
+
+static mlir::RewriterRegistration
+    rewriteMarkRanges("mark-ranges", "Indicate ranges parsed", markRanges);
+
+int main(int argc, char **argv) {
+  static cl::opt<std::string> inputFilename(
+      cl::Positional, cl::desc("<input file>"), cl::init("-"));
+
+  static cl::opt<std::string> outputFilename("o", cl::desc("Output filename"),
+                                             cl::value_desc("filename"),
+                                             cl::init("-"));
+
+  llvm::cl::opt<const mlir::RewriterInfo *, false, mlir::RewriterNameParser>
+      rewriter("", llvm::cl::desc("Rewriter to run"));
+
+  std::string helpHeader = "mlir-rewrite";
+
+  cl::ParseCommandLineOptions(argc, argv, helpHeader);
+
+  // If no rewriter has been selected, exit with error code. Could also just
+  // return but its unlikely this was intentionally being used as `cp`.
+  if (!rewriter) {
+    llvm::errs() << "No rewriter selected!\n";
+    return mlir::asMainReturnCode(mlir::failure());
+  }
+
+  // Set up rewrite buffer.
+  auto rewriterOr = RewriteBuffer::init(inputFilename, outputFilename);
+  if (!rewriterOr)
+    return mlir::asMainReturnCode(mlir::failure());
+
+  // Set up the output file.
+  std::string errorMessage;
+  auto output = openOutputFile(outputFilename, &errorMessage);
+  if (!output) {
+    llvm::errs() << errorMessage << "\n";
+    return mlir::asMainReturnCode(mlir::failure());
+  }
+
+  LogicalResult result = rewriter->invoke(*rewriterOr, output->os());
+  if (succeeded(result)) {
+    rewriterOr->write(output->os());
+    output->keep();
+  }
+  return mlir::asMainReturnCode(result);
+}
+

Copy link

github-actions bot commented Jan 10, 2024

✅ With the latest revision this PR passed the Python code formatter.

Copy link

github-actions bot commented Jan 10, 2024

✅ With the latest revision this PR passed the C/C++ code formatter.

Copy link
Contributor

@Mogball Mogball left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Awesome! Does this work with custom syntax though?

Copy link
Collaborator

@joker-eph joker-eph left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is a nice POC, but the clang dependency makes it a hard blocker to me to land this in tree.

@jpienaar jpienaar force-pushed the mlirrewrite branch 2 times, most recently from 702ac0f to 72fbd9b Compare August 19, 2024 03:26
@jpienaar
Copy link
Member Author

Awesome! Does this work with custom syntax though?

It does in my testing. Now its very basic source ranges ...

This is a nice POC, but the clang dependency makes it a hard blocker to me to land this in tree.

SG, removed clang dep and optional building.

@jpienaar jpienaar requested a review from joker-eph October 4, 2024 22:09
std::unique_ptr<RewriteBuffer> r = std::make_unique<RewriteBuffer>();

// Register all the dialects needed.
registerAllDialects(r->registry);
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

There will be some non-trivial refactoring to make it a properly usable utility (like mlir-opt is separated from MlirOptMain library for example.

@llvmbot llvmbot added the bazel "Peripheral" support tier build system: utils/bazel label Oct 12, 2024
Initial commit of a tool to help in textual rewrites of .mlir files.
This tool builds of of AsmParserState and is rather simple. Took some
inspiration from when I used clang's AST rewrites where I'd often treat
it as a "localizing" regex applicator in fallback cases, and started
with that as functionality. There though, one does have access to the
lower level info than here, but still a step up over sed over entire
file.

This aims to be helpful (e.g., rewrite syntax including best effort
inside comments) rather than bulletproof tool. It may even be better
suited under utils than tools. And most of the rewrites would be rather
short lived and might never make it upstream (while the helpers of those
rewrites may for future rewrites).

Started it as a single file to prototype more easily, planning to
refactor later to include and libs for out of file usage.
@jpienaar jpienaar merged commit 4c25a53 into llvm:main Oct 12, 2024
8 checks passed
@harrisonGPU
Copy link
Contributor

@jpienaar , hello, I encountered a build issue because it used Unicode. If you have time, please take a look. :)

#112300

DanielCChen pushed a commit to DanielCChen/llvm-project that referenced this pull request Oct 16, 2024
Initial commit of a tool to help in textual rewrites of .mlir files.
This tool builds of of AsmParserState and is rather simple. Took some
inspiration from when I used clang's AST rewrites where I'd often treat
it as a "localizing" regex applicator in fallback cases, and started
with that as functionality. There though, one does have access to the
lower level info than here, but still a step up over sed over entire
file.

This aims to be helpful (e.g., rewrite syntax including best effort
inside comments) rather than bulletproof tool. It may even be better
suited under utils than tools. And most of the rewrites would be rather
short lived and might never make it upstream (while the helpers of those
rewrites may for future rewrites).

The layering at the moment is not ideal as it is reusing the
RewriteBuffer class from clang's rewrite engine. So only optionally
enabling where clang is also enable. There doesn't seem to be anything
clang specific there (the dep does pull in more dependencies than ideal,
but leaving both refactorings).

Additionally started it as a single file to prototype more easily,
planning to refactor later to include and libs for out of file usage.
harrisonGPU added a commit that referenced this pull request Oct 18, 2024
This issue is from #77668. I
encountered a build issue because it used Unicode. When I built MLIR on
Windows with Visual Studio 2022, I faced a build failure.

---------

Co-authored-by: Harrison Hao <[email protected]>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
bazel "Peripheral" support tier build system: utils/bazel mlir:core MLIR Core Infrastructure mlir
Projects
None yet
Development

Successfully merging this pull request may close these issues.

5 participants