Skip to content

Commit da400cc

Browse files
committed
Fix formatting of cluster labels
1 parent c681a20 commit da400cc

File tree

1 file changed

+28
-10
lines changed

1 file changed

+28
-10
lines changed

mlir/lib/Transforms/ViewOpGraph.cpp

Lines changed: 28 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -49,11 +49,6 @@ static std::string strFromOs(function_ref<void(raw_ostream &)> func) {
4949
return buf;
5050
}
5151

52-
/// Escape special characters such as '\n' and quotation marks.
53-
static std::string escapeString(std::string str) {
54-
return strFromOs([&](raw_ostream &os) { os.write_escaped(str); });
55-
}
56-
5752
/// Put quotation marks around a given string.
5853
static std::string quoteString(const std::string &str) {
5954
return "\"" + str + "\"";
@@ -169,8 +164,7 @@ class PrintOpPass : public impl::ViewOpGraphBase<PrintOpPass> {
169164
os.indent();
170165
// Emit invisible anchor node from/to which arrows can be drawn.
171166
Node anchorNode = emitNodeStmt(" ", kShapeNone);
172-
os << attrStmt("label", quoteString(escapeString(std::move(label))))
173-
<< ";\n";
167+
os << attrStmt("label", quoteString(label)) << ";\n";
174168
builder();
175169
os.unindent();
176170
os << "}\n";
@@ -288,8 +282,32 @@ class PrintOpPass : public impl::ViewOpGraphBase<PrintOpPass> {
288282
return str;
289283
}
290284

285+
std::string getClusterLabel(Operation *op) {
286+
return strFromOs([&](raw_ostream &os) {
287+
// Print operation name and type.
288+
os << op->getName();
289+
if (printResultTypes) {
290+
os << " : (";
291+
std::string buf;
292+
llvm::raw_string_ostream ss(buf);
293+
interleaveComma(op->getResultTypes(), ss);
294+
os << truncateString(buf) << ")";
295+
}
296+
297+
// Print attributes.
298+
if (printAttrs) {
299+
os << "\\l";
300+
for (const NamedAttribute &attr : op->getAttrs()) {
301+
os << attr.getName().getValue() << ": ";
302+
emitMlirAttr(os, attr.getValue());
303+
os << "\\l";
304+
}
305+
}
306+
});
307+
}
308+
291309
/// Generate a label for an operation.
292-
std::string getLabel(Operation *op) {
310+
std::string getRecordLabel(Operation *op) {
293311
return strFromOs([&](raw_ostream &os) {
294312
os << "{";
295313

@@ -369,9 +387,9 @@ class PrintOpPass : public impl::ViewOpGraphBase<PrintOpPass> {
369387
for (Region &region : op->getRegions())
370388
processRegion(region);
371389
},
372-
getLabel(op));
390+
getClusterLabel(op));
373391
} else {
374-
node = emitNodeStmt(getLabel(op), kShapeNode,
392+
node = emitNodeStmt(getRecordLabel(op), kShapeNode,
375393
backgroundColors[op->getName()].second);
376394
}
377395

0 commit comments

Comments
 (0)