Skip to content

[mlir] Retain original identifier names for debugging #79704

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 11 commits into
base: main
Choose a base branch
from
9 changes: 8 additions & 1 deletion mlir/include/mlir/IR/AsmState.h
Original file line number Diff line number Diff line change
Expand Up @@ -464,8 +464,10 @@ class ParserConfig {
/// `fallbackResourceMap` is an optional fallback handler that can be used to
/// parse external resources not explicitly handled by another parser.
ParserConfig(MLIRContext *context, bool verifyAfterParse = true,
FallbackAsmResourceMap *fallbackResourceMap = nullptr)
FallbackAsmResourceMap *fallbackResourceMap = nullptr,
bool retainIdentifierNames = false)
: context(context), verifyAfterParse(verifyAfterParse),
retainIdentifierNames(retainIdentifierNames),
fallbackResourceMap(fallbackResourceMap) {
assert(context && "expected valid MLIR context");
}
Expand All @@ -476,6 +478,10 @@ class ParserConfig {
/// Returns if the parser should verify the IR after parsing.
bool shouldVerifyAfterParse() const { return verifyAfterParse; }

/// Returns if the parser should retain identifier names collected using
/// parsing.
bool shouldRetainIdentifierNames() const { return retainIdentifierNames; }

/// Returns the parsing configurations associated to the bytecode read.
BytecodeReaderConfig &getBytecodeReaderConfig() const {
return const_cast<BytecodeReaderConfig &>(bytecodeReaderConfig);
Expand Down Expand Up @@ -513,6 +519,7 @@ class ParserConfig {
private:
MLIRContext *context;
bool verifyAfterParse;
bool retainIdentifierNames;
DenseMap<StringRef, std::unique_ptr<AsmResourceParser>> resourceParsers;
FallbackAsmResourceMap *fallbackResourceMap;
BytecodeReaderConfig bytecodeReaderConfig;
Expand Down
10 changes: 10 additions & 0 deletions mlir/include/mlir/Tools/mlir-opt/MlirOptMain.h
Original file line number Diff line number Diff line change
Expand Up @@ -176,6 +176,13 @@ class MlirOptMainConfig {
/// Reproducer file generation (no crash required).
StringRef getReproducerFilename() const { return generateReproducerFileFlag; }

/// Print the pass-pipeline as text before executing.
MlirOptMainConfig &retainIdentifierNames(bool retain) {
retainIdentifierNamesFlag = retain;
return *this;
}
bool shouldRetainIdentifierNames() const { return retainIdentifierNamesFlag; }

protected:
/// Allow operation with no registered dialects.
/// This option is for convenience during testing only and discouraged in
Expand Down Expand Up @@ -226,6 +233,9 @@ class MlirOptMainConfig {
/// the corresponding line. This is meant for implementing diagnostic tests.
bool verifyDiagnosticsFlag = false;

/// Retain identifier names in the output (e.g., `%my_var` instead of `%0`).
bool retainIdentifierNamesFlag = false;

/// Run the verifier after each transformation pass.
bool verifyPassesFlag = true;

Expand Down
86 changes: 86 additions & 0 deletions mlir/lib/AsmParser/Parser.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -611,6 +611,11 @@ class OperationParser : public Parser {
/// an object of type 'OperationName'. Otherwise, failure is returned.
FailureOr<OperationName> parseCustomOperationName();

/// Store the identifier names for the current operation as attrs for debug
/// purposes.
void storeIdentifierNames(Operation *&op, ArrayRef<ResultRecord> resultIDs);
DenseMap<Value, StringRef> argNames;

//===--------------------------------------------------------------------===//
// Region Parsing
//===--------------------------------------------------------------------===//
Expand Down Expand Up @@ -1268,6 +1273,70 @@ OperationParser::parseSuccessors(SmallVectorImpl<Block *> &destinations) {
/*allowEmptyList=*/false);
}

/// Store the SSA names for the current operation as attrs for debug purposes.
void OperationParser::storeIdentifierNames(Operation *&op,
ArrayRef<ResultRecord> resultIDs) {

// Store the name(s) of the result(s) of this operation.
if (op->getNumResults() > 0) {
llvm::SmallVector<llvm::StringRef, 1> resultNames;
for (const ResultRecord &resIt : resultIDs) {
resultNames.push_back(std::get<0>(resIt).drop_front(1));
// Insert empty string for sub-results/result groups
for (unsigned int i = 1; i < std::get<1>(resIt); ++i)
resultNames.push_back(llvm::StringRef());
}
op->setDiscardableAttr("mlir.resultNames",
builder.getStrArrayAttr(resultNames));
}

// Store the name information of the arguments of this operation.
if (op->getNumOperands() > 0) {
llvm::SmallVector<llvm::StringRef, 1> opArgNames;
for (auto &operand : op->getOpOperands()) {
auto it = argNames.find(operand.get());
if (it != argNames.end())
opArgNames.push_back(it->second.drop_front(1));
}
op->setDiscardableAttr("mlir.opArgNames",
builder.getStrArrayAttr(opArgNames));
}

// Store the name information of the block that contains this operation.
Block *blockPtr = op->getBlock();
for (const auto &map : blocksByName) {
for (const auto &entry : map) {
if (entry.second.block == blockPtr) {
op->setDiscardableAttr("mlir.blockName",
StringAttr::get(getContext(), entry.first));

// Store block arguments, if present
llvm::SmallVector<llvm::StringRef, 1> blockArgNames;

for (BlockArgument arg : blockPtr->getArguments()) {
auto it = argNames.find(arg);
if (it != argNames.end())
blockArgNames.push_back(it->second.drop_front(1));
}
op->setAttr("mlir.blockArgNames",
builder.getStrArrayAttr(blockArgNames));
}
}
}

// Store names of region arguments (e.g., for FuncOps)
if (op->getNumRegions() > 0 && op->getRegion(0).getNumArguments() > 0) {
llvm::SmallVector<llvm::StringRef, 1> regionArgNames;
for (BlockArgument arg : op->getRegion(0).getArguments()) {
auto it = argNames.find(arg);
if (it != argNames.end()) {
regionArgNames.push_back(it->second.drop_front(1));
}
}
op->setAttr("mlir.regionArgNames", builder.getStrArrayAttr(regionArgNames));
}
}

namespace {
// RAII-style guard for cleaning up the regions in the operation state before
// deleting them. Within the parser, regions may get deleted if parsing failed,
Expand Down Expand Up @@ -1672,6 +1741,11 @@ class CustomOpAsmParser : public AsmParserImpl<OpAsmParser> {
SmallVectorImpl<Value> &result) override {
if (auto value = parser.resolveSSAUse(operand, type)) {
result.push_back(value);

// Optionally store argument name for debug purposes
if (parser.getState().config.shouldRetainIdentifierNames())
parser.argNames.insert({value, operand.name});

return success();
}
return failure();
Expand Down Expand Up @@ -2031,6 +2105,11 @@ OperationParser::parseCustomOperation(ArrayRef<ResultRecord> resultIDs) {

// Otherwise, create the operation and try to parse a location for it.
Operation *op = opBuilder.create(opState);

// If enabled, store the original identifier name(s) for the operation
if (state.config.shouldRetainIdentifierNames())
storeIdentifierNames(op, resultIDs);

if (parseTrailingLocationSpecifier(op))
return nullptr;

Expand Down Expand Up @@ -2180,6 +2259,9 @@ ParseResult OperationParser::parseRegionBody(Region &region, SMLoc startLoc,
if (state.asmState)
state.asmState->addDefinition(arg, argInfo.location);

if (state.config.shouldRetainIdentifierNames())
argNames.insert({arg, argInfo.name});

// Record the definition for this argument.
if (addDefinition(argInfo, arg))
return failure();
Expand Down Expand Up @@ -2355,6 +2437,10 @@ ParseResult OperationParser::parseOptionalBlockArgList(Block *owner) {
} else {
auto loc = getEncodedSourceLocation(useInfo.location);
arg = owner->addArgument(type, loc);

// Optionally store argument name for debug purposes
if (state.config.shouldRetainIdentifierNames())
argNames.insert({arg, useInfo.name});
}

// If the argument has an explicit loc(...) specifier, parse and apply
Expand Down
125 changes: 107 additions & 18 deletions mlir/lib/IR/AsmPrinter.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -72,13 +72,12 @@ OpAsmParser::~OpAsmParser() = default;
MLIRContext *AsmParser::getContext() const { return getBuilder().getContext(); }

/// Parse a type list.
/// This is out-of-line to work-around https://github.com/llvm/llvm-project/issues/62918
/// This is out-of-line to work-around
/// https://github.com/llvm/llvm-project/issues/62918
ParseResult AsmParser::parseTypeList(SmallVectorImpl<Type> &result) {
return parseCommaSeparatedList(
[&]() { return parseType(result.emplace_back()); });
}


return parseCommaSeparatedList(
[&]() { return parseType(result.emplace_back()); });
}

//===----------------------------------------------------------------------===//
// DialectAsmPrinter
Expand Down Expand Up @@ -982,7 +981,8 @@ class DummyAliasDialectAsmPrinter : public DialectAsmPrinter {
/// store the new copy,
static StringRef sanitizeIdentifier(StringRef name, SmallString<16> &buffer,
StringRef allowedPunctChars = "$._-",
bool allowTrailingDigit = true) {
bool allowTrailingDigit = true,
bool allowNumeric = false) {
assert(!name.empty() && "Shouldn't have an empty name here");

auto copyNameToBuffer = [&] {
Expand All @@ -998,16 +998,17 @@ static StringRef sanitizeIdentifier(StringRef name, SmallString<16> &buffer,

// Check to see if this name is valid. If it starts with a digit, then it
// could conflict with the autogenerated numeric ID's, so add an underscore
// prefix to avoid problems.
if (isdigit(name[0])) {
// prefix to avoid problems. This can be overridden by setting allowNumeric.
if (isdigit(name[0]) && !allowNumeric) {
buffer.push_back('_');
copyNameToBuffer();
return buffer;
}

// If the name ends with a trailing digit, add a '_' to avoid potential
// conflicts with autogenerated ID's.
if (!allowTrailingDigit && isdigit(name.back())) {
// conflicts with autogenerated ID's. This can be overridden by setting
// allowNumeric.
if (!allowTrailingDigit && isdigit(name.back()) && !allowNumeric) {
copyNameToBuffer();
buffer.push_back('_');
return buffer;
Expand Down Expand Up @@ -1293,11 +1294,18 @@ class SSANameState {
std::optional<int> &lookupResultNo) const;

/// Set a special value name for the given value.
void setValueName(Value value, StringRef name);
void setValueName(Value value, StringRef name, bool allowNumeric = false);

/// Uniques the given value name within the printer. If the given name
/// conflicts, it is automatically renamed.
StringRef uniqueValueName(StringRef name);
StringRef uniqueValueName(StringRef name, bool allowNumeric = false);

/// Set the original identifier names if available. Used in debugging with
/// `--retain-identifier-names`/`shouldRetainIdentifierNames` in ParserConfig
void setRetainedIdentifierNames(Operation &op,
SmallVector<int, 2> &resultGroups,
bool hasRegion = false);
void setRetainedIdentifierNames(Region &region);

/// This is the value ID for each SSA value. If this returns NameSentinel,
/// then the valueID has an entry in valueNames.
Expand Down Expand Up @@ -1486,6 +1494,9 @@ void SSANameState::numberValuesInRegion(Region &region) {
setValueName(arg, name);
};

// Use manually specified region arg names if available
setRetainedIdentifierNames(region);

if (!printerFlags.shouldPrintGenericOpForm()) {
if (Operation *op = region.getParentOp()) {
if (auto asmInterface = dyn_cast<OpAsmOpInterface>(op))
Expand Down Expand Up @@ -1537,7 +1548,10 @@ void SSANameState::numberValuesInOp(Operation &op) {
// Function used to set the special result names for the operation.
SmallVector<int, 2> resultGroups(/*Size=*/1, /*Value=*/0);
auto setResultNameFn = [&](Value result, StringRef name) {
assert(!valueIDs.count(result) && "result numbered multiple times");
// Case where the result has already been named
if (valueIDs.count(result))
return;
// assert(!valueIDs.count(result) && "result numbered multiple times");
assert(result.getDefiningOp() == &op && "result not defined by 'op'");
setValueName(result, name);

Expand All @@ -1561,6 +1575,10 @@ void SSANameState::numberValuesInOp(Operation &op) {
blockNames[block] = {-1, name};
};

// Set the original identifier names if available. Used in debugging with
// `--retain-identifier-names`/`shouldRetainIdentifierNames` in ParserConfig
setRetainedIdentifierNames(op, resultGroups);

if (!printerFlags.shouldPrintGenericOpForm()) {
if (OpAsmOpInterface asmInterface = dyn_cast<OpAsmOpInterface>(&op)) {
asmInterface.getAsmBlockNames(setBlockNameFn);
Expand Down Expand Up @@ -1590,6 +1608,75 @@ void SSANameState::numberValuesInOp(Operation &op) {
}
}

void SSANameState::setRetainedIdentifierNames(Operation &op,
SmallVector<int, 2> &resultGroups,
bool hasRegion) {

// Lambda which fetches the list of relevant attributes (e.g.,
// mlir.resultNames) and associates them with the relevant values
auto handleNamedAttributes =
[this](Operation &op, const Twine &attrName, auto getValuesFunc,
std::optional<std::function<void(int)>> customAction =
std::nullopt) {
if (ArrayAttr namesAttr = op.getAttrOfType<ArrayAttr>(attrName.str())) {
auto names = namesAttr.getValue();
auto values = getValuesFunc();
// Conservative in case the number of values has changed
for (size_t i = 0; i < values.size() && i < names.size(); ++i) {
auto name = names[i].cast<StringAttr>().strref();
if (!name.empty()) {
if (!this->usedNames.count(name))
this->setValueName(values[i], name, true);
if (customAction.has_value())
customAction.value()(i);
}
}
op.removeDiscardableAttr(attrName.str());
}
};

if (hasRegion) {
// Get the original name(s) for the region arg(s) if available (e.g., for
// FuncOp args). Requires hasRegion flag to ensure scoping is correct
if (hasRegion && op.getNumRegions() > 0 &&
op.getRegion(0).getNumArguments() > 0) {
handleNamedAttributes(op, "mlir.regionArgNames",
[&]() { return op.getRegion(0).getArguments(); });
}
} else {
// Get the original names for the results if available
handleNamedAttributes(
op, "mlir.resultNames", [&]() { return op.getResults(); },
[&resultGroups](int i) { /*handles result groups*/
if (i > 0)
resultGroups.push_back(i);
});

// Get the original name for the op args if available
handleNamedAttributes(op, "mlir.opArgNames",
[&]() { return op.getOperands(); });

// Get the original name for the block if available
if (StringAttr blockNameAttr =
op.getAttrOfType<StringAttr>("mlir.blockName")) {
blockNames[op.getBlock()] = {-1, blockNameAttr.strref()};
op.removeDiscardableAttr("mlir.blockName");
}

// Get the original name(s) for the block arg(s) if available
handleNamedAttributes(op, "mlir.blockArgNames",
[&]() { return op.getBlock()->getArguments(); });
}
return;
}

void SSANameState::setRetainedIdentifierNames(Region &region) {
if (Operation *op = region.getParentOp()) {
SmallVector<int, 2> resultGroups;
setRetainedIdentifierNames(*op, resultGroups, true);
}
}

void SSANameState::getResultIDAndNumber(
OpResult result, Value &lookupValue,
std::optional<int> &lookupResultNo) const {
Expand Down Expand Up @@ -1629,20 +1716,22 @@ void SSANameState::getResultIDAndNumber(
lookupValue = owner->getResult(groupResultNo);
}

void SSANameState::setValueName(Value value, StringRef name) {
void SSANameState::setValueName(Value value, StringRef name,
bool allowNumeric) {
// If the name is empty, the value uses the default numbering.
if (name.empty()) {
valueIDs[value] = nextValueID++;
return;
}

valueIDs[value] = NameSentinel;
valueNames[value] = uniqueValueName(name);
valueNames[value] = uniqueValueName(name, allowNumeric);
}

StringRef SSANameState::uniqueValueName(StringRef name) {
StringRef SSANameState::uniqueValueName(StringRef name, bool allowNumeric) {
SmallString<16> tmpBuffer;
name = sanitizeIdentifier(name, tmpBuffer);
name = sanitizeIdentifier(name, tmpBuffer, /*allowedPunctChars=*/"$._-",
/*allowTrailingDigit=*/true, allowNumeric);

// Check to see if this name is already unique.
if (!usedNames.count(name)) {
Expand Down
Loading