-
Notifications
You must be signed in to change notification settings - Fork 14.3k
[mlir] Retain original identifier names for debugging #79704
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Conversation
@llvm/pr-subscribers-mlir @llvm/pr-subscribers-mlir-core Author: Perry Gibson (Wheest) ChangesThis PR implements retention of MLIR identifier names (e.g., A wider discussion of this feature is in this discourse thread. A motivating example is that right now, IR generation drops all meaningful identifier names, which could be useful for developers trying to understand their passes, or for other tooling (e.g., MLIR code formatters). func.func @<!-- -->add_one(%my_input: f64) -> f64 {
%my_constant = arith.constant 1.00000e+00 : f64
%my_output = arith.addf %my_input, %my_constant : f64
return %my_output : f64
} ⬇️ becomes func.func @<!-- -->add_one(%arg0: f64) -> f64 {
%cst = arith.constant 1.000000e+00 : f64
%0 = arith.addf %arg0, %cst : f64
return %0 : f64
} The solution this PR implements is to store this metadata inside the attribute dictionary of operations, under a special namespace (e.g., Alternative solutions, such as adding a string field to the I've implemented some initial test cases in This covers things such as:
A case that I know not to work is when a I do not have test cases for how the system handles code transformations, and am open to suggestions for additional tests to include. Also note that this is my most substantial contribution to the codebase to date, so I may require some shepherding with regards to coding style or use of core LLVM library constructs. Full diff: https://github.com/llvm/llvm-project/pull/79704.diff 6 Files Affected:
diff --git a/mlir/include/mlir/IR/AsmState.h b/mlir/include/mlir/IR/AsmState.h
index 42cbedcf9f8837..9c4eadb04cdf2f 100644
--- a/mlir/include/mlir/IR/AsmState.h
+++ b/mlir/include/mlir/IR/AsmState.h
@@ -144,8 +144,7 @@ class AsmResourceBlob {
/// Return the underlying data as an array of the given type. This is an
/// inherrently unsafe operation, and should only be used when the data is
/// known to be of the correct type.
- template <typename T>
- ArrayRef<T> getDataAs() const {
+ template <typename T> ArrayRef<T> getDataAs() const {
return llvm::ArrayRef<T>((const T *)data.data(), data.size() / sizeof(T));
}
@@ -464,8 +463,10 @@ class ParserConfig {
/// `fallbackResourceMap` is an optional fallback handler that can be used to
/// parse external resources not explicitly handled by another parser.
ParserConfig(MLIRContext *context, bool verifyAfterParse = true,
- FallbackAsmResourceMap *fallbackResourceMap = nullptr)
+ FallbackAsmResourceMap *fallbackResourceMap = nullptr,
+ bool retainIdentifierNames = false)
: context(context), verifyAfterParse(verifyAfterParse),
+ retainIdentifierNames(retainIdentifierNames),
fallbackResourceMap(fallbackResourceMap) {
assert(context && "expected valid MLIR context");
}
@@ -476,6 +477,10 @@ class ParserConfig {
/// Returns if the parser should verify the IR after parsing.
bool shouldVerifyAfterParse() const { return verifyAfterParse; }
+ /// Returns if the parser should retain identifier names collected using
+ /// parsing.
+ bool shouldRetainIdentifierNames() const { return retainIdentifierNames; }
+
/// Returns the parsing configurations associated to the bytecode read.
BytecodeReaderConfig &getBytecodeReaderConfig() const {
return const_cast<BytecodeReaderConfig &>(bytecodeReaderConfig);
@@ -513,6 +518,7 @@ class ParserConfig {
private:
MLIRContext *context;
bool verifyAfterParse;
+ bool retainIdentifierNames;
DenseMap<StringRef, std::unique_ptr<AsmResourceParser>> resourceParsers;
FallbackAsmResourceMap *fallbackResourceMap;
BytecodeReaderConfig bytecodeReaderConfig;
diff --git a/mlir/include/mlir/Tools/mlir-opt/MlirOptMain.h b/mlir/include/mlir/Tools/mlir-opt/MlirOptMain.h
index 6e90fad1618d21..a85dca186a4f3c 100644
--- a/mlir/include/mlir/Tools/mlir-opt/MlirOptMain.h
+++ b/mlir/include/mlir/Tools/mlir-opt/MlirOptMain.h
@@ -176,6 +176,13 @@ class MlirOptMainConfig {
/// Reproducer file generation (no crash required).
StringRef getReproducerFilename() const { return generateReproducerFileFlag; }
+ /// Print the pass-pipeline as text before executing.
+ MlirOptMainConfig &retainIdentifierNames(bool retain) {
+ retainIdentifierNamesFlag = retain;
+ return *this;
+ }
+ bool shouldRetainIdentifierNames() const { return retainIdentifierNamesFlag; }
+
protected:
/// Allow operation with no registered dialects.
/// This option is for convenience during testing only and discouraged in
@@ -226,6 +233,9 @@ class MlirOptMainConfig {
/// the corresponding line. This is meant for implementing diagnostic tests.
bool verifyDiagnosticsFlag = false;
+ /// Retain identifier names in the output (e.g., `%my_var` instead of `%0`).
+ bool retainIdentifierNamesFlag = false;
+
/// Run the verifier after each transformation pass.
bool verifyPassesFlag = true;
diff --git a/mlir/lib/AsmParser/Parser.cpp b/mlir/lib/AsmParser/Parser.cpp
index 00f2b0c0c2f12f..247e99e61c2c01 100644
--- a/mlir/lib/AsmParser/Parser.cpp
+++ b/mlir/lib/AsmParser/Parser.cpp
@@ -611,6 +611,10 @@ class OperationParser : public Parser {
/// an object of type 'OperationName'. Otherwise, failure is returned.
FailureOr<OperationName> parseCustomOperationName();
+ /// Store the SSA names for the current operation as attrs for debug purposes.
+ void storeSSANames(Operation *&op, ArrayRef<ResultRecord> resultIDs);
+ DenseMap<Value, StringRef> argNames;
+
//===--------------------------------------------------------------------===//
// Region Parsing
//===--------------------------------------------------------------------===//
@@ -1268,6 +1272,58 @@ OperationParser::parseSuccessors(SmallVectorImpl<Block *> &destinations) {
/*allowEmptyList=*/false);
}
+/// Store the SSA names for the current operation as attrs for debug purposes.
+void OperationParser::storeSSANames(Operation *&op,
+ ArrayRef<ResultRecord> resultIDs) {
+
+ // Store the name(s) of the result(s) of this operation.
+ if (op->getNumResults() > 0) {
+ llvm::SmallVector<llvm::StringRef, 1> resultNames;
+ for (const ResultRecord &resIt : resultIDs) {
+ resultNames.push_back(std::get<0>(resIt).drop_front(1));
+ // Insert empty string for sub-results/result groups
+ for (unsigned int i = 1; i < std::get<1>(resIt); ++i)
+ resultNames.push_back(llvm::StringRef());
+ }
+ op->setDiscardableAttr("mlir.resultNames",
+ builder.getStrArrayAttr(resultNames));
+ }
+
+ // Store the name information of the arguments of this operation.
+ if (op->getNumOperands() > 0) {
+ llvm::SmallVector<llvm::StringRef, 1> opArgNames;
+ for (auto &operand : op->getOpOperands()) {
+ auto it = argNames.find(operand.get());
+ if (it != argNames.end())
+ opArgNames.push_back(it->second.drop_front(1));
+ }
+ op->setDiscardableAttr("mlir.opArgNames",
+ builder.getStrArrayAttr(opArgNames));
+ }
+
+ // Store the name information of the block that contains this operation.
+ Block *blockPtr = op->getBlock();
+ for (const auto &map : blocksByName) {
+ for (const auto &entry : map) {
+ if (entry.second.block == blockPtr) {
+ op->setDiscardableAttr("mlir.blockName",
+ StringAttr::get(getContext(), entry.first));
+
+ // Store block arguments, if present
+ llvm::SmallVector<llvm::StringRef, 1> blockArgNames;
+
+ for (BlockArgument arg : blockPtr->getArguments()) {
+ auto it = argNames.find(arg);
+ if (it != argNames.end())
+ blockArgNames.push_back(it->second.drop_front(1));
+ }
+ op->setAttr("mlir.blockArgNames",
+ builder.getStrArrayAttr(blockArgNames));
+ }
+ }
+ }
+}
+
namespace {
// RAII-style guard for cleaning up the regions in the operation state before
// deleting them. Within the parser, regions may get deleted if parsing failed,
@@ -1672,6 +1728,11 @@ class CustomOpAsmParser : public AsmParserImpl<OpAsmParser> {
SmallVectorImpl<Value> &result) override {
if (auto value = parser.resolveSSAUse(operand, type)) {
result.push_back(value);
+
+ // Optionally store argument name for debug purposes
+ if (parser.getState().config.shouldRetainIdentifierNames())
+ parser.argNames.insert({value, operand.name});
+
return success();
}
return failure();
@@ -2031,6 +2092,11 @@ OperationParser::parseCustomOperation(ArrayRef<ResultRecord> resultIDs) {
// Otherwise, create the operation and try to parse a location for it.
Operation *op = opBuilder.create(opState);
+
+ // If enabled, store the SSA name(s) for the operation
+ if (state.config.shouldRetainIdentifierNames())
+ storeSSANames(op, resultIDs);
+
if (parseTrailingLocationSpecifier(op))
return nullptr;
@@ -2355,6 +2421,10 @@ ParseResult OperationParser::parseOptionalBlockArgList(Block *owner) {
} else {
auto loc = getEncodedSourceLocation(useInfo.location);
arg = owner->addArgument(type, loc);
+
+ // Optionally store argument name for debug purposes
+ if (state.config.shouldRetainIdentifierNames())
+ argNames.insert({arg, useInfo.name});
}
// If the argument has an explicit loc(...) specifier, parse and apply
diff --git a/mlir/lib/IR/AsmPrinter.cpp b/mlir/lib/IR/AsmPrinter.cpp
index 6b8b7473bf0f8a..84603bb6ebfba3 100644
--- a/mlir/lib/IR/AsmPrinter.cpp
+++ b/mlir/lib/IR/AsmPrinter.cpp
@@ -72,13 +72,12 @@ OpAsmParser::~OpAsmParser() = default;
MLIRContext *AsmParser::getContext() const { return getBuilder().getContext(); }
/// Parse a type list.
-/// This is out-of-line to work-around https://github.com/llvm/llvm-project/issues/62918
+/// This is out-of-line to work-around
+/// https://github.com/llvm/llvm-project/issues/62918
ParseResult AsmParser::parseTypeList(SmallVectorImpl<Type> &result) {
- return parseCommaSeparatedList(
- [&]() { return parseType(result.emplace_back()); });
- }
-
-
+ return parseCommaSeparatedList(
+ [&]() { return parseType(result.emplace_back()); });
+}
//===----------------------------------------------------------------------===//
// DialectAsmPrinter
@@ -1299,6 +1298,11 @@ class SSANameState {
/// conflicts, it is automatically renamed.
StringRef uniqueValueName(StringRef name);
+ /// Set the original identifier names if available. Used in debugging with
+ /// `--retain-identifier-names`/`shouldRetainIdentifierNames` in ParserConfig
+ void setRetainedIdentifierNames(Operation &op,
+ SmallVector<int, 2> &resultGroups);
+
/// This is the value ID for each SSA value. If this returns NameSentinel,
/// then the valueID has an entry in valueNames.
DenseMap<Value, unsigned> valueIDs;
@@ -1568,6 +1572,10 @@ void SSANameState::numberValuesInOp(Operation &op) {
}
}
+ // Set the original identifier names if available. Used in debugging with
+ // `--retain-identifier-names`/`shouldRetainIdentifierNames` in ParserConfig
+ setRetainedIdentifierNames(op, resultGroups);
+
unsigned numResults = op.getNumResults();
if (numResults == 0) {
// If value users should be printed, operations with no result need an id.
@@ -1590,6 +1598,64 @@ void SSANameState::numberValuesInOp(Operation &op) {
}
}
+void SSANameState::setRetainedIdentifierNames(
+ Operation &op, SmallVector<int, 2> &resultGroups) {
+ // Get the original names for the results if available
+ if (ArrayAttr resultNamesAttr =
+ op.getAttrOfType<ArrayAttr>("mlir.resultNames")) {
+ auto resultNames = resultNamesAttr.getValue();
+ auto results = op.getResults();
+ // Conservative in the case that the #results has changed
+ for (size_t i = 0; i < results.size() && i < resultNames.size(); ++i) {
+ auto resultName = resultNames[i].cast<StringAttr>().strref();
+ if (!resultName.empty()) {
+ if (!usedNames.count(resultName))
+ setValueName(results[i], resultName);
+ // If a result has a name, it is the start of a result group.
+ if (i > 0)
+ resultGroups.push_back(i);
+ }
+ }
+ op.removeDiscardableAttr("mlir.resultNames");
+ }
+
+ // Get the original name for the op args if available
+ if (ArrayAttr opArgNamesAttr =
+ op.getAttrOfType<ArrayAttr>("mlir.opArgNames")) {
+ auto opArgNames = opArgNamesAttr.getValue();
+ auto opArgs = op.getOperands();
+ // Conservative in the case that the #operands has changed
+ for (size_t i = 0; i < opArgs.size() && i < opArgNames.size(); ++i) {
+ auto opArgName = opArgNames[i].cast<StringAttr>().strref();
+ if (!usedNames.count(opArgName))
+ setValueName(opArgs[i], opArgName);
+ }
+ op.removeDiscardableAttr("mlir.opArgNames");
+ }
+
+ // Get the original name for the block if available
+ if (StringAttr blockNameAttr =
+ op.getAttrOfType<StringAttr>("mlir.blockName")) {
+ blockNames[op.getBlock()] = {-1, blockNameAttr.strref()};
+ op.removeDiscardableAttr("mlir.blockName");
+ }
+
+ // Get the original name for the block args if available
+ if (ArrayAttr blockArgNamesAttr =
+ op.getAttrOfType<ArrayAttr>("mlir.blockArgNames")) {
+ auto blockArgNames = blockArgNamesAttr.getValue();
+ auto blockArgs = op.getBlock()->getArguments();
+ // Conservative in the case that the #args has changed
+ for (size_t i = 0; i < blockArgs.size() && i < blockArgNames.size(); ++i) {
+ auto blockArgName = blockArgNames[i].cast<StringAttr>().strref();
+ if (!usedNames.count(blockArgName))
+ setValueName(blockArgs[i], blockArgName);
+ }
+ op.removeDiscardableAttr("mlir.blockArgNames");
+ }
+ return;
+}
+
void SSANameState::getResultIDAndNumber(
OpResult result, Value &lookupValue,
std::optional<int> &lookupResultNo) const {
diff --git a/mlir/lib/Tools/mlir-opt/MlirOptMain.cpp b/mlir/lib/Tools/mlir-opt/MlirOptMain.cpp
index 5395aa2b502d78..c4482435861590 100644
--- a/mlir/lib/Tools/mlir-opt/MlirOptMain.cpp
+++ b/mlir/lib/Tools/mlir-opt/MlirOptMain.cpp
@@ -149,6 +149,11 @@ struct MlirOptMainConfigCLOptions : public MlirOptMainConfig {
cl::desc("Round-trip the IR after parsing and ensure it succeeds"),
cl::location(verifyRoundtripFlag), cl::init(false));
+ static cl::opt<bool, /*ExternalStorage=*/true> retainIdentifierNames(
+ "retain-identifier-names",
+ cl::desc("Retain the original names of identifiers when printing"),
+ cl::location(retainIdentifierNamesFlag), cl::init(false));
+
static cl::list<std::string> passPlugins(
"load-pass-plugin", cl::desc("Load passes from plugin library"));
@@ -359,8 +364,9 @@ performActions(raw_ostream &os,
// untouched.
PassReproducerOptions reproOptions;
FallbackAsmResourceMap fallbackResourceMap;
- ParserConfig parseConfig(context, /*verifyAfterParse=*/true,
- &fallbackResourceMap);
+ ParserConfig parseConfig(
+ context, /*verifyAfterParse=*/true, &fallbackResourceMap,
+ /*retainIdentifierName=*/config.shouldRetainIdentifierNames());
if (config.shouldRunReproducer())
reproOptions.attachResourceParser(parseConfig);
diff --git a/mlir/test/IR/print-retain-identifiers.mlir b/mlir/test/IR/print-retain-identifiers.mlir
new file mode 100644
index 00000000000000..b3e4f075b3936a
--- /dev/null
+++ b/mlir/test/IR/print-retain-identifiers.mlir
@@ -0,0 +1,92 @@
+// RUN: mlir-opt -retain-identifier-names %s | FileCheck %s
+
+
+//===----------------------------------------------------------------------===//
+// Test SSA results (with single return values)
+//===----------------------------------------------------------------------===//
+
+// CHECK: func.func @add_one(%my_input: f64, %arg1: f64) -> f64 {
+func.func @add_one(%my_input: f64, %arg1: f64) -> f64 {
+ // CHECK: %my_constant = arith.constant 1.000000e+00 : f64
+ %my_constant = arith.constant 1.000000e+00 : f64
+ // CHECK: %my_output = arith.addf %my_input, %my_constant : f64
+ %my_output = arith.addf %my_input, %my_constant : f64
+ // CHECK: return %my_output : f64
+ return %my_output : f64
+}
+
+
+// -----
+
+//===----------------------------------------------------------------------===//
+// Test basic blocks and their arguments
+//===----------------------------------------------------------------------===//
+
+func.func @simple(i64, i1) -> i64 {
+^bb_alpha(%a: i64, %cond: i1):
+ // CHECK: cf.cond_br %cond, ^bb_beta, ^bb_gamma
+ cf.cond_br %cond, ^bb_beta, ^bb_gamma
+
+// CHECK: ^bb_beta: // pred: ^bb_alpha
+^bb_beta:
+ // CHECK: cf.br ^bb_delta(%a : i64)
+ cf.br ^bb_delta(%a: i64)
+
+// CHECK: ^bb_gamma: // pred: ^bb_alpha
+^bb_gamma:
+ // CHECK: %b = arith.addi %a, %a : i64
+ %b = arith.addi %a, %a : i64
+ // CHECK: cf.br ^bb_delta(%b : i64)
+ cf.br ^bb_delta(%b: i64)
+
+// CHECK: ^bb_delta(%c: i64): // 2 preds: ^bb_gamma, ^bb_beta
+^bb_delta(%c: i64):
+ // CHECK: cf.br ^bb_eps(%c, %a : i64, i64)
+ cf.br ^bb_eps(%c, %a : i64, i64)
+
+// CHECK: ^bb_eps(%d: i64, %e: i64): // pred: ^bb_delta
+^bb_eps(%d : i64, %e : i64):
+ // CHECK: %f = arith.addi %d, %e : i64
+ %f = arith.addi %d, %e : i64
+ return %f : i64
+}
+
+// -----
+
+//===----------------------------------------------------------------------===//
+// Test multiple return values
+//===----------------------------------------------------------------------===//
+
+func.func @select_min_max(%a: f64, %b: f64) -> (f64, f64) {
+ %gt = arith.cmpf "ogt", %a, %b : f64
+ // CHECK: %min, %max = scf.if %gt -> (f64, f64) {
+ %min, %max = scf.if %gt -> (f64, f64) {
+ scf.yield %b, %a : f64, f64
+ } else {
+ scf.yield %a, %b : f64, f64
+ }
+ // CHECK: return %min, %max : f64, f64
+ return %min, %max : f64, f64
+}
+
+// -----
+
+////===----------------------------------------------------------------------===//
+// Test multiple return values, with a grouped value tuple
+//===----------------------------------------------------------------------===//
+
+func.func @select_max(%a: f64, %b: f64, %c: f64, %d: f64) -> (f64, f64, f64, f64) {
+ // Find the max between %a and %b,
+ // with %c and %d being other values that are returned.
+ %gt = arith.cmpf "ogt", %a, %b : f64
+ // CHECK: %max, %others:2, %alt = scf.if %gt -> (f64, f64, f64, f64) {
+ %max, %others:2, %alt = scf.if %gt -> (f64, f64, f64, f64) {
+ scf.yield %b, %a, %c, %d : f64, f64, f64, f64
+ } else {
+ scf.yield %a, %b, %d, %c : f64, f64, f64, f64
+ }
+ // CHECK: return %max, %others#0, %others#1, %alt : f64, f64, f64, f64
+ return %max, %others#0, %others#1, %alt : f64, f64, f64, f64
+}
+
+// -----
|
✅ With the latest revision this PR passed the C/C++ code formatter. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This is a fairly straightforward solution, however the requirements aren't super clear to me to understand if this is really the right solution (for example the interactions with the AsmOpInterface can be subtle).
I also have concerns about using the dictionary of attributes here: the main purpose of the textual IR is to be faithful to the in-memory representation, and this is breaking this contract.
Thanks for the review! I've tried to enumerate some requirements gathered from the discourse discussion and my own opinions, see below. First though, regarding:
This is a valid concern. I'll try and advocate for this being okay and/or not unprecedented:
I could add an additional flag such as This would generate IR which looks like: IR with no attr dict cleanup
module {
func.func @add_one(%my_input: f64, %arg1: f64) -> f64 {
%my_constant = arith.constant {mlir.resultNames = ["my_constant"]} 1.000000e+00 : f64
%my_output = arith.addf %my_input, %my_constant {mlir.opArgNames = ["my_input", "my_constant"], mlir.resultNames = ["my_output"]} : f64
return {mlir.opArgNames = ["my_output"]} %my_output : f64
}
} Secondly, regarding:
Yes, although the function that actually sets the names is a lambda which is passed by the caller to the operation (e.g., the lambda in AsmPrinter we care about). This means that although we can't control what the operation does in its implementation, we can control what we do with the string it gives us. Your comment makes me realise that I should add test cases where our chosen names might clash with the names chosen by a RequirementsLegend Functionality
Performance
Maintainability
🟠: Issues to discuss or addressFunctionality
Maintainability
|
I've managed to fix the issue around unused Reading the Func dialect docs, they say that;
Therefore, I handle this case in a generic way with Also, following your suggestion around the how AsmOpInterface behaves, I added a test case where I set names that would explicitly clash with an contradict default anonymous names (e.g. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I would want to see at least a design for another more principled approach before possibly considering this one.
At the moment I'm more on the side of "this feature isn't worth the drawback of polluting the IR".
Excluding adding a new member variable to the Value class (which adds an unacceptable constant overhead), I think there are two approaches that are worth considering. 1. Location based approachWould not require changes to the IR, however it would require re-parsing, which could add a lot more complexity (e.g., how do we reassociate a newly parsed Value and its name with the original Value?) There are also clients of location info, for whom this feature could cause problems. 2. Use of external data structureWe could create an external data structure, e.g., a However, typically (e.g., in We have something similar to this already, the A direct Value mapping approach also has the advantage compared to the attribute dict approach of reducing ambiguity in the case of IR transformation. If we introduce new arguments or operations, unless the Value is the same, we don't touch the name. Personally, I think option 2. is a more straightforward and less intrusive solution. If we made the map a user-managed object (e.g., by Regarding the attribute dictionary approach, the only situation when the IR would be polluted is when this flag is enabled, which limits its impact. |
One thing I was wondering about here (and sorry haven't had time to follow the other discussion - also sounded like there was some confusion with what I said there as I saw 2 uses and meant different comments for each). For 2 have we considered enabling creating an AsmState from an AsmParserState? That would handle this right (its in the 2nd approach model). It would not handle tracking mutations - but then again mutation tracking would be duplicated via other means. It may require allowing mutations of AsmState ... so potentially a separate mapping may be better. |
Thanks @jpienaar, I hadn't considered that, but it looks like AsmState could be a good way to manage the name data.
That looks like a good way to store this data, rather than letting mlir-opt handle it as I suggested.
There is a member variable for Regarding making an We could have something akin to Thoughts on this design? I'll try making another version that has the external data structure, first managed by |
Can we iterate on this? Preserving legible SSA names into the parse and even through passes is immensely valuable/helpful for debugging everything. Note, I don't understand the relationship to formatters so what follows doesn't address that use case. I propose we store the identifier in the Location info as a I prototyped this using @Wheest's current PR (a few small changes) and I got this: Source IR: #loc = loc("triton/python/examples/empty.py":17:0)
module {
tt.func public @add_kernel(
%in_ptr0: !tt.ptr<f32> {tt.divisibility = 16 : i32, tt.pointer_range = 32 : i32} loc("triton/python/examples/empty.py":17:0),
%arg1: !tt.ptr<f32> {tt.divisibility = 16 : i32, tt.pointer_range = 32 : i32} loc("triton/python/examples/empty.py":17:0),
%arg2: !tt.ptr<f32> {tt.divisibility = 16 : i32, tt.pointer_range = 32 : i32} loc("triton/python/examples/empty.py":17:0)) -> tensor<1024xf32> attributes {noinline = false} {
%c1024_i32 = arith.constant 1024 : i32 loc(#loc1)
%pid = tt.get_program_id x : i32 loc(#loc2)
%block_start = arith.muli %pid, %c1024_i32 : i32 loc(#loc3)
%make_range = tt.make_range {end = 1024 : i32, start = 0 : i32} : tensor<1024xi32> loc(#loc4)
%block_start_splat = tt.splat %block_start : i32 -> tensor<1024xi32> loc(#loc5)
%offsets = arith.addi %block_start_splat, %make_range : tensor<1024xi32> loc(#loc5)
%in_ptr0_splat = tt.splat %in_ptr0 : !tt.ptr<f32> -> tensor<1024x!tt.ptr<f32>> loc(#loc6)
%addr = tt.addptr %in_ptr0_splat, %offsets : tensor<1024x!tt.ptr<f32>>, tensor<1024xi32> loc(#loc6)
%val = tt.load %addr : tensor<1024x!tt.ptr<f32>> loc(#loc7)
tt.return %val : tensor<1024xf32> loc(#loc8)
} loc(#loc)
} loc(#loc)
#loc1 = loc(unknown)
#loc2 = loc("triton/python/examples/empty.py":24:24)
#loc3 = loc("triton/python/examples/empty.py":25:24)
#loc4 = loc("triton/python/examples/empty.py":26:41)
#loc5 = loc("triton/python/examples/empty.py":26:28)
#loc6 = loc("triton/python/examples/empty.py":27:26)
#loc7 = loc("triton/python/examples/empty.py":27:16)
#loc8 = loc("triton/python/examples/empty.py":29:11) Just parsing: #loc = triton/python/examples/empty.py:17:0
module {
tt.func public @add_kernel(%arg0: !tt.ptr<f32> {tt.divisibility = 16 : i32, tt.pointer_range = 32 : i32} triton/python/examples/empty.py:17:0, %arg1: !tt.ptr<f32> {tt.divisibility = 16 : i32, tt.pointer_range = 32 : i32} triton/python/examples/empty.py:17:0, %arg2: !tt.ptr<f32> {tt.divisibility = 16 : i32, tt.pointer_range = 32 : i32} triton/python/examples/empty.py:17:0) -> tensor<1024xf32> attributes {noinline = false} {
%c1024_i32 = arith.constant 1024 : i32 "c1024_i32"
%0 = tt.get_program_id x : i32 "pid"(#loc2)
%1 = arith.muli %0, %c1024_i32 : i32 "block_start"(#loc3)
%2 = tt.make_range {end = 1024 : i32, start = 0 : i32} : tensor<1024xi32> "make_range"(#loc4)
%3 = tt.splat %1 : i32 -> tensor<1024xi32> "block_start_splat"(#loc5)
%4 = arith.addi %3, %2 : tensor<1024xi32> "offsets"(#loc5)
%5 = tt.splat %arg0 : !tt.ptr<f32> -> tensor<1024x!tt.ptr<f32>> "in_ptr0_splat"(#loc6)
%6 = tt.addptr %5, %4 : tensor<1024x!tt.ptr<f32>>, tensor<1024xi32> "addr"(#loc6)
%7 = tt.load %6 : tensor<1024x!tt.ptr<f32>> "val"(#loc7)
tt.return %7 : tensor<1024xf32> ""(#loc8)
} triton/python/examples/empty.py:17:0
} triton/python/examples/empty.py:17:0
#loc1 = "c1024_i32"
#loc2 = triton/python/examples/empty.py:24:24
#loc3 = triton/python/examples/empty.py:25:24
#loc4 = triton/python/examples/empty.py:26:41
#loc5 = triton/python/examples/empty.py:26:28
#loc6 = triton/python/examples/empty.py:27:26
#loc7 = triton/python/examples/empty.py:27:16
#loc8 = triton/python/examples/empty.py:29:11 After some passes: #loc = triton/python/examples/empty.py:17:0
module {
tt.func public @add_kernel(%arg0: !tt.ptr<f32> {tt.divisibility = 16 : i32, tt.pointer_range = 32 : i32} triton/python/examples/empty.py:17:0, %arg1: !tt.ptr<f32> {tt.divisibility = 16 : i32, tt.pointer_range = 32 : i32} triton/python/examples/empty.py:17:0, %arg2: !tt.ptr<f32> {tt.divisibility = 16 : i32, tt.pointer_range = 32 : i32} triton/python/examples/empty.py:17:0) -> tensor<1024xf32> attributes {noinline = false, rewritten} {
%c0_i32 = arith.constant 0 : i32 triton/python/examples/empty.py:17:0
%0 = builtin.unrealized_conversion_cast %arg0, %c0_i32 : !tt.ptr<f32>, i32 to !tt.ptr<f32> triton/python/examples/empty.py:17:0
%c0_i32_0 = arith.constant 0 : i32 triton/python/examples/empty.py:17:0
%1 = builtin.unrealized_conversion_cast %arg1, %c0_i32_0 : !tt.ptr<f32>, i32 to !tt.ptr<f32> triton/python/examples/empty.py:17:0
%c0_i32_1 = arith.constant 0 : i32 triton/python/examples/empty.py:17:0
%2 = builtin.unrealized_conversion_cast %arg2, %c0_i32_1 : !tt.ptr<f32>, i32 to !tt.ptr<f32> triton/python/examples/empty.py:17:0
%c1024_i32 = arith.constant 1024 : i32 "c1024_i32"
%3 = tt.get_program_id x : i32 "pid"(#loc2)
%4 = arith.muli %3, %c1024_i32 : i32 "block_start"(#loc3)
%5 = tt.make_range {end = 1024 : i32, start = 0 : i32} : tensor<1024xi32> "make_range"(#loc4)
%6 = tt.splat %4 : i32 -> tensor<1024xi32> "block_start_splat"(#loc5)
%7 = arith.addi %6, %5 : tensor<1024xi32> "offsets"(#loc5)
%8 = tt.splat %c0_i32 : i32 -> tensor<1024xi32> "in_ptr0_splat"(#loc6)
%cst = arith.constant dense<0> : tensor<1024xi32> "addr"(#loc6)
%c0_i32_2 = arith.constant 0 : i32 "addr"(#loc6)
%9 = arith.addi %4, %c0_i32_2 : i32 "addr"(#loc6)
%10 = arith.addi %cst, %5 : tensor<1024xi32> "addr"(#loc6)
%11 = arith.addi %10, %8 : tensor<1024xi32> "addr"(#loc6)
%12 = tt.addptr %0, %9 : !tt.ptr<f32>, i32 "addr"(#loc6)
%13 = tt.splat %12 {rewritten} : !tt.ptr<f32> -> tensor<1024x!tt.ptr<f32>> "val"(#loc7)
%14 = tt.addptr %13, %11 {rewritten} : tensor<1024x!tt.ptr<f32>>, tensor<1024xi32> "val"(#loc7)
%15 = tt.load %14 {rewritten} : tensor<1024x!tt.ptr<f32>> "val"(#loc7)
tt.return %15 : tensor<1024xf32> ""(#loc8)
} triton/python/examples/empty.py:17:0
} triton/python/examples/empty.py:17:0
#loc1 = "c1024_i32"
#loc2 = triton/python/examples/empty.py:24:24
#loc3 = triton/python/examples/empty.py:25:24
#loc4 = triton/python/examples/empty.py:26:41
#loc5 = triton/python/examples/empty.py:26:28
#loc6 = triton/python/examples/empty.py:27:26
#loc7 = triton/python/examples/empty.py:27:16
#loc8 = triton/python/examples/empty.py:29:11 I'll point out what's nice here is that %cst = arith.constant dense<0> : tensor<1024xi32> "addr"(#loc6)
%c0_i32_2 = arith.constant 0 : i32 "addr"(#loc6)
%9 = arith.addi %4, %c0_i32_2 : i32 "addr"(#loc6)
%10 = arith.addi %cst, %5 : tensor<1024xi32> "addr"(#loc6)
%11 = arith.addi %10, %8 : tensor<1024xi32> "addr"(#loc6)
%12 = tt.addptr %0, %9 : !tt.ptr<f32>, i32 "addr"(#loc6) Note this example started with debug info in the source and changes I currently have are compatible with that but they of course also work if the original source doesn't have any info: module {
tt.func public @add_kernel(%arg0: !tt.ptr<f32> {tt.divisibility = 16 : i32, tt.pointer_range = 32 : i32} within split...
%c1024_i32 = arith.constant 1024 : i32 "c1024_i32"(within split...)
%0 = tt.get_program_id x : i32 "pid"(within split...)
%1 = arith.muli %0, %c1024_i32 : i32 "block_start"(within split...)
%2 = tt.make_range {end = 1024 : i32, start = 0 : i32} : tensor<1024xi32> "make_range"(within split...)
%3 = tt.splat %1 : i32 -> tensor<1024xi32> "block_start_splat"(within split...)
%4 = arith.addi %3, %2 : tensor<1024xi32> "offsets"(within split...)
%5 = tt.splat %arg0 : !tt.ptr<f32> -> tensor<1024x!tt.ptr<f32>> "in_ptr0_splat"(within split...)
%6 = tt.addptr %5, %4 : tensor<1024x!tt.ptr<f32>>, tensor<1024xi32> "addr"(within split...)
%7 = tt.load %6 : tensor<1024x!tt.ptr<f32>> "val"(within split...)
tt.return %7 : tensor<1024xf32> within split...
} within split...
} within split... and module {
tt.func public @add_kernel(%arg0: !tt.ptr<f32> {tt.divisibility = 16 : i32, tt.pointer_range = 32 : i32} within split...
%c0_i32 = arith.constant 0 : i32 within split...
%0 = builtin.unrealized_conversion_cast %arg0, %c0_i32 : !tt.ptr<f32>, i32 to !tt.ptr<f32> within split...
%c0_i32_0 = arith.constant 0 : i32 within split...
%1 = builtin.unrealized_conversion_cast %arg1, %c0_i32_0 : !tt.ptr<f32>, i32 to !tt.ptr<f32> within split...
%c0_i32_1 = arith.constant 0 : i32 within split...
%2 = builtin.unrealized_conversion_cast %arg2, %c0_i32_1 : !tt.ptr<f32>, i32 to !tt.ptr<f32> within split...
%c1024_i32 = arith.constant 1024 : i32 "c1024_i32"(within split...)
%3 = tt.get_program_id x : i32 "pid"(within split...)
%4 = arith.muli %3, %c1024_i32 : i32 "block_start"(within split...)
%5 = tt.make_range {end = 1024 : i32, start = 0 : i32} : tensor<1024xi32> "make_range"(within split...)
%6 = tt.splat %4 : i32 -> tensor<1024xi32> "block_start_splat"(within split...)
%7 = arith.addi %6, %5 : tensor<1024xi32> "offsets"(within split...)
%8 = tt.splat %c0_i32 : i32 -> tensor<1024xi32> "in_ptr0_splat"(within split...)
%cst = arith.constant dense<0> : tensor<1024xi32> "addr"(within split...)
%c0_i32_2 = arith.constant 0 : i32 "addr"(within split...)
%9 = arith.addi %4, %c0_i32_2 : i32 "addr"(within split...)
%10 = arith.addi %cst, %5 : tensor<1024xi32> "addr"(within split...)
%11 = arith.addi %10, %8 : tensor<1024xi32> "addr"(within split...)
%12 = tt.addptr %0, %9 : !tt.ptr<f32>, i32 "addr"(within split...)
%13 = tt.splat %12 {rewritten} : !tt.ptr<f32> -> tensor<1024x!tt.ptr<f32>> "val"(within split...)
%14 = tt.addptr %13, %11 {rewritten} : tensor<1024x!tt.ptr<f32>>, tensor<1024xi32> "val"(within split...)
%15 = tt.load %14 {rewritten} : tensor<1024x!tt.ptr<f32>> "val"(within split...)
tt.return %15 : tensor<1024xf32> within split...
} within split...
} within split... |
Being able to retain the original name of each ssa value would be so nice for debugging |
This PR implements retention of MLIR identifier names (e.g.,
%my_val
,^bb_foo
) for debugging and development purposes.A wider discussion of this feature is in this discourse thread.
A motivating example is that right now, IR generation drops all meaningful identifier names, which could be useful for developers trying to understand their passes, or for other tooling (e.g., MLIR code formatters).
⬇️ becomes
The solution this PR implements is to store this metadata inside the attribute dictionary of operations, under a special namespace (e.g.,
mlir.resultNames
). This means that this optional feature (turned on by themlir-opt
flag--retain-identifier-names
) does not incur any additional overhead, except in text parsing and printing (AsmParser/Parser.cpp
andIR/AsmPrinter.cpp
).Alternative solutions, such as adding a string field to the
Value
class, reparsing location information, and adaptingOpAsmInterface
are discussed in the relevant discourse thread).I've implemented some initial test cases in
mlir/test/IR/print-retain-identifiers.mlir
.This covers things such as:
A case that I know not to work is when a
func.func
argument is not used. This is because we recover the SSA name of these arguments from operations which use them. You can see in the first test case that the 2nd argument is not used, so will always default back to%arg0
.I do not have test cases for how the system handles code transformations, and am open to suggestions for additional tests to include. Also note that this is my most substantial contribution to the codebase to date, so I may require some shepherding with regards to coding style or use of core LLVM library constructs.