Skip to content

[mlir] Retain original identifier names for debugging #79704

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 11 commits into
base: main
Choose a base branch
from

Conversation

Wheest
Copy link
Contributor

@Wheest Wheest commented Jan 27, 2024

This PR implements retention of MLIR identifier names (e.g., %my_val, ^bb_foo) for debugging and development purposes.

A wider discussion of this feature is in this discourse thread.

A motivating example is that right now, IR generation drops all meaningful identifier names, which could be useful for developers trying to understand their passes, or for other tooling (e.g., MLIR code formatters).

func.func @add_one(%my_input: f64) -> f64 {
    %my_constant = arith.constant 1.00000e+00 : f64
    %my_output = arith.addf %my_input, %my_constant : f64
    return %my_output : f64
}

⬇️ becomes

func.func @add_one(%arg0: f64) -> f64 {
  %cst = arith.constant 1.000000e+00 : f64
  %0 = arith.addf %arg0, %cst : f64
  return %0 : f64
}

The solution this PR implements is to store this metadata inside the attribute dictionary of operations, under a special namespace (e.g., mlir.resultNames). This means that this optional feature (turned on by the mlir-opt flag --retain-identifier-names) does not incur any additional overhead, except in text parsing and printing (AsmParser/Parser.cpp and IR/AsmPrinter.cpp).

Alternative solutions, such as adding a string field to the Value class, reparsing location information, and adapting OpAsmInterface are discussed in the relevant discourse thread).

I've implemented some initial test cases in mlir/test/IR/print-retain-identifiers.mlir.

This covers things such as:

  • retaining the result names of operation
  • retaining the names of basic blocks (and their arguments)
  • handling result groups

A case that I know not to work is when a func.func argument is not used. This is because we recover the SSA name of these arguments from operations which use them. You can see in the first test case that the 2nd argument is not used, so will always default back to %arg0.

I do not have test cases for how the system handles code transformations, and am open to suggestions for additional tests to include. Also note that this is my most substantial contribution to the codebase to date, so I may require some shepherding with regards to coding style or use of core LLVM library constructs.

@llvmbot llvmbot added mlir:core MLIR Core Infrastructure mlir labels Jan 27, 2024
@llvmbot
Copy link
Member

llvmbot commented Jan 27, 2024

@llvm/pr-subscribers-mlir

@llvm/pr-subscribers-mlir-core

Author: Perry Gibson (Wheest)

Changes

This PR implements retention of MLIR identifier names (e.g., %my_val, ^bb_foo) for debugging and development purposes.

A wider discussion of this feature is in this discourse thread.

A motivating example is that right now, IR generation drops all meaningful identifier names, which could be useful for developers trying to understand their passes, or for other tooling (e.g., MLIR code formatters).

func.func @<!-- -->add_one(%my_input: f64) -&gt; f64 {
    %my_constant = arith.constant 1.00000e+00 : f64
    %my_output = arith.addf %my_input, %my_constant : f64
    return %my_output : f64
}

⬇️ becomes

func.func @<!-- -->add_one(%arg0: f64) -&gt; f64 {
  %cst = arith.constant 1.000000e+00 : f64
  %0 = arith.addf %arg0, %cst : f64
  return %0 : f64
}

The solution this PR implements is to store this metadata inside the attribute dictionary of operations, under a special namespace (e.g., mlir.resultNames). This means that this optional feature (turned on by the mlir-opt flag --retain-identifier-names) does not incur any additional overhead, except in text parsing and printing (AsmParser/Parser.cpp and IR/AsmPrinter.cpp).

Alternative solutions, such as adding a string field to the Value class, reparsing location information, and adapting OpAsmInterface are discussed in the relevant discourse thread).

I've implemented some initial test cases in mlir/test/IR/print-retain-identifiers.mlir.

This covers things such as:

  • retaining the result names of operation
  • retaining the names of basic blocks (and their arguments)
  • handling result groups

A case that I know not to work is when a func.func argument is not used. This is because we recover the SSA name of these arguments from operations which use them. You can see in the first test case that the 2nd argument is not used, so will always default back to %arg0.

I do not have test cases for how the system handles code transformations, and am open to suggestions for additional tests to include. Also note that this is my most substantial contribution to the codebase to date, so I may require some shepherding with regards to coding style or use of core LLVM library constructs.


Full diff: https://github.com/llvm/llvm-project/pull/79704.diff

6 Files Affected:

  • (modified) mlir/include/mlir/IR/AsmState.h (+9-3)
  • (modified) mlir/include/mlir/Tools/mlir-opt/MlirOptMain.h (+10)
  • (modified) mlir/lib/AsmParser/Parser.cpp (+70)
  • (modified) mlir/lib/IR/AsmPrinter.cpp (+72-6)
  • (modified) mlir/lib/Tools/mlir-opt/MlirOptMain.cpp (+8-2)
  • (added) mlir/test/IR/print-retain-identifiers.mlir (+92)
