Skip to content

Commit 596da62

Browse files
Chris Lattnerlattner
authored andcommitted
Add support for custom op parser/printer hooks to know about result names.
Summary: This allows the custom parser/printer hooks to do interesting things with the SSA names. This patch: - Adds a new 'getResultName' method to OpAsmParser that allows a parser implementation to get information about its result names, along with a getNumResults() method that allows op parser impls to know how many results are expected. - Adds a OpAsmPrinter::printOperand overload that takes an explicit stream. - Adds a test.string_attr_pretty_name operation that uses these hooks to do fancy things with the result name. Reviewers: rriddle! Subscribers: mehdi_amini, rriddle, jpienaar, burmako, shauheen, antiagainst, nicolasvasilache, arpith-jacob, mgester, lucyrfox, liufengdb, Joonsoo, llvm-commits Tags: #llvm Differential Revision: https://reviews.llvm.org/D76205
1 parent 5ff5ddd commit 596da62

File tree

6 files changed

+221
-16
lines changed

6 files changed

+221
-16
lines changed

mlir/include/mlir/IR/OpImplementation.h

Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -37,6 +37,7 @@ class OpAsmPrinter {
3737

3838
/// Print implementations for various things an operation contains.
3939
virtual void printOperand(Value value) = 0;
40+
virtual void printOperand(Value value, raw_ostream &os) = 0;
4041

4142
/// Print a comma separated list of operands.
4243
template <typename ContainerType>
@@ -245,6 +246,24 @@ class OpAsmParser {
245246
return success();
246247
}
247248

249+
/// Return the name of the specified result in the specified syntax, as well
250+
/// as the sub-element in the name. It returns an empty string and ~0U for
251+
/// invalid result numbers. For example, in this operation:
252+
///
253+
/// %x, %y:2, %z = foo.op
254+
///
255+
/// getResultName(0) == {"x", 0 }
256+
/// getResultName(1) == {"y", 0 }
257+
/// getResultName(2) == {"y", 1 }
258+
/// getResultName(3) == {"z", 0 }
259+
/// getResultName(4) == {"", ~0U }
260+
virtual std::pair<StringRef, unsigned>
261+
getResultName(unsigned resultNo) const = 0;
262+
263+
/// Return the number of declared SSA results. This returns 4 for the foo.op
264+
/// example in the comment for `getResultName`.
265+
virtual size_t getNumResults() const = 0;
266+
248267
/// Return the location of the original name token.
249268
virtual llvm::SMLoc getNameLoc() const = 0;
250269

mlir/lib/IR/AsmPrinter.cpp

Lines changed: 14 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -765,10 +765,10 @@ void SSANameState::setValueName(Value value, StringRef name) {
765765
static bool isPunct(char c) {
766766
return c == '$' || c == '.' || c == '_' || c == '-';
767767
}
768-
768+
769769
StringRef SSANameState::uniqueValueName(StringRef name) {
770770
assert(!name.empty() && "Shouldn't have an empty name here");
771-
771+
772772
// Check to see if this name is valid. If it starts with a digit, then it
773773
// could conflict with the autogenerated numeric ID's (we unique them in a
774774
// different map), so add an underscore prefix to avoid problems.
@@ -777,13 +777,13 @@ StringRef SSANameState::uniqueValueName(StringRef name) {
777777
tmpName += name;
778778
return uniqueValueName(tmpName);
779779
}
780-
780+
781781
// Check to see if the name consists of all-valid identifiers. If not, we
782782
// need to escape them.
783783
for (char ch : name) {
784784
if (isalpha(ch) || isPunct(ch) || isdigit(ch))
785785
continue;
786-
786+
787787
SmallString<16> tmpName;
788788
for (char ch : name) {
789789
if (isalpha(ch) || isPunct(ch) || isdigit(ch))
@@ -796,7 +796,7 @@ StringRef SSANameState::uniqueValueName(StringRef name) {
796796
}
797797
return uniqueValueName(tmpName);
798798
}
799-
799+
800800
// Check to see if this name is already unique.
801801
if (!usedNames.count(name)) {
802802
name = name.copy(usedNameAllocator);
@@ -1963,7 +1963,8 @@ class OperationPrinter : public ModulePrinter, private OpAsmPrinter {
19631963
bool printBlockTerminator = true);
19641964

19651965
/// Print the ID of the given value, optionally with its result number.
1966-
void printValueID(Value value, bool printResultNo = true) const;
1966+
void printValueID(Value value, bool printResultNo = true,
1967+
raw_ostream *streamOverride = nullptr) const;
19671968

19681969
//===--------------------------------------------------------------------===//
19691970
// OpAsmPrinter methods
@@ -1988,6 +1989,9 @@ class OperationPrinter : public ModulePrinter, private OpAsmPrinter {
19881989

19891990
/// Print the ID for the given value.
19901991
void printOperand(Value value) override { printValueID(value); }
1992+
void printOperand(Value value, raw_ostream &os) override {
1993+
printValueID(value, /*printResultNo=*/true, &os);
1994+
}
19911995

19921996
/// Print an optional attribute dictionary with a given set of elided values.
19931997
void printOptionalAttrDict(ArrayRef<NamedAttribute> attrs,
@@ -2195,8 +2199,10 @@ void OperationPrinter::print(Block *block, bool printBlockArgs,
21952199
currentIndent -= indentWidth;
21962200
}
21972201

2198-
void OperationPrinter::printValueID(Value value, bool printResultNo) const {
2199-
state->getSSANameState().printValueID(value, printResultNo, os);
2202+
void OperationPrinter::printValueID(Value value, bool printResultNo,
2203+
raw_ostream *streamOverride) const {
2204+
state->getSSANameState().printValueID(value, printResultNo,
2205+
streamOverride ? *streamOverride : os);
22002206
}
22012207

22022208
void OperationPrinter::printSuccessor(Block *successor) {

mlir/lib/Parser/Parser.cpp

Lines changed: 55 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -3322,8 +3322,13 @@ class OperationParser : public Parser {
33223322
Operation *parseGenericOperation(Block *insertBlock,
33233323
Block::iterator insertPt);
33243324

3325+
/// This is the structure of a result specifier in the assembly syntax,
3326+
/// including the name, number of results, and location.
3327+
typedef std::tuple<StringRef, unsigned, SMLoc> ResultRecord;
3328+
33253329
/// Parse an operation instance that is in the op-defined custom form.
3326-
Operation *parseCustomOperation();
3330+
/// resultInfo specifies information about the "%name =" specifiers.
3331+
Operation *parseCustomOperation(ArrayRef<ResultRecord> resultInfo);
33273332

33283333
//===--------------------------------------------------------------------===//
33293334
// Region Parsing
@@ -3728,7 +3733,7 @@ Value OperationParser::createForwardRefPlaceholder(SMLoc loc, Type type) {
37283733
///
37293734
ParseResult OperationParser::parseOperation() {
37303735
auto loc = getToken().getLoc();
3731-
SmallVector<std::tuple<StringRef, unsigned, SMLoc>, 1> resultIDs;
3736+
SmallVector<ResultRecord, 1> resultIDs;
37323737
size_t numExpectedResults = 0;
37333738
if (getToken().is(Token::percent_identifier)) {
37343739
// Parse the group of result ids.
@@ -3769,7 +3774,7 @@ ParseResult OperationParser::parseOperation() {
37693774

37703775
Operation *op;
37713776
if (getToken().is(Token::bare_identifier) || getToken().isKeyword())
3772-
op = parseCustomOperation();
3777+
op = parseCustomOperation(resultIDs);
37733778
else if (getToken().is(Token::string))
37743779
op = parseGenericOperation();
37753780
else
@@ -3790,7 +3795,7 @@ ParseResult OperationParser::parseOperation() {
37903795

37913796
// Add definitions for each of the result groups.
37923797
unsigned opResI = 0;
3793-
for (std::tuple<StringRef, unsigned, SMLoc> &resIt : resultIDs) {
3798+
for (ResultRecord &resIt : resultIDs) {
37943799
for (unsigned subRes : llvm::seq<unsigned>(0, std::get<1>(resIt))) {
37953800
if (addDefinition({std::get<0>(resIt), subRes, std::get<2>(resIt)},
37963801
op->getResult(opResI++)))
@@ -3955,9 +3960,12 @@ Operation *OperationParser::parseGenericOperation(Block *insertBlock,
39553960
namespace {
39563961
class CustomOpAsmParser : public OpAsmParser {
39573962
public:
3958-
CustomOpAsmParser(SMLoc nameLoc, const AbstractOperation *opDefinition,
3963+
CustomOpAsmParser(SMLoc nameLoc,
3964+
ArrayRef<OperationParser::ResultRecord> resultIDs,
3965+
const AbstractOperation *opDefinition,
39593966
OperationParser &parser)
3960-
: nameLoc(nameLoc), opDefinition(opDefinition), parser(parser) {}
3967+
: nameLoc(nameLoc), resultIDs(resultIDs), opDefinition(opDefinition),
3968+
parser(parser) {}
39613969

39623970
/// Parse an instance of the operation described by 'opDefinition' into the
39633971
/// provided operation state.
@@ -3992,6 +4000,41 @@ class CustomOpAsmParser : public OpAsmParser {
39924000

39934001
Builder &getBuilder() const override { return parser.builder; }
39944002

4003+
/// Return the name of the specified result in the specified syntax, as well
4004+
/// as the subelement in the name. For example, in this operation:
4005+
///
4006+
/// %x, %y:2, %z = foo.op
4007+
///
4008+
/// getResultName(0) == {"x", 0 }
4009+
/// getResultName(1) == {"y", 0 }
4010+
/// getResultName(2) == {"y", 1 }
4011+
/// getResultName(3) == {"z", 0 }
4012+
std::pair<StringRef, unsigned>
4013+
getResultName(unsigned resultNo) const override {
4014+
// Scan for the resultID that contains this result number.
4015+
for (unsigned nameID = 0, e = resultIDs.size(); nameID != e; ++nameID) {
4016+
const auto &entry = resultIDs[nameID];
4017+
if (resultNo < std::get<1>(entry)) {
4018+
// Don't pass on the leading %.
4019+
StringRef name = std::get<0>(entry).drop_front();
4020+
return {name, resultNo};
4021+
}
4022+
resultNo -= std::get<1>(entry);
4023+
}
4024+
4025+
// Invalid result number.
4026+
return {"", ~0U};
4027+
}
4028+
4029+
/// Return the number of declared SSA results. This returns 4 for the foo.op
4030+
/// example in the comment for getResultName.
4031+
size_t getNumResults() const override {
4032+
size_t count = 0;
4033+
for (auto &entry : resultIDs)
4034+
count += std::get<1>(entry);
4035+
return count;
4036+
}
4037+
39954038
llvm::SMLoc getNameLoc() const override { return nameLoc; }
39964039

39974040
//===--------------------------------------------------------------------===//
@@ -4500,6 +4543,9 @@ class CustomOpAsmParser : public OpAsmParser {
45004543
/// The source location of the operation name.
45014544
SMLoc nameLoc;
45024545

4546+
/// Information about the result name specifiers.
4547+
ArrayRef<OperationParser::ResultRecord> resultIDs;
4548+
45034549
/// The abstract information of the operation.
45044550
const AbstractOperation *opDefinition;
45054551

@@ -4511,7 +4557,8 @@ class CustomOpAsmParser : public OpAsmParser {
45114557
};
45124558
} // end anonymous namespace.
45134559

4514-
Operation *OperationParser::parseCustomOperation() {
4560+
Operation *
4561+
OperationParser::parseCustomOperation(ArrayRef<ResultRecord> resultIDs) {
45154562
auto opLoc = getToken().getLoc();
45164563
auto opName = getTokenSpelling();
45174564

@@ -4544,7 +4591,7 @@ Operation *OperationParser::parseCustomOperation() {
45444591
// Have the op implementation take a crack and parsing this.
45454592
OperationState opState(srcLocation, opDefinition->name);
45464593
CleanupOpStateRegions guard{opState};
4547-
CustomOpAsmParser opAsmParser(opLoc, opDefinition, *this);
4594+
CustomOpAsmParser opAsmParser(opLoc, resultIDs, opDefinition, *this);
45484595
if (opAsmParser.parseOperation(opState))
45494596
return nullptr;
45504597

mlir/test/IR/parser.mlir

Lines changed: 40 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1185,3 +1185,43 @@ func @custom_asm_names() -> (i32, i32, i32, i32, i32, i32, i32) {
11851185
// CHECK: return %[[FIRST]], %[[MIDDLE]]#0, %[[MIDDLE]]#1, %[[LAST]], %[[FIRST_2]], %[[LAST_2]]
11861186
return %0, %1#0, %1#1, %2, %3, %4, %5 : i32, i32, i32, i32, i32, i32, i32
11871187
}
1188+
1189+
1190+
// CHECK-LABEL: func @pretty_names
1191+
1192+
// This tests the behavior
1193+
func @pretty_names() {
1194+
// Simple case, should parse and print as %x being an implied 'name'
1195+
// attribute.
1196+
%x = test.string_attr_pretty_name
1197+
// CHECK: %x = test.string_attr_pretty_name
1198+
// CHECK-NOT: attributes
1199+
1200+
// This specifies an explicit name, which should override the result.
1201+
%YY = test.string_attr_pretty_name attributes { names = ["y"] }
1202+
// CHECK: %y = test.string_attr_pretty_name
1203+
// CHECK-NOT: attributes
1204+
1205+
// Conflicts with the 'y' name, so need an explicit attribute.
1206+
%0 = "test.string_attr_pretty_name"() { names = ["y"]} : () -> i32
1207+
// CHECK: %y_0 = test.string_attr_pretty_name attributes {names = ["y"]}
1208+
1209+
// Name contains a space.
1210+
%1 = "test.string_attr_pretty_name"() { names = ["space name"]} : () -> i32
1211+
// CHECK: %space_name = test.string_attr_pretty_name attributes {names = ["space name"]}
1212+
1213+
"unknown.use"(%x, %YY, %0, %1) : (i32, i32, i32, i32) -> ()
1214+
1215+
// Multi-result support.
1216+
1217+
%a, %b, %c = test.string_attr_pretty_name
1218+
// CHECK: %a, %b, %c = test.string_attr_pretty_name
1219+
// CHECK-NOT: attributes
1220+
1221+
%q:3, %r = test.string_attr_pretty_name
1222+
// CHECK: %q, %q_1, %q_2, %r = test.string_attr_pretty_name attributes {names = ["q", "q", "q", "r"]}
1223+
1224+
// CHECK: return
1225+
return
1226+
}
1227+

mlir/test/lib/TestDialect/TestDialect.cpp

Lines changed: 81 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -391,6 +391,87 @@ void SideEffectOp::getEffects(
391391
}
392392
}
393393

394+
//===----------------------------------------------------------------------===//
395+
// StringAttrPrettyNameOp
396+
//===----------------------------------------------------------------------===//
397+
398+
// This op has fancy handling of its SSA result name.
399+
static ParseResult parseStringAttrPrettyNameOp(OpAsmParser &parser,
400+
OperationState &result) {
401+
// Add the result types.
402+
for (size_t i = 0, e = parser.getNumResults(); i != e; ++i)
403+
result.addTypes(parser.getBuilder().getIntegerType(32));
404+
405+
if (parser.parseOptionalAttrDictWithKeyword(result.attributes))
406+
return failure();
407+
408+
// If the attribute dictionary contains no 'names' attribute, infer it from
409+
// the SSA name (if specified).
410+
bool hadNames = llvm::any_of(result.attributes, [](NamedAttribute attr) {
411+
return attr.first.is("names");
412+
});
413+
414+
// If there was no name specified, check to see if there was a useful name
415+
// specified in the asm file.
416+
if (hadNames || parser.getNumResults() == 0)
417+
return success();
418+
419+
SmallVector<StringRef, 4> names;
420+
auto *context = result.getContext();
421+
422+
for (size_t i = 0, e = parser.getNumResults(); i != e; ++i) {
423+
auto resultName = parser.getResultName(i);
424+
StringRef nameStr;
425+
if (!resultName.first.empty() && !isdigit(resultName.first[0]))
426+
nameStr = resultName.first;
427+
428+
names.push_back(nameStr);
429+
}
430+
431+
auto namesAttr = parser.getBuilder().getStrArrayAttr(names);
432+
result.attributes.push_back({Identifier::get("names", context), namesAttr});
433+
return success();
434+
}
435+
436+
static void print(OpAsmPrinter &p, StringAttrPrettyNameOp op) {
437+
p << "test.string_attr_pretty_name";
438+
439+
// Note that we only need to print the "name" attribute if the asmprinter
440+
// result name disagrees with it. This can happen in strange cases, e.g.
441+
// when there are conflicts.
442+
bool namesDisagree = op.names().size() != op.getNumResults();
443+
444+
SmallString<32> resultNameStr;
445+
for (size_t i = 0, e = op.getNumResults(); i != e && !namesDisagree; ++i) {
446+
resultNameStr.clear();
447+
llvm::raw_svector_ostream tmpStream(resultNameStr);
448+
p.printOperand(op.getResult(i), tmpStream);
449+
450+
auto expectedName = op.names()[i].dyn_cast<StringAttr>();
451+
if (!expectedName ||
452+
tmpStream.str().drop_front() != expectedName.getValue()) {
453+
namesDisagree = true;
454+
}
455+
}
456+
457+
if (namesDisagree)
458+
p.printOptionalAttrDictWithKeyword(op.getAttrs());
459+
else
460+
p.printOptionalAttrDictWithKeyword(op.getAttrs(), {"names"});
461+
}
462+
463+
// We set the SSA name in the asm syntax to the contents of the name
464+
// attribute.
465+
void StringAttrPrettyNameOp::getAsmResultNames(
466+
function_ref<void(Value, StringRef)> setNameFn) {
467+
468+
auto value = names();
469+
for (size_t i = 0, e = value.size(); i != e; ++i)
470+
if (auto str = value[i].dyn_cast<StringAttr>())
471+
if (!str.getValue().empty())
472+
setNameFn(getResult(i), str.getValue());
473+
}
474+
394475
//===----------------------------------------------------------------------===//
395476
// Dialect Registration
396477
//===----------------------------------------------------------------------===//

mlir/test/lib/TestDialect/TestOps.td

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -496,6 +496,18 @@ def AttrSizedResultOp : TEST_Op<"attr_sized_results",
496496
);
497497
}
498498

499+
// This is used to test encoding of a string attribute into an SSA name of a
500+
// pretty printed value name.
501+
def StringAttrPrettyNameOp
502+
: TEST_Op<"string_attr_pretty_name",
503+
[DeclareOpInterfaceMethods<OpAsmOpInterface>]> {
504+
let arguments = (ins StrArrayAttr:$names);
505+
let results = (outs Variadic<I32>:$r);
506+
507+
let printer = [{ return ::print(p, *this); }];
508+
let parser = [{ return ::parse$cppClass(parser, result); }];
509+
}
510+
499511
//===----------------------------------------------------------------------===//
500512
// Test Patterns
501513
//===----------------------------------------------------------------------===//

0 commit comments

Comments
 (0)