Skip to content

Commit f539e00

Browse files
authored
[mlir] add option to print SSA IDs using NameLocs as prefixes (#119996)
This PR adds an `AsmPrinter` option `-mlir-use-nameloc-as-prefix` which uses trailing `NameLoc`s, if the source IR provides them, as prefixes when printing SSA IDs.
1 parent 6a7d6c5 commit f539e00

File tree

4 files changed

+170
-3
lines changed

4 files changed

+170
-3
lines changed

mlir/docs/PassManagement.md

Lines changed: 21 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1398,6 +1398,27 @@ $ tree /tmp/pipeline_output
13981398
│ │ ├── 1_1_pass4.mlir
13991399
```
14001400

1401+
* `mlir-use-nameloc-as-prefix`
1402+
* If your source IR has named locations (`loc("named_location")"`) then passing this flag will use those
1403+
names (`named_location`) to prefix the corresponding SSA identifiers:
1404+
1405+
```mlir
1406+
%1 = memref.load %0[] : memref<i32> loc("alice")
1407+
%2 = memref.load %0[] : memref<i32> loc("bob")
1408+
%3 = memref.load %0[] : memref<i32> loc("bob")
1409+
```
1410+
1411+
will print
1412+
1413+
```mlir
1414+
%alice = memref.load %0[] : memref<i32>
1415+
%bob = memref.load %0[] : memref<i32>
1416+
%bob_0 = memref.load %0[] : memref<i32>
1417+
```
1418+
1419+
These names will also be preserved through passes to newly created operations if using the appropriate location.
1420+
1421+
14011422
## Crash and Failure Reproduction
14021423
14031424
The [pass manager](#pass-manager) in MLIR contains a builtin mechanism to

mlir/include/mlir/IR/OperationSupport.h

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1221,6 +1221,10 @@ class OpPrintingFlags {
12211221
/// Return if printer should use unique SSA IDs.
12221222
bool shouldPrintUniqueSSAIDs() const;
12231223

1224+
/// Return if the printer should use NameLocs as prefixes when printing SSA
1225+
/// IDs
1226+
bool shouldUseNameLocAsPrefix() const;
1227+
12241228
private:
12251229
/// Elide large elements attributes if the number of elements is larger than
12261230
/// the upper limit.
@@ -1254,6 +1258,9 @@ class OpPrintingFlags {
12541258

12551259
/// Print unique SSA IDs for values, block arguments and naming conflicts
12561260
bool printUniqueSSAIDsFlag : 1;
1261+
1262+
/// Print SSA IDs using NameLocs as prefixes
1263+
bool useNameLocAsPrefix : 1;
12571264
};
12581265

12591266
//===----------------------------------------------------------------------===//

mlir/lib/IR/AsmPrinter.cpp

Lines changed: 37 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -73,7 +73,8 @@ OpAsmParser::~OpAsmParser() = default;
7373
MLIRContext *AsmParser::getContext() const { return getBuilder().getContext(); }
7474

7575
/// Parse a type list.
76-
/// This is out-of-line to work-around https://github.com/llvm/llvm-project/issues/62918
76+
/// This is out-of-line to work-around
77+
/// https://github.com/llvm/llvm-project/issues/62918
7778
ParseResult AsmParser::parseTypeList(SmallVectorImpl<Type> &result) {
7879
return parseCommaSeparatedList(
7980
[&]() { return parseType(result.emplace_back()); });
@@ -195,6 +196,10 @@ struct AsmPrinterOptions {
195196
"mlir-print-unique-ssa-ids", llvm::cl::init(false),
196197
llvm::cl::desc("Print unique SSA ID numbers for values, block arguments "
197198
"and naming conflicts across all regions")};
199+
200+
llvm::cl::opt<bool> useNameLocAsPrefix{
201+
"mlir-use-nameloc-as-prefix", llvm::cl::init(false),
202+
llvm::cl::desc("Print SSA IDs using NameLocs as prefixes")};
198203
};
199204
} // namespace
200205

@@ -212,7 +217,8 @@ OpPrintingFlags::OpPrintingFlags()
212217
: printDebugInfoFlag(false), printDebugInfoPrettyFormFlag(false),
213218
printGenericOpFormFlag(false), skipRegionsFlag(false),
214219
assumeVerifiedFlag(false), printLocalScope(false),
215-
printValueUsersFlag(false), printUniqueSSAIDsFlag(false) {
220+
printValueUsersFlag(false), printUniqueSSAIDsFlag(false),
221+
useNameLocAsPrefix(false) {
216222
// Initialize based upon command line options, if they are available.
217223
if (!clOptions.isConstructed())
218224
return;
@@ -231,6 +237,7 @@ OpPrintingFlags::OpPrintingFlags()
231237
skipRegionsFlag = clOptions->skipRegionsOpt;
232238
printValueUsersFlag = clOptions->printValueUsers;
233239
printUniqueSSAIDsFlag = clOptions->printUniqueSSAIDs;
240+
useNameLocAsPrefix = clOptions->useNameLocAsPrefix;
234241
}
235242

236243
/// Enable the elision of large elements attributes, by printing a '...'
@@ -362,6 +369,11 @@ bool OpPrintingFlags::shouldPrintUniqueSSAIDs() const {
362369
return printUniqueSSAIDsFlag || shouldPrintGenericOpForm();
363370
}
364371

372+
/// Return if the printer should use NameLocs as prefixes when printing SSA IDs.
373+
bool OpPrintingFlags::shouldUseNameLocAsPrefix() const {
374+
return useNameLocAsPrefix;
375+
}
376+
365377
//===----------------------------------------------------------------------===//
366378
// NewLineCounter
367379
//===----------------------------------------------------------------------===//
@@ -1506,11 +1518,22 @@ void SSANameState::shadowRegionArgs(Region &region, ValueRange namesToUse) {
15061518
}
15071519
}
15081520

1521+
namespace {
1522+
/// Try to get value name from value's location, fallback to `name`.
1523+
StringRef maybeGetValueNameFromLoc(Value value, StringRef name) {
1524+
if (auto maybeNameLoc = value.getLoc()->findInstanceOf<NameLoc>())
1525+
return maybeNameLoc.getName();
1526+
return name;
1527+
}
1528+
} // namespace
1529+
15091530
void SSANameState::numberValuesInRegion(Region &region) {
15101531
auto setBlockArgNameFn = [&](Value arg, StringRef name) {
15111532
assert(!valueIDs.count(arg) && "arg numbered multiple times");
15121533
assert(llvm::cast<BlockArgument>(arg).getOwner()->getParent() == &region &&
15131534
"arg not defined in current region");
1535+
if (LLVM_UNLIKELY(printerFlags.shouldUseNameLocAsPrefix()))
1536+
name = maybeGetValueNameFromLoc(arg, name);
15141537
setValueName(arg, name);
15151538
};
15161539

@@ -1553,7 +1576,10 @@ void SSANameState::numberValuesInBlock(Block &block) {
15531576
specialNameBuffer.resize(strlen("arg"));
15541577
specialName << nextArgumentID++;
15551578
}
1556-
setValueName(arg, specialName.str());
1579+
StringRef specialNameStr = specialName.str();
1580+
if (LLVM_UNLIKELY(printerFlags.shouldUseNameLocAsPrefix()))
1581+
specialNameStr = maybeGetValueNameFromLoc(arg, specialNameStr);
1582+
setValueName(arg, specialNameStr);
15571583
}
15581584

15591585
// Number the operations in this block.
@@ -1567,6 +1593,8 @@ void SSANameState::numberValuesInOp(Operation &op) {
15671593
auto setResultNameFn = [&](Value result, StringRef name) {
15681594
assert(!valueIDs.count(result) && "result numbered multiple times");
15691595
assert(result.getDefiningOp() == &op && "result not defined by 'op'");
1596+
if (LLVM_UNLIKELY(printerFlags.shouldUseNameLocAsPrefix()))
1597+
name = maybeGetValueNameFromLoc(result, name);
15701598
setValueName(result, name);
15711599

15721600
// Record the result number for groups not anchored at 0.
@@ -1607,6 +1635,12 @@ void SSANameState::numberValuesInOp(Operation &op) {
16071635
}
16081636
Value resultBegin = op.getResult(0);
16091637

1638+
if (printerFlags.shouldUseNameLocAsPrefix() && !valueIDs.count(resultBegin)) {
1639+
if (auto nameLoc = resultBegin.getLoc()->findInstanceOf<NameLoc>()) {
1640+
setValueName(resultBegin, nameLoc.getName());
1641+
}
1642+
}
1643+
16101644
// If the first result wasn't numbered, give it a default number.
16111645
if (valueIDs.try_emplace(resultBegin, nextValueID).second)
16121646
++nextValueID;
Lines changed: 105 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,105 @@
1+
// RUN: mlir-opt %s -mlir-use-nameloc-as-prefix -split-input-file | FileCheck %s
2+
// RUN: mlir-opt %s -test-loop-unrolling='unroll-factor=2' -mlir-use-nameloc-as-prefix -split-input-file | FileCheck %s --check-prefix=CHECK-PASS-PRESERVE
3+
4+
// CHECK-LABEL: test_basic
5+
func.func @test_basic() {
6+
%0 = memref.alloc() : memref<i32>
7+
// CHECK: %alice = memref.load
8+
%1 = memref.load %0[] : memref<i32> loc("alice")
9+
return
10+
}
11+
12+
// -----
13+
14+
// CHECK-LABEL: test_repeat_namelocs
15+
func.func @test_repeat_namelocs() {
16+
%0 = memref.alloc() : memref<i32>
17+
// CHECK: %alice = memref.load
18+
%1 = memref.load %0[] : memref<i32> loc("alice")
19+
// CHECK: %alice_0 = memref.load
20+
%2 = memref.load %0[] : memref<i32> loc("alice")
21+
return
22+
}
23+
24+
// -----
25+
26+
// CHECK-LABEL: test_bb_args
27+
func.func @test_bb_args1(%arg0 : memref<i32> loc("foo")) {
28+
// CHECK: %alice = memref.load %foo
29+
%1 = memref.load %arg0[] : memref<i32> loc("alice")
30+
return
31+
}
32+
33+
// -----
34+
35+
func.func private @make_two_results() -> (index, index)
36+
37+
// CHECK-LABEL: test_multiple_results
38+
func.func @test_multiple_results(%cond: i1) {
39+
// CHECK: %foo:2 = call @make_two_results
40+
%0:2 = call @make_two_results() : () -> (index, index) loc("foo")
41+
// CHECK: %bar:2 = call @make_two_results
42+
%1, %2 = call @make_two_results() : () -> (index, index) loc("bar")
43+
44+
// CHECK: %kevin:2 = scf.while (%arg1 = %bar#0, %arg2 = %bar#0)
45+
%5:2 = scf.while (%arg1 = %1, %arg2 = %1) : (index, index) -> (index, index) {
46+
%6 = arith.cmpi slt, %arg1, %arg2 : index
47+
scf.condition(%6) %arg1, %arg2 : index, index
48+
} do {
49+
// CHECK: ^bb0(%alice: index, %bob: index)
50+
^bb0(%arg3 : index loc("alice"), %arg4: index loc("bob")):
51+
%c1, %c2 = func.call @make_two_results() : () -> (index, index) loc("harriet")
52+
// CHECK: scf.yield %harriet#1, %harriet#1
53+
scf.yield %c2, %c2 : index, index
54+
} loc("kevin")
55+
return
56+
}
57+
58+
// -----
59+
60+
#map = affine_map<(d0) -> (d0)>
61+
#trait = {
62+
iterator_types = ["parallel"],
63+
indexing_maps = [#map, #map, #map]
64+
}
65+
66+
// CHECK-LABEL: test_op_asm_interface
67+
func.func @test_op_asm_interface(%arg0: tensor<?xf32>, %arg1: tensor<?xf32>) {
68+
// CHECK: %c0 = arith.constant
69+
%0 = arith.constant 0 : index
70+
// CHECK: %foo = arith.constant
71+
%1 = arith.constant 1 : index loc("foo")
72+
73+
linalg.generic #trait ins(%arg0: tensor<?xf32>) outs(%arg0, %arg1: tensor<?xf32>, tensor<?xf32>) {
74+
// CHECK: ^bb0(%in: f32, %out: f32, %out_0: f32)
75+
^bb0(%a: f32, %b: f32, %c: f32):
76+
linalg.yield %a, %a : f32, f32
77+
} -> (tensor<?xf32>, tensor<?xf32>)
78+
79+
linalg.generic #trait ins(%arg0: tensor<?xf32>) outs(%arg0, %arg1: tensor<?xf32>, tensor<?xf32>) {
80+
// CHECK: ^bb0(%bar: f32, %alice: f32, %steve: f32)
81+
^bb0(%a: f32 loc("bar"), %b: f32 loc("alice"), %c: f32 loc("steve")):
82+
// CHECK: linalg.yield %alice, %steve
83+
linalg.yield %b, %c : f32, f32
84+
} -> (tensor<?xf32>, tensor<?xf32>)
85+
86+
return
87+
}
88+
89+
// -----
90+
91+
// CHECK-LABEL: test_pass
92+
func.func @test_pass(%arg0: memref<4xf32>, %arg1: memref<4xf32>) {
93+
%c0 = arith.constant 0 : index
94+
%c1 = arith.constant 1 : index
95+
%c4 = arith.constant 4 : index
96+
scf.for %arg2 = %c0 to %c4 step %c1 {
97+
// CHECK-PASS-PRESERVE: %foo = memref.load
98+
// CHECK-PASS-PRESERVE: memref.store %foo
99+
// CHECK-PASS-PRESERVE: %foo_1 = memref.load
100+
// CHECK-PASS-PRESERVE: memref.store %foo_1
101+
%0 = memref.load %arg0[%arg2] : memref<4xf32> loc("foo")
102+
memref.store %0, %arg1[%arg2] : memref<4xf32>
103+
}
104+
return
105+
}

0 commit comments

Comments
 (0)