Skip to content

[mlir] Retain original identifier names for debugging v2 #119944

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
9 changes: 8 additions & 1 deletion mlir/include/mlir/IR/AsmState.h
Original file line number Diff line number Diff line change
Expand Up @@ -471,8 +471,10 @@ class ParserConfig {
/// `fallbackResourceMap` is an optional fallback handler that can be used to
/// parse external resources not explicitly handled by another parser.
ParserConfig(MLIRContext *context, bool verifyAfterParse = true,
FallbackAsmResourceMap *fallbackResourceMap = nullptr)
FallbackAsmResourceMap *fallbackResourceMap = nullptr,
bool retainIdentifierNames = false)
: context(context), verifyAfterParse(verifyAfterParse),
retainIdentifierNames(retainIdentifierNames),
fallbackResourceMap(fallbackResourceMap) {
assert(context && "expected valid MLIR context");
}
Expand All @@ -483,6 +485,10 @@ class ParserConfig {
/// Returns if the parser should verify the IR after parsing.
bool shouldVerifyAfterParse() const { return verifyAfterParse; }

/// Returns if the parser should retain identifier names collected using
/// parsing.
bool shouldRetainIdentifierNames() const { return retainIdentifierNames; }

/// Returns the parsing configurations associated to the bytecode read.
BytecodeReaderConfig &getBytecodeReaderConfig() const {
return const_cast<BytecodeReaderConfig &>(bytecodeReaderConfig);
Expand Down Expand Up @@ -520,6 +526,7 @@ class ParserConfig {
private:
MLIRContext *context;
bool verifyAfterParse;
bool retainIdentifierNames;
DenseMap<StringRef, std::unique_ptr<AsmResourceParser>> resourceParsers;
FallbackAsmResourceMap *fallbackResourceMap;
BytecodeReaderConfig bytecodeReaderConfig;
Expand Down
7 changes: 7 additions & 0 deletions mlir/include/mlir/IR/OperationSupport.h
Original file line number Diff line number Diff line change
Expand Up @@ -1221,6 +1221,10 @@ class OpPrintingFlags {
/// Return if printer should use unique SSA IDs.
bool shouldPrintUniqueSSAIDs() const;

/// Returns if the printer should retain identifier names collected using
/// parsing.
bool shouldPrintRetainedIdentifierNames() const;

private:
/// Elide large elements attributes if the number of elements is larger than
/// the upper limit.
Expand Down Expand Up @@ -1254,6 +1258,9 @@ class OpPrintingFlags {

/// Print unique SSA IDs for values, block arguments and naming conflicts
bool printUniqueSSAIDsFlag : 1;

/// Print the retained original names of identifiers
bool printRetainedIdentifierNamesFlag : 1;
};

//===----------------------------------------------------------------------===//
Expand Down
18 changes: 13 additions & 5 deletions mlir/include/mlir/IR/Value.h
Original file line number Diff line number Diff line change
Expand Up @@ -367,7 +367,8 @@ namespace detail {
/// This class provides the implementation for an operation result.
class alignas(8) OpResultImpl : public ValueImpl {
public:
using ValueImpl::ValueImpl;
OpResultImpl(Type type, Kind kind, Location loc)
: ValueImpl(type, kind), loc(loc) {}

static bool classof(const ValueImpl *value) {
return value->getKind() != ValueImpl::Kind::BlockArgument;
Expand All @@ -390,14 +391,17 @@ class alignas(8) OpResultImpl : public ValueImpl {
static unsigned getMaxInlineResults() {
return static_cast<unsigned>(Kind::OutOfLineOpResult);
}

/// The source location of this result.
Location loc;
};

/// This class provides the implementation for an operation result whose index
/// can be represented "inline" in the underlying ValueImpl.
struct InlineOpResult : public OpResultImpl {
public:
InlineOpResult(Type type, unsigned resultNo)
: OpResultImpl(type, static_cast<ValueImpl::Kind>(resultNo)) {
InlineOpResult(Type type, unsigned resultNo, Location loc)
: OpResultImpl(type, static_cast<ValueImpl::Kind>(resultNo), loc) {
assert(resultNo < getMaxInlineResults());
}

Expand All @@ -413,8 +417,8 @@ struct InlineOpResult : public OpResultImpl {
/// cannot be represented "inline", and thus requires an additional index field.
class OutOfLineOpResult : public OpResultImpl {
public:
OutOfLineOpResult(Type type, uint64_t outOfLineIndex)
: OpResultImpl(type, Kind::OutOfLineOpResult),
OutOfLineOpResult(Type type, uint64_t outOfLineIndex, Location loc)
: OpResultImpl(type, Kind::OutOfLineOpResult, loc),
outOfLineIndex(outOfLineIndex) {}

static bool classof(const OpResultImpl *value) {
Expand Down Expand Up @@ -468,6 +472,10 @@ class OpResult : public Value {
/// Returns the number of this result.
unsigned getResultNumber() const { return getImpl()->getResultNumber(); }

/// Return the location for this result.
Location getLoc() const { return getImpl()->loc; }
void setLoc(Location loc) { getImpl()->loc = loc; }

private:
/// Get a raw pointer to the internal implementation.
detail::OpResultImpl *getImpl() const {
Expand Down
9 changes: 9 additions & 0 deletions mlir/include/mlir/Tools/mlir-opt/MlirOptMain.h
Original file line number Diff line number Diff line change
Expand Up @@ -198,6 +198,12 @@ class MlirOptMainConfig {
}
bool shouldVerifyPasses() const { return verifyPassesFlag; }

MlirOptMainConfig &retainIdentifierNames(bool retain) {
retainIdentifierNamesFlag = retain;
return *this;
}
bool shouldRetainIdentifierNames() const { return retainIdentifierNamesFlag; }

/// Set whether to run the verifier on parsing.
MlirOptMainConfig &verifyOnParsing(bool verify) {
disableVerifierOnParsingFlag = !verify;
Expand Down Expand Up @@ -284,6 +290,9 @@ class MlirOptMainConfig {
/// Run the verifier after each transformation pass.
bool verifyPassesFlag = true;

/// Retain identifier names in the output (e.g., `%my_var` instead of `%0`).
bool retainIdentifierNamesFlag = false;

/// Disable the verifier on parsing.
bool disableVerifierOnParsingFlag = false;

Expand Down
103 changes: 81 additions & 22 deletions mlir/lib/AsmParser/Parser.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -543,6 +543,10 @@ Type Parser::codeCompleteDialectSymbol(const llvm::StringMap<Type> &aliases) {
//===----------------------------------------------------------------------===//

namespace {
/// This is the structure of a result specifier in the assembly syntax,
/// including the name, number of results, and location.
using ResultRecord = std::tuple<StringRef, unsigned, SMLoc>;

/// This class provides support for parsing operations and regions of
/// operations.
class OperationParser : public Parser {
Expand Down Expand Up @@ -618,7 +622,8 @@ class OperationParser : public Parser {
ParseResult parseSuccessors(SmallVectorImpl<Block *> &destinations);

/// Parse an operation instance that is in the generic form.
Operation *parseGenericOperation();
Operation *parseGenericOperation(
std::optional<ArrayRef<ResultRecord>> resultIDs = std::nullopt);

/// Parse different components, viz., use-info of operand(s), successor(s),
/// region(s), attribute(s) and function-type, of the generic form of an
Expand Down Expand Up @@ -659,10 +664,6 @@ class OperationParser : public Parser {
/// token is actually an alias, which means it must not contain a dot.
ParseResult parseLocationAlias(LocationAttr &loc);

/// This is the structure of a result specifier in the assembly syntax,
/// including the name, number of results, and location.
using ResultRecord = std::tuple<StringRef, unsigned, SMLoc>;

/// Parse an operation instance that is in the op-defined custom form.
/// resultInfo specifies information about the "%name =" specifiers.
Operation *parseCustomOperation(ArrayRef<ResultRecord> resultIDs);
Expand Down Expand Up @@ -1238,7 +1239,7 @@ ParseResult OperationParser::parseOperation() {
if (nameTok.is(Token::bare_identifier) || nameTok.isKeyword())
op = parseCustomOperation(resultIDs);
else if (nameTok.is(Token::string))
op = parseGenericOperation();
op = parseGenericOperation(resultIDs);
else if (nameTok.isCodeCompletionFor(Token::string))
return codeCompleteStringDialectOrOperationName(nameTok.getStringValue());
else if (nameTok.isCodeCompletion())
Expand Down Expand Up @@ -1344,6 +1345,38 @@ struct CleanupOpStateRegions {
}
OperationState &state;
};

std::pair<StringRef, unsigned> getResultName(ArrayRef<ResultRecord> resultIDs,
unsigned resultNo) {
// Scan for the resultID that contains this result number.
for (const auto &entry : resultIDs) {
if (resultNo < std::get<1>(entry)) {
// Don't pass on the leading %.
StringRef name = std::get<0>(entry).drop_front();
return {name, resultNo};
}
resultNo -= std::get<1>(entry);
}

// Invalid result number.
return {"", ~0U};
}

std::pair<SMLoc, unsigned> getResultLoc(ArrayRef<ResultRecord> resultIDs,
unsigned resultNo) {
// Scan for the resultID that contains this result number.
for (const auto &entry : resultIDs) {
if (resultNo < std::get<1>(entry)) {
SMLoc loc = std::get<2>(entry);
return {loc, resultNo};
}
resultNo -= std::get<1>(entry);
}

// Invalid result number.
return {SMLoc{}, ~0U};
}

} // namespace

ParseResult OperationParser::parseGenericOperationAfterOpName(
Expand Down Expand Up @@ -1457,7 +1490,8 @@ ParseResult OperationParser::parseGenericOperationAfterOpName(
return success();
}

Operation *OperationParser::parseGenericOperation() {
Operation *OperationParser::parseGenericOperation(
std::optional<ArrayRef<ResultRecord>> maybeResultIDs) {
// Get location information for the operation.
auto srcLocation = getEncodedSourceLocation(getToken().getLoc());

Expand Down Expand Up @@ -1531,6 +1565,17 @@ Operation *OperationParser::parseGenericOperation() {

// Create the operation and try to parse a location for it.
Operation *op = opBuilder.create(result);
if (state.config.shouldRetainIdentifierNames() && maybeResultIDs) {
for (OpResult opResult : op->getResults()) {
unsigned resultNum = opResult.getResultNumber();
Location resultLoc = getEncodedSourceLocation(
getResultLoc(*maybeResultIDs, resultNum).first);
opResult.setLoc(NameLoc::get(
StringAttr::get(state.config.getContext(),
getResultName(*maybeResultIDs, resultNum).first),
resultLoc));
}
}
if (parseTrailingLocationSpecifier(op))
return nullptr;

Expand Down Expand Up @@ -1571,7 +1616,7 @@ namespace {
class CustomOpAsmParser : public AsmParserImpl<OpAsmParser> {
public:
CustomOpAsmParser(
SMLoc nameLoc, ArrayRef<OperationParser::ResultRecord> resultIDs,
SMLoc nameLoc, ArrayRef<ResultRecord> resultIDs,
function_ref<ParseResult(OpAsmParser &, OperationState &)> parseAssembly,
bool isIsolatedFromAbove, StringRef opName, OperationParser &parser)
: AsmParserImpl<OpAsmParser>(nameLoc, parser), resultIDs(resultIDs),
Expand Down Expand Up @@ -1634,18 +1679,7 @@ class CustomOpAsmParser : public AsmParserImpl<OpAsmParser> {
/// getResultName(3) == {"z", 0 }
std::pair<StringRef, unsigned>
getResultName(unsigned resultNo) const override {
// Scan for the resultID that contains this result number.
for (const auto &entry : resultIDs) {
if (resultNo < std::get<1>(entry)) {
// Don't pass on the leading %.
StringRef name = std::get<0>(entry).drop_front();
return {name, resultNo};
}
resultNo -= std::get<1>(entry);
}

// Invalid result number.
return {"", ~0U};
return ::getResultName(resultIDs, resultNo);
}

/// Return the number of declared SSA results. This returns 4 for the foo.op
Expand Down Expand Up @@ -1962,7 +1996,7 @@ class CustomOpAsmParser : public AsmParserImpl<OpAsmParser> {

private:
/// Information about the result name specifiers.
ArrayRef<OperationParser::ResultRecord> resultIDs;
ArrayRef<ResultRecord> resultIDs;

/// The abstract information of the operation.
function_ref<ParseResult(OpAsmParser &, OperationState &)> parseAssembly;
Expand Down Expand Up @@ -2093,6 +2127,18 @@ OperationParser::parseCustomOperation(ArrayRef<ResultRecord> resultIDs) {

// Otherwise, create the operation and try to parse a location for it.
Operation *op = opBuilder.create(opState);

if (state.config.shouldRetainIdentifierNames()) {
for (OpResult opResult : op->getResults()) {
unsigned resultNum = opResult.getResultNumber();
Location resultLoc =
getEncodedSourceLocation(getResultLoc(resultIDs, resultNum).first);
StringRef resName = opAsmParser.getResultName(resultNum).first;
opResult.setLoc(NameLoc::get(
StringAttr::get(state.config.getContext(), resName), resultLoc));
}
}

if (parseTrailingLocationSpecifier(op))
return nullptr;

Expand Down Expand Up @@ -2159,8 +2205,11 @@ OperationParser::parseTrailingLocationSpecifier(OpOrArgument opOrArgument) {
if (parseToken(Token::r_paren, "expected ')' in location"))
return failure();

if (auto *op = llvm::dyn_cast_if_present<Operation *>(opOrArgument))
if (auto *op = llvm::dyn_cast_if_present<Operation *>(opOrArgument)) {
op->setLoc(directLoc);
for (auto result : op->getResults())
result.setLoc(directLoc);
}
else
opOrArgument.get<BlockArgument>().setLoc(directLoc);
return success();
Expand Down Expand Up @@ -2235,6 +2284,11 @@ ParseResult OperationParser::parseRegionBody(Region &region, SMLoc startLoc,
Location loc = entryArg.sourceLoc.has_value()
? *entryArg.sourceLoc
: getEncodedSourceLocation(argInfo.location);
if (state.config.shouldRetainIdentifierNames()) {
loc = NameLoc::get(StringAttr::get(state.config.getContext(),
entryArg.ssaName.name.drop_front(1)),
loc);
}
BlockArgument arg = block->addArgument(entryArg.type, loc);

// Add a definition of this arg to the assembly state if provided.
Expand Down Expand Up @@ -2415,6 +2469,11 @@ ParseResult OperationParser::parseOptionalBlockArgList(Block *owner) {
return emitError("argument and block argument type mismatch");
} else {
auto loc = getEncodedSourceLocation(useInfo.location);
if (state.config.shouldRetainIdentifierNames()) {
loc = NameLoc::get(StringAttr::get(state.config.getContext(),
useInfo.name.drop_front(1)),
loc);
}
arg = owner->addArgument(type, loc);
}

Expand Down
Loading
Loading