Skip to content

Commit d062cb8

Browse files
committed
Added support for result groups
1 parent 467c3a3 commit d062cb8

File tree

3 files changed

+72
-21
lines changed

3 files changed

+72
-21
lines changed

mlir/lib/AsmParser/Parser.cpp

Lines changed: 11 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -1275,15 +1275,18 @@ OperationParser::parseSuccessors(SmallVectorImpl<Block *> &destinations) {
12751275
/// Store the SSA names for the current operation as attrs for debug purposes.
12761276
void OperationParser::storeSSANames(Operation *&op,
12771277
ArrayRef<ResultRecord> resultIDs) {
1278-
if (op->getNumResults() > 1)
1279-
emitError("have not yet implemented support for multiple return values\n");
1280-
1281-
for (const ResultRecord &resIt : resultIDs) {
1282-
for (unsigned subRes : llvm::seq<unsigned>(0, std::get<1>(resIt))) {
1283-
op->setDiscardableAttr(
1284-
"mlir.ssaName",
1285-
StringAttr::get(getContext(), std::get<0>(resIt).drop_front(1)));
1278+
1279+
// Store the name(s) of the result(s) of this operation.
1280+
if (op->getNumResults() > 0) {
1281+
llvm::SmallVector<llvm::StringRef, 1> resultNames;
1282+
for (const ResultRecord &resIt : resultIDs) {
1283+
resultNames.push_back(std::get<0>(resIt).drop_front(1));
1284+
// Insert empty string for sub-results/result groups
1285+
for (unsigned int i = 1; i < std::get<1>(resIt); ++i)
1286+
resultNames.push_back(llvm::StringRef());
12861287
}
1288+
op->setDiscardableAttr("mlir.resultNames",
1289+
builder.getStrArrayAttr(resultNames));
12871290
}
12881291

12891292
// Store the name information of the arguments of this operation.

mlir/lib/IR/AsmPrinter.cpp

Lines changed: 23 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -1300,7 +1300,8 @@ class SSANameState {
13001300

13011301
/// Set the original identifier names if available. Used in debugging with
13021302
/// `--retain-identifier-names`/`shouldRetainIdentifierNames` in ParserConfig
1303-
void setRetainedIdentifierNames(Operation &op);
1303+
void setRetainedIdentifierNames(Operation &op,
1304+
SmallVector<int, 2> &resultGroups);
13041305

13051306
/// This is the value ID for each SSA value. If this returns NameSentinel,
13061307
/// then the valueID has an entry in valueNames.
@@ -1573,7 +1574,7 @@ void SSANameState::numberValuesInOp(Operation &op) {
15731574

15741575
// Set the original identifier names if available. Used in debugging with
15751576
// `--retain-identifier-names`/`shouldRetainIdentifierNames` in ParserConfig
1576-
setRetainedIdentifierNames(op);
1577+
setRetainedIdentifierNames(op, resultGroups);
15771578

15781579
unsigned numResults = op.getNumResults();
15791580
if (numResults == 0) {
@@ -1597,25 +1598,33 @@ void SSANameState::numberValuesInOp(Operation &op) {
15971598
}
15981599
}
15991600

1600-
void SSANameState::setRetainedIdentifierNames(Operation &op) {
1601-
// Get the original SSA for the result(s) if available
1602-
unsigned numResults = op.getNumResults();
1603-
if (numResults > 1)
1604-
llvm::outs()
1605-
<< "have not yet implemented support for multiple return values\n";
1606-
else if (numResults == 1) {
1607-
Value resultBegin = op.getResult(0);
1608-
if (StringAttr ssaNameAttr = op.getAttrOfType<StringAttr>("mlir.ssaName")) {
1609-
setValueName(resultBegin, ssaNameAttr.strref());
1610-
op.removeDiscardableAttr("mlir.ssaName");
1601+
void SSANameState::setRetainedIdentifierNames(
1602+
Operation &op, SmallVector<int, 2> &resultGroups) {
1603+
// Get the original names for the results if available
1604+
if (ArrayAttr resultNamesAttr =
1605+
op.getAttrOfType<ArrayAttr>("mlir.resultNames")) {
1606+
auto resultNames = resultNamesAttr.getValue();
1607+
auto results = op.getResults();
1608+
// Conservative in the case that the #results has changed
1609+
for (size_t i = 0; i < results.size() && i < resultNames.size(); ++i) {
1610+
auto resultName = resultNames[i].cast<StringAttr>().strref();
1611+
if (!resultName.empty()) {
1612+
if (!usedNames.count(resultName))
1613+
setValueName(results[i], resultName);
1614+
// If a result has a name, it is the start of a result group.
1615+
if (i > 0)
1616+
resultGroups.push_back(i);
1617+
}
16111618
}
1619+
op.removeDiscardableAttr("mlir.resultNames");
16121620
}
16131621

16141622
// Get the original name for the op args if available
16151623
if (ArrayAttr opArgNamesAttr =
16161624
op.getAttrOfType<ArrayAttr>("mlir.opArgNames")) {
16171625
auto opArgNames = opArgNamesAttr.getValue();
16181626
auto opArgs = op.getOperands();
1627+
// Conservative in the case that the #operands has changed
16191628
for (size_t i = 0; i < opArgs.size() && i < opArgNames.size(); ++i) {
16201629
auto opArgName = opArgNames[i].cast<StringAttr>().strref();
16211630
if (!usedNames.count(opArgName))
@@ -1636,6 +1645,7 @@ void SSANameState::setRetainedIdentifierNames(Operation &op) {
16361645
op.getAttrOfType<ArrayAttr>("mlir.blockArgNames")) {
16371646
auto blockArgNames = blockArgNamesAttr.getValue();
16381647
auto blockArgs = op.getBlock()->getArguments();
1648+
// Conservative in the case that the #args has changed
16391649
for (size_t i = 0; i < blockArgs.size() && i < blockArgNames.size(); ++i) {
16401650
auto blockArgName = blockArgNames[i].cast<StringAttr>().strref();
16411651
if (!usedNames.count(blockArgName))

mlir/test/IR/print-retain-identifiers.mlir

Lines changed: 38 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -52,3 +52,41 @@ func.func @simple(i64, i1) -> i64 {
5252
}
5353

5454
// -----
55+
56+
//===----------------------------------------------------------------------===//
57+
// Test multiple return values
58+
//===----------------------------------------------------------------------===//
59+
60+
func.func @select_min_max(%a: f64, %b: f64) -> (f64, f64) {
61+
%gt = arith.cmpf "ogt", %a, %b : f64
62+
// CHECK: %min, %max = scf.if %gt -> (f64, f64) {
63+
%min, %max = scf.if %gt -> (f64, f64) {
64+
scf.yield %b, %a : f64, f64
65+
} else {
66+
scf.yield %a, %b : f64, f64
67+
}
68+
// CHECK: return %min, %max : f64, f64
69+
return %min, %max : f64, f64
70+
}
71+
72+
// -----
73+
74+
////===----------------------------------------------------------------------===//
75+
// Test multiple return values, with a grouped value tuple
76+
//===----------------------------------------------------------------------===//
77+
78+
func.func @select_max(%a: f64, %b: f64, %c: f64, %d: f64) -> (f64, f64, f64, f64) {
79+
// Find the max between %a and %b,
80+
// with %c and %d being other values that are returned.
81+
%gt = arith.cmpf "ogt", %a, %b : f64
82+
// CHECK: %max, %others:2, %alt = scf.if %gt -> (f64, f64, f64, f64) {
83+
%max, %others:2, %alt = scf.if %gt -> (f64, f64, f64, f64) {
84+
scf.yield %b, %a, %c, %d : f64, f64, f64, f64
85+
} else {
86+
scf.yield %a, %b, %d, %c : f64, f64, f64, f64
87+
}
88+
// CHECK: return %max, %others#0, %others#1, %alt : f64, f64, f64, f64
89+
return %max, %others#0, %others#1, %alt : f64, f64, f64, f64
90+
}
91+
92+
// -----

0 commit comments

Comments
 (0)