Skip to content

Commit d637d02

Browse files
committed
[mlir] retain identifier names
1 parent 60325ab commit d637d02

File tree

8 files changed

+174
-41
lines changed

8 files changed

+174
-41
lines changed

mlir/include/mlir/IR/AsmState.h

Lines changed: 8 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -471,8 +471,10 @@ class ParserConfig {
471471
/// `fallbackResourceMap` is an optional fallback handler that can be used to
472472
/// parse external resources not explicitly handled by another parser.
473473
ParserConfig(MLIRContext *context, bool verifyAfterParse = true,
474-
FallbackAsmResourceMap *fallbackResourceMap = nullptr)
474+
FallbackAsmResourceMap *fallbackResourceMap = nullptr,
475+
bool retainIdentifierNames = false)
475476
: context(context), verifyAfterParse(verifyAfterParse),
477+
retainIdentifierNames(retainIdentifierNames),
476478
fallbackResourceMap(fallbackResourceMap) {
477479
assert(context && "expected valid MLIR context");
478480
}
@@ -483,6 +485,10 @@ class ParserConfig {
483485
/// Returns if the parser should verify the IR after parsing.
484486
bool shouldVerifyAfterParse() const { return verifyAfterParse; }
485487

488+
/// Returns if the parser should retain identifier names collected using
489+
/// parsing.
490+
bool shouldRetainIdentifierNames() const { return retainIdentifierNames; }
491+
486492
/// Returns the parsing configurations associated to the bytecode read.
487493
BytecodeReaderConfig &getBytecodeReaderConfig() const {
488494
return const_cast<BytecodeReaderConfig &>(bytecodeReaderConfig);
@@ -520,6 +526,7 @@ class ParserConfig {
520526
private:
521527
MLIRContext *context;
522528
bool verifyAfterParse;
529+
bool retainIdentifierNames;
523530
DenseMap<StringRef, std::unique_ptr<AsmResourceParser>> resourceParsers;
524531
FallbackAsmResourceMap *fallbackResourceMap;
525532
BytecodeReaderConfig bytecodeReaderConfig;

mlir/include/mlir/IR/OperationSupport.h

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1221,6 +1221,10 @@ class OpPrintingFlags {
12211221
/// Return if printer should use unique SSA IDs.
12221222
bool shouldPrintUniqueSSAIDs() const;
12231223

1224+
/// Returns if the printer should retain identifier names collected using
1225+
/// parsing.
1226+
bool shouldPrintRetainedIdentifierNames() const;
1227+
12241228
private:
12251229
/// Elide large elements attributes if the number of elements is larger than
12261230
/// the upper limit.
@@ -1254,6 +1258,9 @@ class OpPrintingFlags {
12541258

12551259
/// Print unique SSA IDs for values, block arguments and naming conflicts
12561260
bool printUniqueSSAIDsFlag : 1;
1261+
1262+
/// Print the retained original names of identifiers
1263+
bool printRetainedIdentifierNamesFlag : 1;
12571264
};
12581265

12591266
//===----------------------------------------------------------------------===//

mlir/include/mlir/IR/Value.h

Lines changed: 13 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -367,7 +367,8 @@ namespace detail {
367367
/// This class provides the implementation for an operation result.
368368
class alignas(8) OpResultImpl : public ValueImpl {
369369
public:
370-
using ValueImpl::ValueImpl;
370+
OpResultImpl(Type type, Kind kind, Location loc)
371+
: ValueImpl(type, kind), loc(loc) {}
371372

372373
static bool classof(const ValueImpl *value) {
373374
return value->getKind() != ValueImpl::Kind::BlockArgument;
@@ -390,14 +391,17 @@ class alignas(8) OpResultImpl : public ValueImpl {
390391
static unsigned getMaxInlineResults() {
391392
return static_cast<unsigned>(Kind::OutOfLineOpResult);
392393
}
394+
395+
/// The source location of this result.
396+
Location loc;
393397
};
394398

395399
/// This class provides the implementation for an operation result whose index
396400
/// can be represented "inline" in the underlying ValueImpl.
397401
struct InlineOpResult : public OpResultImpl {
398402
public:
399-
InlineOpResult(Type type, unsigned resultNo)
400-
: OpResultImpl(type, static_cast<ValueImpl::Kind>(resultNo)) {
403+
InlineOpResult(Type type, unsigned resultNo, Location loc)
404+
: OpResultImpl(type, static_cast<ValueImpl::Kind>(resultNo), loc) {
401405
assert(resultNo < getMaxInlineResults());
402406
}
403407

@@ -413,8 +417,8 @@ struct InlineOpResult : public OpResultImpl {
413417
/// cannot be represented "inline", and thus requires an additional index field.
414418
class OutOfLineOpResult : public OpResultImpl {
415419
public:
416-
OutOfLineOpResult(Type type, uint64_t outOfLineIndex)
417-
: OpResultImpl(type, Kind::OutOfLineOpResult),
420+
OutOfLineOpResult(Type type, uint64_t outOfLineIndex, Location loc)
421+
: OpResultImpl(type, Kind::OutOfLineOpResult, loc),
418422
outOfLineIndex(outOfLineIndex) {}
419423

420424
static bool classof(const OpResultImpl *value) {
@@ -468,6 +472,10 @@ class OpResult : public Value {
468472
/// Returns the number of this result.
469473
unsigned getResultNumber() const { return getImpl()->getResultNumber(); }
470474

475+
/// Return the location for this result.
476+
Location getLoc() const { return getImpl()->loc; }
477+
void setLoc(Location loc) { getImpl()->loc = loc; }
478+
471479
private:
472480
/// Get a raw pointer to the internal implementation.
473481
detail::OpResultImpl *getImpl() const {

mlir/include/mlir/Tools/mlir-opt/MlirOptMain.h

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -198,6 +198,12 @@ class MlirOptMainConfig {
198198
}
199199
bool shouldVerifyPasses() const { return verifyPassesFlag; }
200200

201+
MlirOptMainConfig &retainIdentifierNames(bool retain) {
202+
retainIdentifierNamesFlag = retain;
203+
return *this;
204+
}
205+
bool shouldRetainIdentifierNames() const { return retainIdentifierNamesFlag; }
206+
201207
/// Set whether to run the verifier on parsing.
202208
MlirOptMainConfig &verifyOnParsing(bool verify) {
203209
disableVerifierOnParsingFlag = !verify;
@@ -284,6 +290,9 @@ class MlirOptMainConfig {
284290
/// Run the verifier after each transformation pass.
285291
bool verifyPassesFlag = true;
286292

293+
/// Retain identifier names in the output (e.g., `%my_var` instead of `%0`).
294+
bool retainIdentifierNamesFlag = false;
295+
287296
/// Disable the verifier on parsing.
288297
bool disableVerifierOnParsingFlag = false;
289298

mlir/lib/AsmParser/Parser.cpp

Lines changed: 77 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -543,6 +543,10 @@ Type Parser::codeCompleteDialectSymbol(const llvm::StringMap<Type> &aliases) {
543543
//===----------------------------------------------------------------------===//
544544

545545
namespace {
546+
/// This is the structure of a result specifier in the assembly syntax,
547+
/// including the name, number of results, and location.
548+
using ResultRecord = std::tuple<StringRef, unsigned, SMLoc>;
549+
546550
/// This class provides support for parsing operations and regions of
547551
/// operations.
548552
class OperationParser : public Parser {
@@ -618,7 +622,8 @@ class OperationParser : public Parser {
618622
ParseResult parseSuccessors(SmallVectorImpl<Block *> &destinations);
619623

620624
/// Parse an operation instance that is in the generic form.
621-
Operation *parseGenericOperation();
625+
Operation *parseGenericOperation(
626+
std::optional<ArrayRef<ResultRecord>> resultIDs = std::nullopt);
622627

623628
/// Parse different components, viz., use-info of operand(s), successor(s),
624629
/// region(s), attribute(s) and function-type, of the generic form of an
@@ -659,10 +664,6 @@ class OperationParser : public Parser {
659664
/// token is actually an alias, which means it must not contain a dot.
660665
ParseResult parseLocationAlias(LocationAttr &loc);
661666

662-
/// This is the structure of a result specifier in the assembly syntax,
663-
/// including the name, number of results, and location.
664-
using ResultRecord = std::tuple<StringRef, unsigned, SMLoc>;
665-
666667
/// Parse an operation instance that is in the op-defined custom form.
667668
/// resultInfo specifies information about the "%name =" specifiers.
668669
Operation *parseCustomOperation(ArrayRef<ResultRecord> resultIDs);
@@ -1238,7 +1239,7 @@ ParseResult OperationParser::parseOperation() {
12381239
if (nameTok.is(Token::bare_identifier) || nameTok.isKeyword())
12391240
op = parseCustomOperation(resultIDs);
12401241
else if (nameTok.is(Token::string))
1241-
op = parseGenericOperation();
1242+
op = parseGenericOperation(resultIDs);
12421243
else if (nameTok.isCodeCompletionFor(Token::string))
12431244
return codeCompleteStringDialectOrOperationName(nameTok.getStringValue());
12441245
else if (nameTok.isCodeCompletion())
@@ -1344,6 +1345,38 @@ struct CleanupOpStateRegions {
13441345
}
13451346
OperationState &state;
13461347
};
1348+
1349+
std::pair<StringRef, unsigned> getResultName(ArrayRef<ResultRecord> resultIDs,
1350+
unsigned resultNo) {
1351+
// Scan for the resultID that contains this result number.
1352+
for (const auto &entry : resultIDs) {
1353+
if (resultNo < std::get<1>(entry)) {
1354+
// Don't pass on the leading %.
1355+
StringRef name = std::get<0>(entry).drop_front();
1356+
return {name, resultNo};
1357+
}
1358+
resultNo -= std::get<1>(entry);
1359+
}
1360+
1361+
// Invalid result number.
1362+
return {"", ~0U};
1363+
}
1364+
1365+
std::pair<SMLoc, unsigned> getResultLoc(ArrayRef<ResultRecord> resultIDs,
1366+
unsigned resultNo) {
1367+
// Scan for the resultID that contains this result number.
1368+
for (const auto &entry : resultIDs) {
1369+
if (resultNo < std::get<1>(entry)) {
1370+
SMLoc loc = std::get<2>(entry);
1371+
return {loc, resultNo};
1372+
}
1373+
resultNo -= std::get<1>(entry);
1374+
}
1375+
1376+
// Invalid result number.
1377+
return {SMLoc{}, ~0U};
1378+
}
1379+
13471380
} // namespace
13481381

13491382
ParseResult OperationParser::parseGenericOperationAfterOpName(
@@ -1457,7 +1490,8 @@ ParseResult OperationParser::parseGenericOperationAfterOpName(
14571490
return success();
14581491
}
14591492

1460-
Operation *OperationParser::parseGenericOperation() {
1493+
Operation *OperationParser::parseGenericOperation(
1494+
std::optional<ArrayRef<ResultRecord>> maybeResultIDs) {
14611495
// Get location information for the operation.
14621496
auto srcLocation = getEncodedSourceLocation(getToken().getLoc());
14631497

@@ -1531,6 +1565,17 @@ Operation *OperationParser::parseGenericOperation() {
15311565

15321566
// Create the operation and try to parse a location for it.
15331567
Operation *op = opBuilder.create(result);
1568+
if (state.config.shouldRetainIdentifierNames() && maybeResultIDs) {
1569+
for (OpResult opResult : op->getResults()) {
1570+
unsigned resultNum = opResult.getResultNumber();
1571+
Location resultLoc = getEncodedSourceLocation(
1572+
getResultLoc(*maybeResultIDs, resultNum).first);
1573+
opResult.setLoc(NameLoc::get(
1574+
StringAttr::get(state.config.getContext(),
1575+
getResultName(*maybeResultIDs, resultNum).first),
1576+
resultLoc));
1577+
}
1578+
}
15341579
if (parseTrailingLocationSpecifier(op))
15351580
return nullptr;
15361581

@@ -1571,7 +1616,7 @@ namespace {
15711616
class CustomOpAsmParser : public AsmParserImpl<OpAsmParser> {
15721617
public:
15731618
CustomOpAsmParser(
1574-
SMLoc nameLoc, ArrayRef<OperationParser::ResultRecord> resultIDs,
1619+
SMLoc nameLoc, ArrayRef<ResultRecord> resultIDs,
15751620
function_ref<ParseResult(OpAsmParser &, OperationState &)> parseAssembly,
15761621
bool isIsolatedFromAbove, StringRef opName, OperationParser &parser)
15771622
: AsmParserImpl<OpAsmParser>(nameLoc, parser), resultIDs(resultIDs),
@@ -1634,18 +1679,7 @@ class CustomOpAsmParser : public AsmParserImpl<OpAsmParser> {
16341679
/// getResultName(3) == {"z", 0 }
16351680
std::pair<StringRef, unsigned>
16361681
getResultName(unsigned resultNo) const override {
1637-
// Scan for the resultID that contains this result number.
1638-
for (const auto &entry : resultIDs) {
1639-
if (resultNo < std::get<1>(entry)) {
1640-
// Don't pass on the leading %.
1641-
StringRef name = std::get<0>(entry).drop_front();
1642-
return {name, resultNo};
1643-
}
1644-
resultNo -= std::get<1>(entry);
1645-
}
1646-
1647-
// Invalid result number.
1648-
return {"", ~0U};
1682+
return ::getResultName(resultIDs, resultNo);
16491683
}
16501684

16511685
/// Return the number of declared SSA results. This returns 4 for the foo.op
@@ -1962,7 +1996,7 @@ class CustomOpAsmParser : public AsmParserImpl<OpAsmParser> {
19621996

19631997
private:
19641998
/// Information about the result name specifiers.
1965-
ArrayRef<OperationParser::ResultRecord> resultIDs;
1999+
ArrayRef<ResultRecord> resultIDs;
19662000

19672001
/// The abstract information of the operation.
19682002
function_ref<ParseResult(OpAsmParser &, OperationState &)> parseAssembly;
@@ -2093,6 +2127,18 @@ OperationParser::parseCustomOperation(ArrayRef<ResultRecord> resultIDs) {
20932127

20942128
// Otherwise, create the operation and try to parse a location for it.
20952129
Operation *op = opBuilder.create(opState);
2130+
2131+
if (state.config.shouldRetainIdentifierNames()) {
2132+
for (OpResult opResult : op->getResults()) {
2133+
unsigned resultNum = opResult.getResultNumber();
2134+
Location resultLoc =
2135+
getEncodedSourceLocation(getResultLoc(resultIDs, resultNum).first);
2136+
StringRef resName = opAsmParser.getResultName(resultNum).first;
2137+
opResult.setLoc(NameLoc::get(
2138+
StringAttr::get(state.config.getContext(), resName), resultLoc));
2139+
}
2140+
}
2141+
20962142
if (parseTrailingLocationSpecifier(op))
20972143
return nullptr;
20982144

@@ -2235,6 +2281,11 @@ ParseResult OperationParser::parseRegionBody(Region &region, SMLoc startLoc,
22352281
Location loc = entryArg.sourceLoc.has_value()
22362282
? *entryArg.sourceLoc
22372283
: getEncodedSourceLocation(argInfo.location);
2284+
if (state.config.shouldRetainIdentifierNames()) {
2285+
loc = NameLoc::get(StringAttr::get(state.config.getContext(),
2286+
entryArg.ssaName.name.drop_front(1)),
2287+
loc);
2288+
}
22382289
BlockArgument arg = block->addArgument(entryArg.type, loc);
22392290

22402291
// Add a definition of this arg to the assembly state if provided.
@@ -2415,6 +2466,11 @@ ParseResult OperationParser::parseOptionalBlockArgList(Block *owner) {
24152466
return emitError("argument and block argument type mismatch");
24162467
} else {
24172468
auto loc = getEncodedSourceLocation(useInfo.location);
2469+
if (state.config.shouldRetainIdentifierNames()) {
2470+
loc = NameLoc::get(StringAttr::get(state.config.getContext(),
2471+
useInfo.name.drop_front(1)),
2472+
loc);
2473+
}
24182474
arg = owner->addArgument(type, loc);
24192475
}
24202476

0 commit comments

Comments
 (0)