Skip to content

Commit b257ab5

Browse files
committed
Added region arg support & ambiguous name test
1 parent 233a987 commit b257ab5

File tree

3 files changed

+115
-62
lines changed

3 files changed

+115
-62
lines changed

mlir/lib/AsmParser/Parser.cpp

Lines changed: 22 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -611,8 +611,9 @@ class OperationParser : public Parser {
611611
/// an object of type 'OperationName'. Otherwise, failure is returned.
612612
FailureOr<OperationName> parseCustomOperationName();
613613

614-
/// Store the SSA names for the current operation as attrs for debug purposes.
615-
void storeSSANames(Operation *&op, ArrayRef<ResultRecord> resultIDs);
614+
/// Store the identifier names for the current operation as attrs for debug
615+
/// purposes.
616+
void storeIdentifierNames(Operation *&op, ArrayRef<ResultRecord> resultIDs);
616617
DenseMap<Value, StringRef> argNames;
617618

618619
//===--------------------------------------------------------------------===//
@@ -1273,8 +1274,8 @@ OperationParser::parseSuccessors(SmallVectorImpl<Block *> &destinations) {
12731274
}
12741275

12751276
/// Store the SSA names for the current operation as attrs for debug purposes.
1276-
void OperationParser::storeSSANames(Operation *&op,
1277-
ArrayRef<ResultRecord> resultIDs) {
1277+
void OperationParser::storeIdentifierNames(Operation *&op,
1278+
ArrayRef<ResultRecord> resultIDs) {
12781279

12791280
// Store the name(s) of the result(s) of this operation.
12801281
if (op->getNumResults() > 0) {
@@ -1322,6 +1323,18 @@ void OperationParser::storeSSANames(Operation *&op,
13221323
}
13231324
}
13241325
}
1326+
1327+
// Store names of region arguments (e.g., for FuncOps)
1328+
if (op->getNumRegions() > 0 && op->getRegion(0).getNumArguments() > 0) {
1329+
llvm::SmallVector<llvm::StringRef, 1> regionArgNames;
1330+
for (BlockArgument arg : op->getRegion(0).getArguments()) {
1331+
auto it = argNames.find(arg);
1332+
if (it != argNames.end()) {
1333+
regionArgNames.push_back(it->second.drop_front(1));
1334+
}
1335+
}
1336+
op->setAttr("mlir.regionArgNames", builder.getStrArrayAttr(regionArgNames));
1337+
}
13251338
}
13261339

13271340
namespace {
@@ -2093,9 +2106,9 @@ OperationParser::parseCustomOperation(ArrayRef<ResultRecord> resultIDs) {
20932106
// Otherwise, create the operation and try to parse a location for it.
20942107
Operation *op = opBuilder.create(opState);
20952108

2096-
// If enabled, store the SSA name(s) for the operation
2109+
// If enabled, store the original identifier name(s) for the operation
20972110
if (state.config.shouldRetainIdentifierNames())
2098-
storeSSANames(op, resultIDs);
2111+
storeIdentifierNames(op, resultIDs);
20992112

21002113
if (parseTrailingLocationSpecifier(op))
21012114
return nullptr;
@@ -2246,6 +2259,9 @@ ParseResult OperationParser::parseRegionBody(Region &region, SMLoc startLoc,
22462259
if (state.asmState)
22472260
state.asmState->addDefinition(arg, argInfo.location);
22482261

2262+
if (state.config.shouldRetainIdentifierNames())
2263+
argNames.insert({arg, argInfo.name});
2264+
22492265
// Record the definition for this argument.
22502266
if (addDefinition(argInfo, arg))
22512267
return failure();

mlir/lib/IR/AsmPrinter.cpp

Lines changed: 69 additions & 53 deletions
Original file line numberDiff line numberDiff line change
@@ -1303,7 +1303,9 @@ class SSANameState {
13031303
/// Set the original identifier names if available. Used in debugging with
13041304
/// `--retain-identifier-names`/`shouldRetainIdentifierNames` in ParserConfig
13051305
void setRetainedIdentifierNames(Operation &op,
1306-
SmallVector<int, 2> &resultGroups);
1306+
SmallVector<int, 2> &resultGroups,
1307+
bool hasRegion = false);
1308+
void setRetainedIdentifierNames(Region &region);
13071309

13081310
/// This is the value ID for each SSA value. If this returns NameSentinel,
13091311
/// then the valueID has an entry in valueNames.
@@ -1492,6 +1494,9 @@ void SSANameState::numberValuesInRegion(Region &region) {
14921494
setValueName(arg, name);
14931495
};
14941496

1497+
// Use manually specified region arg names if available
1498+
setRetainedIdentifierNames(region);
1499+
14951500
if (!printerFlags.shouldPrintGenericOpForm()) {
14961501
if (Operation *op = region.getParentOp()) {
14971502
if (auto asmInterface = dyn_cast<OpAsmOpInterface>(op))
@@ -1603,64 +1608,75 @@ void SSANameState::numberValuesInOp(Operation &op) {
16031608
}
16041609
}
16051610

1606-
void SSANameState::setRetainedIdentifierNames(
1607-
Operation &op, SmallVector<int, 2> &resultGroups) {
1608-
// Get the original names for the results if available
1609-
if (ArrayAttr resultNamesAttr =
1610-
op.getAttrOfType<ArrayAttr>("mlir.resultNames")) {
1611-
auto resultNames = resultNamesAttr.getValue();
1612-
auto results = op.getResults();
1613-
// Conservative in the case that the #results has changed
1614-
for (size_t i = 0; i < results.size() && i < resultNames.size(); ++i) {
1615-
auto resultName = resultNames[i].cast<StringAttr>().strref();
1616-
if (!resultName.empty()) {
1617-
if (!usedNames.count(resultName))
1618-
setValueName(results[i], resultName, /*allowNumeric=*/true);
1619-
// If a result has a name, it is the start of a result group.
1620-
if (i > 0)
1621-
resultGroups.push_back(i);
1622-
}
1623-
}
1624-
op.removeDiscardableAttr("mlir.resultNames");
1625-
}
1626-
1627-
// Get the original name for the op args if available
1628-
if (ArrayAttr opArgNamesAttr =
1629-
op.getAttrOfType<ArrayAttr>("mlir.opArgNames")) {
1630-
auto opArgNames = opArgNamesAttr.getValue();
1631-
auto opArgs = op.getOperands();
1632-
// Conservative in the case that the #operands has changed
1633-
for (size_t i = 0; i < opArgs.size() && i < opArgNames.size(); ++i) {
1634-
auto opArgName = opArgNames[i].cast<StringAttr>().strref();
1635-
if (!usedNames.count(opArgName))
1636-
setValueName(opArgs[i], opArgName, /*allowNumeric=*/true);
1611+
void SSANameState::setRetainedIdentifierNames(Operation &op,
1612+
SmallVector<int, 2> &resultGroups,
1613+
bool hasRegion) {
1614+
1615+
// Lambda which fetches the list of relevant attributes (e.g.,
1616+
// mlir.resultNames) and associates them with the relevant values
1617+
auto handleNamedAttributes =
1618+
[this](Operation &op, const Twine &attrName, auto getValuesFunc,
1619+
std::optional<std::function<void(int)>> customAction =
1620+
std::nullopt) {
1621+
if (ArrayAttr namesAttr = op.getAttrOfType<ArrayAttr>(attrName.str())) {
1622+
auto names = namesAttr.getValue();
1623+
auto values = getValuesFunc();
1624+
// Conservative in case the number of values has changed
1625+
for (size_t i = 0; i < values.size() && i < names.size(); ++i) {
1626+
auto name = names[i].cast<StringAttr>().strref();
1627+
if (!name.empty()) {
1628+
if (!this->usedNames.count(name))
1629+
this->setValueName(values[i], name, true);
1630+
if (customAction.has_value())
1631+
customAction.value()(i);
1632+
}
1633+
}
1634+
op.removeDiscardableAttr(attrName.str());
1635+
}
1636+
};
1637+
1638+
if (hasRegion) {
1639+
// Get the original name(s) for the region arg(s) if available (e.g., for
1640+
// FuncOp args). Requires hasRegion flag to ensure scoping is correct
1641+
if (hasRegion && op.getNumRegions() > 0 &&
1642+
op.getRegion(0).getNumArguments() > 0) {
1643+
handleNamedAttributes(op, "mlir.regionArgNames",
1644+
[&]() { return op.getRegion(0).getArguments(); });
16371645
}
1638-
op.removeDiscardableAttr("mlir.opArgNames");
1639-
}
1640-
1641-
// Get the original name for the block if available
1642-
if (StringAttr blockNameAttr =
1643-
op.getAttrOfType<StringAttr>("mlir.blockName")) {
1644-
blockNames[op.getBlock()] = {-1, blockNameAttr.strref()};
1645-
op.removeDiscardableAttr("mlir.blockName");
1646-
}
1647-
1648-
// Get the original name for the block args if available
1649-
if (ArrayAttr blockArgNamesAttr =
1650-
op.getAttrOfType<ArrayAttr>("mlir.blockArgNames")) {
1651-
auto blockArgNames = blockArgNamesAttr.getValue();
1652-
auto blockArgs = op.getBlock()->getArguments();
1653-
// Conservative in the case that the #args has changed
1654-
for (size_t i = 0; i < blockArgs.size() && i < blockArgNames.size(); ++i) {
1655-
auto blockArgName = blockArgNames[i].cast<StringAttr>().strref();
1656-
if (!usedNames.count(blockArgName))
1657-
setValueName(blockArgs[i], blockArgName, /*allowNumeric=*/true);
1646+
} else {
1647+
// Get the original names for the results if available
1648+
handleNamedAttributes(
1649+
op, "mlir.resultNames", [&]() { return op.getResults(); },
1650+
[&resultGroups](int i) { /*handles result groups*/
1651+
if (i > 0)
1652+
resultGroups.push_back(i);
1653+
});
1654+
1655+
// Get the original name for the op args if available
1656+
handleNamedAttributes(op, "mlir.opArgNames",
1657+
[&]() { return op.getOperands(); });
1658+
1659+
// Get the original name for the block if available
1660+
if (StringAttr blockNameAttr =
1661+
op.getAttrOfType<StringAttr>("mlir.blockName")) {
1662+
blockNames[op.getBlock()] = {-1, blockNameAttr.strref()};
1663+
op.removeDiscardableAttr("mlir.blockName");
16581664
}
1659-
op.removeDiscardableAttr("mlir.blockArgNames");
1665+
1666+
// Get the original name(s) for the block arg(s) if available
1667+
handleNamedAttributes(op, "mlir.blockArgNames",
1668+
[&]() { return op.getBlock()->getArguments(); });
16601669
}
16611670
return;
16621671
}
16631672

1673+
void SSANameState::setRetainedIdentifierNames(Region &region) {
1674+
if (Operation *op = region.getParentOp()) {
1675+
SmallVector<int, 2> resultGroups;
1676+
setRetainedIdentifierNames(*op, resultGroups, true);
1677+
}
1678+
}
1679+
16641680
void SSANameState::getResultIDAndNumber(
16651681
OpResult result, Value &lookupValue,
16661682
std::optional<int> &lookupResultNo) const {

mlir/test/IR/print-retain-identifiers.mlir

Lines changed: 24 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -5,8 +5,8 @@
55
// Test SSA results (with single return values)
66
//===----------------------------------------------------------------------===//
77

8-
// CHECK: func.func @add_one(%my_input: f64, %arg1: f64) -> f64 {
9-
func.func @add_one(%my_input: f64, %arg1: f64) -> f64 {
8+
// CHECK: func.func @add_one(%my_input: f64) -> f64 {
9+
func.func @add_one(%my_input: f64) -> f64 {
1010
// CHECK: %my_constant = arith.constant 1.000000e+00 : f64
1111
%my_constant = arith.constant 1.000000e+00 : f64
1212
// CHECK: %my_output = arith.addf %my_input, %my_constant : f64
@@ -71,7 +71,7 @@ func.func @select_min_max(%a: f64, %b: f64) -> (f64, f64) {
7171

7272
// -----
7373

74-
////===----------------------------------------------------------------------===//
74+
//===----------------------------------------------------------------------===//
7575
// Test multiple return values, with a grouped value tuple
7676
//===----------------------------------------------------------------------===//
7777

@@ -90,3 +90,24 @@ func.func @select_max(%a: f64, %b: f64, %c: f64, %d: f64) -> (f64, f64, f64, f64
9090
}
9191

9292
// -----
93+
94+
//===----------------------------------------------------------------------===//
95+
// Test identifiers which may clash with OpAsmOpInterface names (e.g., cst, %1, etc)
96+
//===----------------------------------------------------------------------===//
97+
98+
// CHECK: func.func @clash(%arg1: f64, %arg0: f64, %arg2: f64) -> f64 {
99+
func.func @clash(%arg1: f64, %arg0: f64, %arg2: f64) -> f64 {
100+
%my_constant = arith.constant 1.000000e+00 : f64
101+
// CHECK: %cst = arith.constant 2.000000e+00 : f64
102+
%cst = arith.constant 2.000000e+00 : f64
103+
// CHECK: %cst_1 = arith.constant 3.000000e+00 : f64
104+
%cst_1 = arith.constant 3.000000e+00 : f64
105+
// CHECK: %1 = arith.addf %arg1, %cst : f64
106+
%1 = arith.addf %arg1, %cst : f64
107+
// CHECK: %0 = arith.addf %arg1, %cst_1 : f64
108+
%0 = arith.addf %arg1, %cst_1 : f64
109+
// CHECK: return %1 : f64
110+
return %1 : f64
111+
}
112+
113+
// -----

0 commit comments

Comments
 (0)