diff --git a/mlir/include/mlir/IR/AsmState.h b/mlir/include/mlir/IR/AsmState.h
index 42cbedcf9f8837..9c4eadb04cdf2f 100644
--- a/mlir/include/mlir/IR/AsmState.h
+++ b/mlir/include/mlir/IR/AsmState.h
@@ -144,8 +144,7 @@ class AsmResourceBlob {
   /// Return the underlying data as an array of the given type. This is an
   /// inherrently unsafe operation, and should only be used when the data is
   /// known to be of the correct type.
-  template <typename T>
-  ArrayRef<T> getDataAs() const {
+  template <typename T> ArrayRef<T> getDataAs() const {
     return llvm::ArrayRef<T>((const T *)data.data(), data.size() / sizeof(T));
   }
 
@@ -464,8 +463,10 @@ class ParserConfig {
   /// `fallbackResourceMap` is an optional fallback handler that can be used to
   /// parse external resources not explicitly handled by another parser.
   ParserConfig(MLIRContext *context, bool verifyAfterParse = true,
-               FallbackAsmResourceMap *fallbackResourceMap = nullptr)
+               FallbackAsmResourceMap *fallbackResourceMap = nullptr,
+               bool retainIdentifierNames = false)
       : context(context), verifyAfterParse(verifyAfterParse),
+        retainIdentifierNames(retainIdentifierNames),
         fallbackResourceMap(fallbackResourceMap) {
     assert(context && "expected valid MLIR context");
   }
@@ -476,6 +477,10 @@ class ParserConfig {
   /// Returns if the parser should verify the IR after parsing.
   bool shouldVerifyAfterParse() const { return verifyAfterParse; }
 
+  /// Returns if the parser should retain identifier names collected using
+  /// parsing.
+  bool shouldRetainIdentifierNames() const { return retainIdentifierNames; }
+
   /// Returns the parsing configurations associated to the bytecode read.
   BytecodeReaderConfig &getBytecodeReaderConfig() const {
     return const_cast<BytecodeReaderConfig &>(bytecodeReaderConfig);
@@ -513,6 +518,7 @@ class ParserConfig {
 private:
   MLIRContext *context;
   bool verifyAfterParse;
+  bool retainIdentifierNames;
   DenseMap<StringRef, std::unique_ptr<AsmResourceParser>> resourceParsers;
   FallbackAsmResourceMap *fallbackResourceMap;
   BytecodeReaderConfig bytecodeReaderConfig;
diff --git a/mlir/include/mlir/Tools/mlir-opt/MlirOptMain.h b/mlir/include/mlir/Tools/mlir-opt/MlirOptMain.h
index 6e90fad1618d21..a85dca186a4f3c 100644
--- a/mlir/include/mlir/Tools/mlir-opt/MlirOptMain.h
+++ b/mlir/include/mlir/Tools/mlir-opt/MlirOptMain.h
@@ -176,6 +176,13 @@ class MlirOptMainConfig {
   /// Reproducer file generation (no crash required).
   StringRef getReproducerFilename() const { return generateReproducerFileFlag; }
 
+  /// Print the pass-pipeline as text before executing.
+  MlirOptMainConfig &retainIdentifierNames(bool retain) {
+    retainIdentifierNamesFlag = retain;
+    return *this;
+  }
+  bool shouldRetainIdentifierNames() const { return retainIdentifierNamesFlag; }
+
 protected:
   /// Allow operation with no registered dialects.
   /// This option is for convenience during testing only and discouraged in
@@ -226,6 +233,9 @@ class MlirOptMainConfig {
   /// the corresponding line. This is meant for implementing diagnostic tests.
   bool verifyDiagnosticsFlag = false;
 
+  /// Retain identifier names in the output (e.g., `%my_var` instead of `%0`).
+  bool retainIdentifierNamesFlag = false;
+
   /// Run the verifier after each transformation pass.
   bool verifyPassesFlag = true;
 
diff --git a/mlir/lib/AsmParser/Parser.cpp b/mlir/lib/AsmParser/Parser.cpp
index 00f2b0c0c2f12f..247e99e61c2c01 100644
--- a/mlir/lib/AsmParser/Parser.cpp
+++ b/mlir/lib/AsmParser/Parser.cpp
@@ -611,6 +611,10 @@ class OperationParser : public Parser {
   /// an object of type 'OperationName'. Otherwise, failure is returned.
   FailureOr<OperationName> parseCustomOperationName();
 
+  /// Store the SSA names for the current operation as attrs for debug purposes.
+  void storeSSANames(Operation *&op, ArrayRef<ResultRecord> resultIDs);
+  DenseMap<Value, StringRef> argNames;
+
   //===--------------------------------------------------------------------===//
   // Region Parsing
   //===--------------------------------------------------------------------===//
@@ -1268,6 +1272,58 @@ OperationParser::parseSuccessors(SmallVectorImpl<Block *> &destinations) {
                                       /*allowEmptyList=*/false);
 }
 
+/// Store the SSA names for the current operation as attrs for debug purposes.
+void OperationParser::storeSSANames(Operation *&op,
+                                    ArrayRef<ResultRecord> resultIDs) {
+
+  // Store the name(s) of the result(s) of this operation.
+  if (op->getNumResults() > 0) {
+    llvm::SmallVector<llvm::StringRef, 1> resultNames;
+    for (const ResultRecord &resIt : resultIDs) {
+      resultNames.push_back(std::get<0>(resIt).drop_front(1));
+      // Insert empty string for sub-results/result groups
+      for (unsigned int i = 1; i < std::get<1>(resIt); ++i)
+        resultNames.push_back(llvm::StringRef());
+    }
+    op->setDiscardableAttr("mlir.resultNames",
+                           builder.getStrArrayAttr(resultNames));
+  }
+
+  // Store the name information of the arguments of this operation.
+  if (op->getNumOperands() > 0) {
+    llvm::SmallVector<llvm::StringRef, 1> opArgNames;
+    for (auto &operand : op->getOpOperands()) {
+      auto it = argNames.find(operand.get());
+      if (it != argNames.end())
+        opArgNames.push_back(it->second.drop_front(1));
+    }
+    op->setDiscardableAttr("mlir.opArgNames",
+                           builder.getStrArrayAttr(opArgNames));
+  }
+
+  // Store the name information of the block that contains this operation.
+  Block *blockPtr = op->getBlock();
+  for (const auto &map : blocksByName) {
+    for (const auto &entry : map) {
+      if (entry.second.block == blockPtr) {
+        op->setDiscardableAttr("mlir.blockName",
+                               StringAttr::get(getContext(), entry.first));
+
+        // Store block arguments, if present
+        llvm::SmallVector<llvm::StringRef, 1> blockArgNames;
+
+        for (BlockArgument arg : blockPtr->getArguments()) {
+          auto it = argNames.find(arg);
+          if (it != argNames.end())
+            blockArgNames.push_back(it->second.drop_front(1));
+        }
+        op->setAttr("mlir.blockArgNames",
+                    builder.getStrArrayAttr(blockArgNames));
+      }
+    }
+  }
+}
+
 namespace {
 // RAII-style guard for cleaning up the regions in the operation state before
 // deleting them.  Within the parser, regions may get deleted if parsing failed,
@@ -1672,6 +1728,11 @@ class CustomOpAsmParser : public AsmParserImpl<OpAsmParser> {
                              SmallVectorImpl<Value> &result) override {
     if (auto value = parser.resolveSSAUse(operand, type)) {
       result.push_back(value);
+
+      // Optionally store argument name for debug purposes
+      if (parser.getState().config.shouldRetainIdentifierNames())
+        parser.argNames.insert({value, operand.name});
+
       return success();
     }
     return failure();
@@ -2031,6 +2092,11 @@ OperationParser::parseCustomOperation(ArrayRef<ResultRecord> resultIDs) {
 
   // Otherwise, create the operation and try to parse a location for it.
   Operation *op = opBuilder.create(opState);
+
+  // If enabled, store the SSA name(s) for the operation
+  if (state.config.shouldRetainIdentifierNames())
+    storeSSANames(op, resultIDs);
+
   if (parseTrailingLocationSpecifier(op))
     return nullptr;
 
@@ -2355,6 +2421,10 @@ ParseResult OperationParser::parseOptionalBlockArgList(Block *owner) {
           } else {
             auto loc = getEncodedSourceLocation(useInfo.location);
             arg = owner->addArgument(type, loc);
+
+            // Optionally store argument name for debug purposes
+            if (state.config.shouldRetainIdentifierNames())
+              argNames.insert({arg, useInfo.name});
           }
 
           // If the argument has an explicit loc(...) specifier, parse and apply
diff --git a/mlir/lib/IR/AsmPrinter.cpp b/mlir/lib/IR/AsmPrinter.cpp
index 6b8b7473bf0f8a..84603bb6ebfba3 100644
--- a/mlir/lib/IR/AsmPrinter.cpp
+++ b/mlir/lib/IR/AsmPrinter.cpp
@@ -72,13 +72,12 @@ OpAsmParser::~OpAsmParser() = default;
 MLIRContext *AsmParser::getContext() const { return getBuilder().getContext(); }
 
 /// Parse a type list.
-/// This is out-of-line to work-around https://github.com/llvm/llvm-project/issues/62918
+/// This is out-of-line to work-around
+/// https://github.com/llvm/llvm-project/issues/62918
 ParseResult AsmParser::parseTypeList(SmallVectorImpl<Type> &result) {
-    return parseCommaSeparatedList(
-        [&]() { return parseType(result.emplace_back()); });
-  }
-
-
+  return parseCommaSeparatedList(
+      [&]() { return parseType(result.emplace_back()); });
+}
 
 //===----------------------------------------------------------------------===//
 // DialectAsmPrinter
@@ -1299,6 +1298,11 @@ class SSANameState {
   /// conflicts, it is automatically renamed.
   StringRef uniqueValueName(StringRef name);
 
+  /// Set the original identifier names if available. Used in debugging with
+  /// `--retain-identifier-names`/`shouldRetainIdentifierNames` in ParserConfig
+  void setRetainedIdentifierNames(Operation &op,
+                                  SmallVector<int, 2> &resultGroups);
+
   /// This is the value ID for each SSA value. If this returns NameSentinel,
   /// then the valueID has an entry in valueNames.
   DenseMap<Value, unsigned> valueIDs;
@@ -1568,6 +1572,10 @@ void SSANameState::numberValuesInOp(Operation &op) {
     }
   }
 
+  // Set the original identifier names if available. Used in debugging with
+  // `--retain-identifier-names`/`shouldRetainIdentifierNames` in ParserConfig
+  setRetainedIdentifierNames(op, resultGroups);
+
   unsigned numResults = op.getNumResults();
   if (numResults == 0) {
     // If value users should be printed, operations with no result need an id.
@@ -1590,6 +1598,64 @@ void SSANameState::numberValuesInOp(Operation &op) {
   }
 }
 
+void SSANameState::setRetainedIdentifierNames(
+    Operation &op, SmallVector<int, 2> &resultGroups) {
+  // Get the original names for the results if available
+  if (ArrayAttr resultNamesAttr =
+          op.getAttrOfType<ArrayAttr>("mlir.resultNames")) {
+    auto resultNames = resultNamesAttr.getValue();
+    auto results = op.getResults();
+    // Conservative in the case that the #results has changed
+    for (size_t i = 0; i < results.size() && i < resultNames.size(); ++i) {
+      auto resultName = resultNames[i].cast<StringAttr>().strref();
+      if (!resultName.empty()) {
+        if (!usedNames.count(resultName))
+          setValueName(results[i], resultName);
+        // If a result has a name, it is the start of a result group.
+        if (i > 0)
+          resultGroups.push_back(i);
+      }
+    }
+    op.removeDiscardableAttr("mlir.resultNames");
+  }
+
+  // Get the original name for the op args if available
+  if (ArrayAttr opArgNamesAttr =
+          op.getAttrOfType<ArrayAttr>("mlir.opArgNames")) {
+    auto opArgNames = opArgNamesAttr.getValue();
+    auto opArgs = op.getOperands();
+    // Conservative in the case that the #operands has changed
+    for (size_t i = 0; i < opArgs.size() && i < opArgNames.size(); ++i) {
+      auto opArgName = opArgNames[i].cast<StringAttr>().strref();
+      if (!usedNames.count(opArgName))
+        setValueName(opArgs[i], opArgName);
+    }
+    op.removeDiscardableAttr("mlir.opArgNames");
+  }
+
+  // Get the original name for the block if available
+  if (StringAttr blockNameAttr =
+          op.getAttrOfType<StringAttr>("mlir.blockName")) {
+    blockNames[op.getBlock()] = {-1, blockNameAttr.strref()};
+    op.removeDiscardableAttr("mlir.blockName");
+  }
+
+  // Get the original name for the block args if available
+  if (ArrayAttr blockArgNamesAttr =
+          op.getAttrOfType<ArrayAttr>("mlir.blockArgNames")) {
+    auto blockArgNames = blockArgNamesAttr.getValue();
+    auto blockArgs = op.getBlock()->getArguments();
+    // Conservative in the case that the #args has changed
+    for (size_t i = 0; i < blockArgs.size() && i < blockArgNames.size(); ++i) {
+      auto blockArgName = blockArgNames[i].cast<StringAttr>().strref();
+      if (!usedNames.count(blockArgName))
+        setValueName(blockArgs[i], blockArgName);
+    }
+    op.removeDiscardableAttr("mlir.blockArgNames");
+  }
+  return;
+}
+
 void SSANameState::getResultIDAndNumber(
     OpResult result, Value &lookupValue,
     std::optional<int> &lookupResultNo) const {
diff --git a/mlir/lib/Tools/mlir-opt/MlirOptMain.cpp b/mlir/lib/Tools/mlir-opt/MlirOptMain.cpp
index 5395aa2b502d78..c4482435861590 100644
--- a/mlir/lib/Tools/mlir-opt/MlirOptMain.cpp
+++ b/mlir/lib/Tools/mlir-opt/MlirOptMain.cpp
@@ -149,6 +149,11 @@ struct MlirOptMainConfigCLOptions : public MlirOptMainConfig {
         cl::desc("Round-trip the IR after parsing and ensure it succeeds"),
         cl::location(verifyRoundtripFlag), cl::init(false));
 
+    static cl::opt<bool, /*ExternalStorage=*/true> retainIdentifierNames(
+        "retain-identifier-names",
+        cl::desc("Retain the original names of identifiers when printing"),
+        cl::location(retainIdentifierNamesFlag), cl::init(false));
+
     static cl::list<std::string> passPlugins(
         "load-pass-plugin", cl::desc("Load passes from plugin library"));
 
@@ -359,8 +364,9 @@ performActions(raw_ostream &os,
   // untouched.
   PassReproducerOptions reproOptions;
   FallbackAsmResourceMap fallbackResourceMap;
-  ParserConfig parseConfig(context, /*verifyAfterParse=*/true,
-                           &fallbackResourceMap);
+  ParserConfig parseConfig(
+      context, /*verifyAfterParse=*/true, &fallbackResourceMap,
+      /*retainIdentifierName=*/config.shouldRetainIdentifierNames());
   if (config.shouldRunReproducer())
     reproOptions.attachResourceParser(parseConfig);
 
diff --git a/mlir/test/IR/print-retain-identifiers.mlir b/mlir/test/IR/print-retain-identifiers.mlir
new file mode 100644
index 00000000000000..b3e4f075b3936a
--- /dev/null
+++ b/mlir/test/IR/print-retain-identifiers.mlir
@@ -0,0 +1,92 @@
+// RUN: mlir-opt -retain-identifier-names %s | FileCheck %s
+
+
+//===----------------------------------------------------------------------===//
+// Test SSA results (with single return values)
+//===----------------------------------------------------------------------===//
+
+// CHECK: func.func @add_one(%my_input: f64, %arg1: f64) -> f64 {
+func.func @add_one(%my_input: f64, %arg1: f64) -> f64 {
+  // CHECK: %my_constant = arith.constant 1.000000e+00 : f64
+  %my_constant = arith.constant 1.000000e+00 : f64
+  // CHECK: %my_output = arith.addf %my_input, %my_constant : f64
+  %my_output = arith.addf %my_input, %my_constant : f64
+  // CHECK: return %my_output : f64
+  return %my_output : f64
+}
+
+
+// -----
+
+//===----------------------------------------------------------------------===//
+// Test basic blocks and their arguments
+//===----------------------------------------------------------------------===//
+
+func.func @simple(i64, i1) -> i64 {
+^bb_alpha(%a: i64, %cond: i1):
+  // CHECK: cf.cond_br %cond, ^bb_beta, ^bb_gamma
+  cf.cond_br %cond, ^bb_beta, ^bb_gamma
+
+// CHECK: ^bb_beta:  // pred: ^bb_alpha
+^bb_beta:
+  // CHECK: cf.br ^bb_delta(%a : i64)
+  cf.br ^bb_delta(%a: i64)
+
+// CHECK: ^bb_gamma:  // pred: ^bb_alpha
+^bb_gamma:
+  // CHECK: %b = arith.addi %a, %a : i64
+  %b = arith.addi %a, %a : i64
+  // CHECK: cf.br ^bb_delta(%b : i64)
+  cf.br ^bb_delta(%b: i64)
+
+// CHECK: ^bb_delta(%c: i64):  // 2 preds: ^bb_gamma, ^bb_beta
+^bb_delta(%c: i64):
+  // CHECK: cf.br ^bb_eps(%c, %a : i64, i64)
+  cf.br ^bb_eps(%c, %a : i64, i64)
+
+// CHECK: ^bb_eps(%d: i64, %e: i64):  // pred: ^bb_delta
+^bb_eps(%d : i64, %e : i64):
+  // CHECK: %f = arith.addi %d, %e : i64
+  %f = arith.addi %d, %e : i64
+  return %f : i64
+}
+
+// -----
+
+//===----------------------------------------------------------------------===//
+// Test multiple return values
+//===----------------------------------------------------------------------===//
+
+func.func @select_min_max(%a: f64, %b: f64) -> (f64, f64) {
+  %gt = arith.cmpf "ogt", %a, %b : f64
+  // CHECK: %min, %max = scf.if %gt -> (f64, f64) {
+  %min, %max = scf.if %gt -> (f64, f64) {
+    scf.yield %b, %a : f64, f64
+  } else {
+    scf.yield %a, %b : f64, f64
+  }
+  // CHECK: return %min, %max : f64, f64
+  return %min, %max : f64, f64
+}
+
+// -----
+
+////===----------------------------------------------------------------------===//
+// Test multiple return values, with a grouped value tuple
+//===----------------------------------------------------------------------===//
+
+func.func @select_max(%a: f64, %b: f64, %c: f64, %d: f64) -> (f64, f64, f64, f64) {
+  // Find the max between %a and %b,
+  // with %c and %d being other values that are returned.
+  %gt = arith.cmpf "ogt", %a, %b : f64
+  // CHECK: %max, %others:2, %alt = scf.if %gt -> (f64, f64, f64, f64) {
+  %max, %others:2, %alt  = scf.if %gt -> (f64, f64, f64, f64) {
+    scf.yield %b, %a, %c, %d : f64, f64, f64, f64
+  } else {
+    scf.yield %a, %b, %d, %c : f64, f64, f64, f64
+  }
+  // CHECK: return %max, %others#0, %others#1, %alt : f64, f64, f64, f64
+  return %max, %others#0, %others#1, %alt : f64, f64, f64, f64
+}
+
+// -----

Copy link

github-actions bot commented Jan 27, 2024

✅ With the latest revision this PR passed the C/C++ code formatter.

Copy link
Collaborator

@joker-eph joker-eph left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is a fairly straightforward solution, however the requirements aren't super clear to me to understand if this is really the right solution (for example the interactions with the AsmOpInterface can be subtle).

I also have concerns about using the dictionary of attributes here: the main purpose of the textual IR is to be faithful to the in-memory representation, and this is breaking this contract.

@Wheest
Copy link
Contributor Author

Wheest commented Jan 28, 2024

Thanks for the review! I've tried to enumerate some requirements gathered from the discourse discussion and my own opinions, see below.

First though, regarding:

The main purpose of the textual IR is to be faithful to the in-memory representation, and this is breaking this contract.

This is a valid concern. I'll try and advocate for this being okay and/or not unprecedented:

  • The in-memory representation also includes the location information of the operation, but the textual IR also does not print this by default
  • Right now I am setting the names as "discardable attributes", which may be consumed by passes, so discarding the attributes as a "final pass" at print time is not entirely new behaviour.
  • Additionally, since the default anonymous names are generated at print time, using a different naming scheme doesn't misrepresent the IR
  • Although a given dialect or transformation could in principle do anything with the attribute dict, my intuition is that a well-formed dialect should either ignore or drop attributes it is not familiar with. Chris Lattner said here "While it may be imperfect MLIR officially sanctions installing “other dialect attributes” on ops."

I could add an additional flag such as --no-attr-dict-cleanup, so that users could optionally print the unmodified attr dicts (i.e., make my op.removeDiscardableAttr("mlir.resultNames"); lines optional). This could also be helpful for diagnosing unexpected namings generated by this pass.

This would generate IR which looks like:

IR with no attr dict cleanup

module {
  func.func @add_one(%my_input: f64, %arg1: f64) -> f64 {
    %my_constant = arith.constant {mlir.resultNames = ["my_constant"]} 1.000000e+00 : f64
    %my_output = arith.addf %my_input, %my_constant {mlir.opArgNames = ["my_input", "my_constant"], mlir.resultNames = ["my_output"]} : f64
    return {mlir.opArgNames = ["my_output"]} %my_output : f64
  }
}

Secondly, regarding:

the interactions with the AsmOpInterface can be subtle

Yes, although the function that actually sets the names is a lambda which is passed by the caller to the operation (e.g., the lambda in AsmPrinter we care about). This means that although we can't control what the operation does in its implementation, we can control what we do with the string it gives us.

Your comment makes me realise that I should add test cases where our chosen names might clash with the names chosen by a AsmOpInterface. I'll add that in a future commit.

Requirements

Legend
✅: I think we cover this
🟠: there could be an issue or improvement made. Not a blocker to merging
⛔: not currently covered, blocker to merging

Functionality

  1. 🟠 Identifier names: SSA names, basic blocks names, etc, should be retained when IR is printed, rather than giving anonymous names
  2. ✅ Optional feature: this feature should not be default behaviour (it is enabled with a ParserConfig flag)
  3. 🟠 Stable to errors: passes mean that operations and blocks can be mutated/added/deleted and the name list will drift. Misnaming may be acceptable, but crashes are not

Performance

  1. ✅ No additional overhead introduced to critical path: since we use an attribute dict, no overhead introduced in core IR. Textual parsing and printing has small cost but this is not performance critical (advocated here)

Maintainability

  1. ✅ Low-code footprint: this solution has a small code footprint, thus the logic is quite easy to read, update, or replace
  2. 🟠 Tested: the solution should be incorporated into the main test harness (e.g., the FileCheck MLIR tests)

🟠: Issues to discuss or address

Functionality

  • 1] Identifier names - we currently don't cover function arguments that are not used (since their names are stored in Operations which use them). This is only marginally useful - it an arg isn't used, who cares? Still, if I can support it I will
  • 3] Stable to errors - I try to cover this with my compound condition for loops in SSANameState::setRetainedIdentifierNames (i.e., we take the minimum of the arguments and names we have). However perhaps this could be better tested by applying some passes which change the number of arguments. I'm unsure of a good pass to apply.

Maintainability

  • 2] Expand test coverage: Development of more test cases, ideally to test system behaviour under various transformation passes. Open to suggestions for specific scenarios that should be covered.

@Wheest
Copy link
Contributor Author

Wheest commented Jan 29, 2024

I've managed to fix the issue around unused func.func arguments by handling them explicitly.

Reading the Func dialect docs, they say that;

While the MLIR textual form provides a nice inline syntax for function arguments, they are internally represented as “block arguments” to the first block in the region.

Therefore, I handle this case in a generic way with mlir.regionArgNames.

Also, following your suggestion around the how AsmOpInterface behaves, I added a test case where I set names that would explicitly clash with an contradict default anonymous names (e.g. %arg0, %cst_1, %1, etc).

@Wheest Wheest requested a review from joker-eph January 29, 2024 13:13
Copy link
Collaborator

@joker-eph joker-eph left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I would want to see at least a design for another more principled approach before possibly considering this one.

At the moment I'm more on the side of "this feature isn't worth the drawback of polluting the IR".

@Wheest
Copy link
Contributor Author

Wheest commented Jan 30, 2024

I would want to see at least a design for another more principled approach before possibly considering this one.

At the moment I'm more on the side of "this feature isn't worth the drawback of polluting the IR".

Excluding adding a new member variable to the Value class (which adds an unacceptable constant overhead), I think there are two approaches that are worth considering.

1. Location based approach

Would not require changes to the IR, however it would require re-parsing, which could add a lot more complexity (e.g., how do we reassociate a newly parsed Value and its name with the original Value?)

There are also clients of location info, for whom this feature could cause problems.

2. Use of external data structure

We could create an external data structure, e.g., a DenseMap<Value, StringRef>. I did this in an initial implementation, but made the owner MLIRContext, which is not desirable since that's not really what MLIRContext is for.

However, typically (e.g., in mlir-opt) the parser and printer are in the same scope, so we could make the map user-controlled, and (optionally) pass it to the Parser (which populates the map), and then to the Printer (which reads the map and updates the names accordingly).

We have something similar to this already, the FallbackAsmResourceMap, which is passed to both the parser and printer. I'm not sure if this would be the best place to put this map, or just to handle it separately. It seems FallbackAsmResourceMap is used in BytecodeWriter (to leverage its printers). Even adding an empty map to the Bytecode system is not attractive to me since I believe it is more performance critical.

A direct Value mapping approach also has the advantage compared to the attribute dict approach of reducing ambiguity in the case of IR transformation. If we introduce new arguments or operations, unless the Value is the same, we don't touch the name.


Personally, I think option 2. is a more straightforward and less intrusive solution. If we made the map a user-managed object (e.g., by mlir-opt's performActions function, then it keeps our requirement of only adding overhead to the AsmParser and AsmPrinter.

Regarding the attribute dictionary approach, the only situation when the IR would be polluted is when this flag is enabled, which limits its impact.

@jpienaar
Copy link
Member

jpienaar commented Feb 2, 2024

One thing I was wondering about here (and sorry haven't had time to follow the other discussion - also sounded like there was some confusion with what I said there as I saw 2 uses and meant different comments for each). For 2 have we considered enabling creating an AsmState from an AsmParserState? That would handle this right (its in the 2nd approach model). It would not handle tracking mutations - but then again mutation tracking would be duplicated via other means. It may require allowing mutations of AsmState ... so potentially a separate mapping may be better.

@Wheest
Copy link
Contributor Author

Wheest commented Feb 20, 2024

For 2 have we considered enabling creating an AsmState from an AsmParserState?

Thanks @jpienaar, I hadn't considered that, but it looks like AsmState could be a good way to manage the name data.
Looking at the description of AsmState:

The following classes enable support for parsing and printing resources
within MLIR assembly formats. Resources are a mechanism by which dialects,
and external clients, may attach additional information when parsing or
printing IR without that information being encoded in the IR itself.
Resources are not uniqued within the MLIR context, are not attached directly
to any operation, and are solely intended to live and be processed outside
of the immediate IR.

That looks like a good way to store this data, rather than letting mlir-opt handle it as I suggested.

It may require allowing mutations of AsmState ... so potentially a separate mapping may be better.

There is a member variable for AsmResourceBlob, dataIsMutable, which could mean this is already possible. But I'm not that familiar with how AsmState behaves, so perhaps that's not possible.

Regarding making an AsmState from AsmParserState, I think it would be simpler to just create an AsmState from a DenseMap<Value, StringRef> (though I think due to how AsmResourceBlob works, it would need to be serialized).

We could have something akin to DenseMap<Value, AsmParserState>, but I'm not sure what benefit that would bring. We could recover the StringRef for the identifiers by looking at the location information in SMDefinition. Perhaps the extra data would make it more extenisble? But this extra processing of AsmParserState at print time would add complexity that this feature doesn't really need.

Thoughts on this design? I'll try making another version that has the external data structure, first managed by mlir-opt, and then stored as an AsmState.

@makslevental
Copy link
Contributor

Can we iterate on this? Preserving legible SSA names into the parse and even through passes is immensely valuable/helpful for debugging everything. Note, I don't understand the relationship to formatters so what follows doesn't address that use case.

I propose we store the identifier in the Location info as a NameLoc. As far as I can tell this avoids the issue of polluting the IR.

I prototyped this using @Wheest's current PR (a few small changes) and I got this:

Source IR:

#loc = loc("triton/python/examples/empty.py":17:0)
module {
  tt.func public @add_kernel(
    %in_ptr0: !tt.ptr<f32> {tt.divisibility = 16 : i32, tt.pointer_range = 32 : i32} loc("triton/python/examples/empty.py":17:0),
    %arg1: !tt.ptr<f32> {tt.divisibility = 16 : i32, tt.pointer_range = 32 : i32} loc("triton/python/examples/empty.py":17:0),
    %arg2: !tt.ptr<f32> {tt.divisibility = 16 : i32, tt.pointer_range = 32 : i32} loc("triton/python/examples/empty.py":17:0)) -> tensor<1024xf32> attributes {noinline = false} {
    %c1024_i32 = arith.constant 1024 : i32 loc(#loc1)
    %pid = tt.get_program_id x : i32 loc(#loc2)
    %block_start = arith.muli %pid, %c1024_i32 : i32 loc(#loc3)
    %make_range = tt.make_range {end = 1024 : i32, start = 0 : i32} : tensor<1024xi32> loc(#loc4)
    %block_start_splat = tt.splat %block_start : i32 -> tensor<1024xi32> loc(#loc5)
    %offsets = arith.addi %block_start_splat, %make_range : tensor<1024xi32> loc(#loc5)
    %in_ptr0_splat = tt.splat %in_ptr0 : !tt.ptr<f32> -> tensor<1024x!tt.ptr<f32>> loc(#loc6)
    %addr = tt.addptr %in_ptr0_splat, %offsets : tensor<1024x!tt.ptr<f32>>, tensor<1024xi32> loc(#loc6)
    %val = tt.load %addr : tensor<1024x!tt.ptr<f32>> loc(#loc7)
    tt.return %val : tensor<1024xf32> loc(#loc8)
  } loc(#loc)
} loc(#loc)
#loc1 = loc(unknown)
#loc2 = loc("triton/python/examples/empty.py":24:24)
#loc3 = loc("triton/python/examples/empty.py":25:24)
#loc4 = loc("triton/python/examples/empty.py":26:41)
#loc5 = loc("triton/python/examples/empty.py":26:28)
#loc6 = loc("triton/python/examples/empty.py":27:26)
#loc7 = loc("triton/python/examples/empty.py":27:16)
#loc8 = loc("triton/python/examples/empty.py":29:11)

Just parsing:

#loc = triton/python/examples/empty.py:17:0
module {
  tt.func public @add_kernel(%arg0: !tt.ptr<f32> {tt.divisibility = 16 : i32, tt.pointer_range = 32 : i32} triton/python/examples/empty.py:17:0, %arg1: !tt.ptr<f32> {tt.divisibility = 16 : i32, tt.pointer_range = 32 : i32} triton/python/examples/empty.py:17:0, %arg2: !tt.ptr<f32> {tt.divisibility = 16 : i32, tt.pointer_range = 32 : i32} triton/python/examples/empty.py:17:0) -> tensor<1024xf32> attributes {noinline = false} {
    %c1024_i32 = arith.constant 1024 : i32 "c1024_i32"
    %0 = tt.get_program_id x : i32 "pid"(#loc2)
    %1 = arith.muli %0, %c1024_i32 : i32 "block_start"(#loc3)
    %2 = tt.make_range {end = 1024 : i32, start = 0 : i32} : tensor<1024xi32> "make_range"(#loc4)
    %3 = tt.splat %1 : i32 -> tensor<1024xi32> "block_start_splat"(#loc5)
    %4 = arith.addi %3, %2 : tensor<1024xi32> "offsets"(#loc5)
    %5 = tt.splat %arg0 : !tt.ptr<f32> -> tensor<1024x!tt.ptr<f32>> "in_ptr0_splat"(#loc6)
    %6 = tt.addptr %5, %4 : tensor<1024x!tt.ptr<f32>>, tensor<1024xi32> "addr"(#loc6)
    %7 = tt.load %6 : tensor<1024x!tt.ptr<f32>> "val"(#loc7)
    tt.return %7 : tensor<1024xf32> ""(#loc8)
  } triton/python/examples/empty.py:17:0
} triton/python/examples/empty.py:17:0
#loc1 = "c1024_i32"
#loc2 = triton/python/examples/empty.py:24:24
#loc3 = triton/python/examples/empty.py:25:24
#loc4 = triton/python/examples/empty.py:26:41
#loc5 = triton/python/examples/empty.py:26:28
#loc6 = triton/python/examples/empty.py:27:26
#loc7 = triton/python/examples/empty.py:27:16
#loc8 = triton/python/examples/empty.py:29:11

After some passes:

#loc = triton/python/examples/empty.py:17:0
module {
  tt.func public @add_kernel(%arg0: !tt.ptr<f32> {tt.divisibility = 16 : i32, tt.pointer_range = 32 : i32} triton/python/examples/empty.py:17:0, %arg1: !tt.ptr<f32> {tt.divisibility = 16 : i32, tt.pointer_range = 32 : i32} triton/python/examples/empty.py:17:0, %arg2: !tt.ptr<f32> {tt.divisibility = 16 : i32, tt.pointer_range = 32 : i32} triton/python/examples/empty.py:17:0) -> tensor<1024xf32> attributes {noinline = false, rewritten} {
    %c0_i32 = arith.constant 0 : i32 triton/python/examples/empty.py:17:0
    %0 = builtin.unrealized_conversion_cast %arg0, %c0_i32 : !tt.ptr<f32>, i32 to !tt.ptr<f32> triton/python/examples/empty.py:17:0
    %c0_i32_0 = arith.constant 0 : i32 triton/python/examples/empty.py:17:0
    %1 = builtin.unrealized_conversion_cast %arg1, %c0_i32_0 : !tt.ptr<f32>, i32 to !tt.ptr<f32> triton/python/examples/empty.py:17:0
    %c0_i32_1 = arith.constant 0 : i32 triton/python/examples/empty.py:17:0
    %2 = builtin.unrealized_conversion_cast %arg2, %c0_i32_1 : !tt.ptr<f32>, i32 to !tt.ptr<f32> triton/python/examples/empty.py:17:0
    %c1024_i32 = arith.constant 1024 : i32 "c1024_i32"
    %3 = tt.get_program_id x : i32 "pid"(#loc2)
    %4 = arith.muli %3, %c1024_i32 : i32 "block_start"(#loc3)
    %5 = tt.make_range {end = 1024 : i32, start = 0 : i32} : tensor<1024xi32> "make_range"(#loc4)
    %6 = tt.splat %4 : i32 -> tensor<1024xi32> "block_start_splat"(#loc5)
    %7 = arith.addi %6, %5 : tensor<1024xi32> "offsets"(#loc5)
    %8 = tt.splat %c0_i32 : i32 -> tensor<1024xi32> "in_ptr0_splat"(#loc6)
    %cst = arith.constant dense<0> : tensor<1024xi32> "addr"(#loc6)
    %c0_i32_2 = arith.constant 0 : i32 "addr"(#loc6)
    %9 = arith.addi %4, %c0_i32_2 : i32 "addr"(#loc6)
    %10 = arith.addi %cst, %5 : tensor<1024xi32> "addr"(#loc6)
    %11 = arith.addi %10, %8 : tensor<1024xi32> "addr"(#loc6)
    %12 = tt.addptr %0, %9 : !tt.ptr<f32>, i32 "addr"(#loc6)
    %13 = tt.splat %12 {rewritten} : !tt.ptr<f32> -> tensor<1024x!tt.ptr<f32>> "val"(#loc7)
    %14 = tt.addptr %13, %11 {rewritten} : tensor<1024x!tt.ptr<f32>>, tensor<1024xi32> "val"(#loc7)
    %15 = tt.load %14 {rewritten} : tensor<1024x!tt.ptr<f32>> "val"(#loc7)
    tt.return %15 : tensor<1024xf32> ""(#loc8)
  } triton/python/examples/empty.py:17:0
} triton/python/examples/empty.py:17:0
#loc1 = "c1024_i32"
#loc2 = triton/python/examples/empty.py:24:24
#loc3 = triton/python/examples/empty.py:25:24
#loc4 = triton/python/examples/empty.py:26:41
#loc5 = triton/python/examples/empty.py:26:28
#loc6 = triton/python/examples/empty.py:27:26
#loc7 = triton/python/examples/empty.py:27:16
#loc8 = triton/python/examples/empty.py:29:11

I'll point out what's nice here is that %addr propagates to all of the created/inserted ops which had addPtrOp.getLoc() passed to them:

%cst = arith.constant dense<0> : tensor<1024xi32> "addr"(#loc6)
%c0_i32_2 = arith.constant 0 : i32 "addr"(#loc6)
%9 = arith.addi %4, %c0_i32_2 : i32 "addr"(#loc6)
%10 = arith.addi %cst, %5 : tensor<1024xi32> "addr"(#loc6)
%11 = arith.addi %10, %8 : tensor<1024xi32> "addr"(#loc6)
%12 = tt.addptr %0, %9 : !tt.ptr<f32>, i32 "addr"(#loc6)

Note this example started with debug info in the source and changes I currently have are compatible with that but they of course also work if the original source doesn't have any info:

module {
  tt.func public @add_kernel(%arg0: !tt.ptr<f32> {tt.divisibility = 16 : i32, tt.pointer_range = 32 : i32} within split...
    %c1024_i32 = arith.constant 1024 : i32 "c1024_i32"(within split...)
    %0 = tt.get_program_id x : i32 "pid"(within split...)
    %1 = arith.muli %0, %c1024_i32 : i32 "block_start"(within split...)
    %2 = tt.make_range {end = 1024 : i32, start = 0 : i32} : tensor<1024xi32> "make_range"(within split...)
    %3 = tt.splat %1 : i32 -> tensor<1024xi32> "block_start_splat"(within split...)
    %4 = arith.addi %3, %2 : tensor<1024xi32> "offsets"(within split...)
    %5 = tt.splat %arg0 : !tt.ptr<f32> -> tensor<1024x!tt.ptr<f32>> "in_ptr0_splat"(within split...)
    %6 = tt.addptr %5, %4 : tensor<1024x!tt.ptr<f32>>, tensor<1024xi32> "addr"(within split...)
    %7 = tt.load %6 : tensor<1024x!tt.ptr<f32>> "val"(within split...)
    tt.return %7 : tensor<1024xf32> within split...
  } within split...
} within split...

and

module {
  tt.func public @add_kernel(%arg0: !tt.ptr<f32> {tt.divisibility = 16 : i32, tt.pointer_range = 32 : i32} within split...
    %c0_i32 = arith.constant 0 : i32 within split...
    %0 = builtin.unrealized_conversion_cast %arg0, %c0_i32 : !tt.ptr<f32>, i32 to !tt.ptr<f32> within split...
    %c0_i32_0 = arith.constant 0 : i32 within split...
    %1 = builtin.unrealized_conversion_cast %arg1, %c0_i32_0 : !tt.ptr<f32>, i32 to !tt.ptr<f32> within split...
    %c0_i32_1 = arith.constant 0 : i32 within split...
    %2 = builtin.unrealized_conversion_cast %arg2, %c0_i32_1 : !tt.ptr<f32>, i32 to !tt.ptr<f32> within split...
    %c1024_i32 = arith.constant 1024 : i32 "c1024_i32"(within split...)
    %3 = tt.get_program_id x : i32 "pid"(within split...)
    %4 = arith.muli %3, %c1024_i32 : i32 "block_start"(within split...)
    %5 = tt.make_range {end = 1024 : i32, start = 0 : i32} : tensor<1024xi32> "make_range"(within split...)
    %6 = tt.splat %4 : i32 -> tensor<1024xi32> "block_start_splat"(within split...)
    %7 = arith.addi %6, %5 : tensor<1024xi32> "offsets"(within split...)
    %8 = tt.splat %c0_i32 : i32 -> tensor<1024xi32> "in_ptr0_splat"(within split...)
    %cst = arith.constant dense<0> : tensor<1024xi32> "addr"(within split...)
    %c0_i32_2 = arith.constant 0 : i32 "addr"(within split...)
    %9 = arith.addi %4, %c0_i32_2 : i32 "addr"(within split...)
    %10 = arith.addi %cst, %5 : tensor<1024xi32> "addr"(within split...)
    %11 = arith.addi %10, %8 : tensor<1024xi32> "addr"(within split...)
    %12 = tt.addptr %0, %9 : !tt.ptr<f32>, i32 "addr"(within split...)
    %13 = tt.splat %12 {rewritten} : !tt.ptr<f32> -> tensor<1024x!tt.ptr<f32>> "val"(within split...)
    %14 = tt.addptr %13, %11 {rewritten} : tensor<1024x!tt.ptr<f32>>, tensor<1024xi32> "val"(within split...)
    %15 = tt.load %14 {rewritten} : tensor<1024x!tt.ptr<f32>> "val"(within split...)
    tt.return %15 : tensor<1024xf32> within split...
  } within split...
} within split...

@renxida
Copy link

renxida commented Dec 12, 2024

Being able to retain the original name of each ssa value would be so nice for debugging

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
mlir:core MLIR Core Infrastructure mlir
Projects
None yet
Development

Successfully merging this pull request may close these issues.

6 participants