Skip to content

Commit 93be82c

Browse files
committed
add test
1 parent 26c0de4 commit 93be82c

File tree

3 files changed

+133
-33
lines changed

3 files changed

+133
-33
lines changed

mlir/include/mlir/IR/OperationSupport.h

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1221,8 +1221,8 @@ class OpPrintingFlags {
12211221
/// Return if printer should use unique SSA IDs.
12221222
bool shouldPrintUniqueSSAIDs() const;
12231223

1224-
/// Returns if the printer should retain identifier names collected using
1225-
/// parsing.
1224+
/// Return if the printer should use NameLocs as prefixes when printing SSA
1225+
/// IDs
12261226
bool shouldUseNameLocAsPrefix() const;
12271227

12281228
private:
@@ -1259,7 +1259,7 @@ class OpPrintingFlags {
12591259
/// Print unique SSA IDs for values, block arguments and naming conflicts
12601260
bool printUniqueSSAIDsFlag : 1;
12611261

1262-
/// Print the retained original names of identifiers
1262+
/// Print SSA IDs using NameLocs as prefixes
12631263
bool useNameLocAsPrefix : 1;
12641264
};
12651265

mlir/lib/IR/AsmPrinter.cpp

Lines changed: 25 additions & 30 deletions
Original file line numberDiff line numberDiff line change
@@ -197,9 +197,9 @@ struct AsmPrinterOptions {
197197
llvm::cl::desc("Print unique SSA ID numbers for values, block arguments "
198198
"and naming conflicts across all regions")};
199199

200-
llvm::cl::opt<bool> useNameLocAsPrefix{"mlir-use-nameloc-as-prefix",
201-
llvm::cl::init(false),
202-
llvm::cl::desc("TODO")};
200+
llvm::cl::opt<bool> useNameLocAsPrefix{
201+
"mlir-use-nameloc-as-prefix", llvm::cl::init(false),
202+
llvm::cl::desc("Print SSA IDs using NameLocs as prefixes")};
203203
};
204204
} // namespace
205205

@@ -369,7 +369,7 @@ bool OpPrintingFlags::shouldPrintUniqueSSAIDs() const {
369369
return printUniqueSSAIDsFlag || shouldPrintGenericOpForm();
370370
}
371371

372-
/// TODO
372+
/// Return if the printer should use NameLocs as prefixes when printing SSA IDs
373373
bool OpPrintingFlags::shouldUseNameLocAsPrefix() const {
374374
return useNameLocAsPrefix;
375375
}
@@ -1523,25 +1523,18 @@ void SSANameState::numberValuesInRegion(Region &region) {
15231523
assert(!valueIDs.count(arg) && "arg numbered multiple times");
15241524
assert(llvm::cast<BlockArgument>(arg).getOwner()->getParent() == &region &&
15251525
"arg not defined in current region");
1526-
setValueName(arg, name);
1526+
if (printerFlags.shouldUseNameLocAsPrefix() && isa<NameLoc>(arg.getLoc())) {
1527+
auto nameLoc = cast<NameLoc>(arg.getLoc());
1528+
setValueName(arg, nameLoc.getName());
1529+
} else {
1530+
setValueName(arg, name);
1531+
}
15271532
};
15281533

1529-
bool alreadySetNames = false;
15301534
if (!printerFlags.shouldPrintGenericOpForm()) {
15311535
if (Operation *op = region.getParentOp()) {
1532-
if (auto asmInterface = dyn_cast<OpAsmOpInterface>(op)) {
1536+
if (auto asmInterface = dyn_cast<OpAsmOpInterface>(op))
15331537
asmInterface.getAsmBlockArgumentNames(region, setBlockArgNameFn);
1534-
alreadySetNames = true;
1535-
}
1536-
}
1537-
}
1538-
1539-
if (printerFlags.shouldUseNameLocAsPrefix() && !alreadySetNames) {
1540-
for (BlockArgument arg : region.getArguments()) {
1541-
if (isa<NameLoc>(arg.getLoc())) {
1542-
auto nameLoc = cast<NameLoc>(arg.getLoc());
1543-
setBlockArgNameFn(arg, nameLoc.getName());
1544-
}
15451538
}
15461539
}
15471540

@@ -1596,7 +1589,13 @@ void SSANameState::numberValuesInOp(Operation &op) {
15961589
auto setResultNameFn = [&](Value result, StringRef name) {
15971590
assert(!valueIDs.count(result) && "result numbered multiple times");
15981591
assert(result.getDefiningOp() == &op && "result not defined by 'op'");
1599-
setValueName(result, name);
1592+
if (printerFlags.shouldUseNameLocAsPrefix() &&
1593+
isa<NameLoc>(result.getLoc())) {
1594+
auto nameLoc = cast<NameLoc>(result.getLoc());
1595+
setValueName(result, nameLoc.getName());
1596+
} else {
1597+
setValueName(result, name);
1598+
}
16001599

16011600
// Record the result number for groups not anchored at 0.
16021601
if (int resultNo = llvm::cast<OpResult>(result).getResultNumber())
@@ -1618,25 +1617,14 @@ void SSANameState::numberValuesInOp(Operation &op) {
16181617
blockNames[block] = {-1, name};
16191618
};
16201619

1621-
bool alreadySetNames = false;
16221620
if (!printerFlags.shouldPrintGenericOpForm()) {
16231621
if (OpAsmOpInterface asmInterface = dyn_cast<OpAsmOpInterface>(&op)) {
16241622
asmInterface.getAsmBlockNames(setBlockNameFn);
16251623
asmInterface.getAsmResultNames(setResultNameFn);
1626-
alreadySetNames = true;
16271624
}
16281625
}
16291626

16301627
unsigned numResults = op.getNumResults();
1631-
if (printerFlags.shouldUseNameLocAsPrefix() && !alreadySetNames &&
1632-
numResults > 0) {
1633-
Value resultBegin = op.getResult(0);
1634-
if (isa<NameLoc>(resultBegin.getLoc())) {
1635-
auto nameLoc = cast<NameLoc>(resultBegin.getLoc());
1636-
setResultNameFn(resultBegin, nameLoc.getName());
1637-
}
1638-
}
1639-
16401628
if (numResults == 0) {
16411629
// If value users should be printed, operations with no result need an id.
16421630
if (printerFlags.shouldPrintValueUsers()) {
@@ -1647,6 +1635,13 @@ void SSANameState::numberValuesInOp(Operation &op) {
16471635
}
16481636
Value resultBegin = op.getResult(0);
16491637

1638+
if (printerFlags.shouldUseNameLocAsPrefix() && !valueIDs.count(resultBegin)) {
1639+
if (isa<NameLoc>(resultBegin.getLoc())) {
1640+
auto nameLoc = cast<NameLoc>(resultBegin.getLoc());
1641+
setResultNameFn(resultBegin, nameLoc.getName());
1642+
}
1643+
}
1644+
16501645
// If the first result wasn't numbered, give it a default number.
16511646
if (valueIDs.try_emplace(resultBegin, nextValueID).second)
16521647
++nextValueID;
Lines changed: 105 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,105 @@
1+
// RUN: mlir-opt %s -mlir-use-nameloc-as-prefix -split-input-file | FileCheck %s
2+
// RUN: mlir-opt %s -test-loop-unrolling='unroll-factor=2' -mlir-use-nameloc-as-prefix -split-input-file | FileCheck %s --check-prefix=CHECK-PASS-PRESERVE
3+
4+
// CHECK-LABEL: test_basic
5+
func.func @test_basic() {
6+
%0 = memref.alloc() : memref<i32>
7+
// CHECK: %alice
8+
%1 = memref.load %0[] : memref<i32> loc("alice")
9+
return
10+
}
11+
12+
// -----
13+
14+
// CHECK-LABEL: test_repeat_namelocs
15+
func.func @test_repeat_namelocs() {
16+
%0 = memref.alloc() : memref<i32>
17+
// CHECK: %alice
18+
%1 = memref.load %0[] : memref<i32> loc("alice")
19+
// CHECK: %alice_0
20+
%2 = memref.load %0[] : memref<i32> loc("alice")
21+
return
22+
}
23+
24+
// -----
25+
26+
// CHECK-LABEL: test_bb_args
27+
func.func @test_bb_args1(%arg0 : memref<i32> loc("foo")) {
28+
// CHECK: %alice = memref.load %foo
29+
%1 = memref.load %arg0[] : memref<i32> loc("alice")
30+
return
31+
}
32+
33+
// -----
34+
35+
func.func private @make_two_results() -> (index, index)
36+
37+
// CHECK-LABEL: test_multiple_results
38+
func.func @test_multiple_results(%cond: i1) {
39+
// CHECK: %foo:2
40+
%0:2 = call @make_two_results() : () -> (index, index) loc("foo")
41+
// CHECK: %bar:2
42+
%1, %2 = call @make_two_results() : () -> (index, index) loc("bar")
43+
44+
// CHECK: %kevin:2 = scf.while (%arg1 = %bar#0, %arg2 = %bar#0)
45+
%5:2 = scf.while (%arg1 = %1, %arg2 = %1) : (index, index) -> (index, index) {
46+
%6 = arith.cmpi slt, %arg1, %arg2 : index
47+
scf.condition(%6) %arg1, %arg2 : index, index
48+
} do {
49+
// CHECK: ^bb0(%alice: index, %bob: index)
50+
^bb0(%arg3 : index loc("alice"), %arg4: index loc("bob")):
51+
%c1, %c2 = func.call @make_two_results() : () -> (index, index) loc("harriet")
52+
// CHECK: scf.yield %harriet#1, %harriet#1
53+
scf.yield %c2, %c2 : index, index
54+
} loc("kevin")
55+
return
56+
}
57+
58+
// -----
59+
60+
#map = affine_map<(d0) -> (d0)>
61+
#trait = {
62+
iterator_types = ["parallel"],
63+
indexing_maps = [#map, #map, #map]
64+
}
65+
66+
// CHECK-LABEL: test_op_asm_interface
67+
func.func @test_op_asm_interface(%arg0: tensor<?xf32>, %arg1: tensor<?xf32>) {
68+
// CHECK: %c0
69+
%0 = arith.constant 0 : index
70+
// CHECK: %foo
71+
%1 = arith.constant 1 : index loc("foo")
72+
73+
linalg.generic #trait ins(%arg0: tensor<?xf32>) outs(%arg0, %arg1: tensor<?xf32>, tensor<?xf32>) {
74+
// CHECK: ^bb0(%in: f32, %out: f32, %out_0: f32)
75+
^bb0(%a: f32, %b: f32, %c: f32):
76+
linalg.yield %a, %a : f32, f32
77+
} -> (tensor<?xf32>, tensor<?xf32>)
78+
79+
linalg.generic #trait ins(%arg0: tensor<?xf32>) outs(%arg0, %arg1: tensor<?xf32>, tensor<?xf32>) {
80+
// CHECK: ^bb0(%bar: f32, %alice: f32, %steve: f32)
81+
^bb0(%a: f32 loc("bar"), %b: f32 loc("alice"), %c: f32 loc("steve")):
82+
// CHECK: linalg.yield %alice, %steve
83+
linalg.yield %b, %c : f32, f32
84+
} -> (tensor<?xf32>, tensor<?xf32>)
85+
86+
return
87+
}
88+
89+
// -----
90+
91+
// CHECK-LABEL: test_pass
92+
func.func @test_pass(%arg0: memref<4xf32>, %arg1: memref<4xf32>) {
93+
%c0 = arith.constant 0 : index
94+
%c1 = arith.constant 1 : index
95+
%c4 = arith.constant 4 : index
96+
scf.for %arg2 = %c0 to %c4 step %c1 {
97+
// CHECK-PASS-PRESERVE: %foo = memref.load
98+
// CHECK-PASS-PRESERVE: memref.store %foo
99+
// CHECK-PASS-PRESERVE: %foo_1 = memref.load
100+
// CHECK-PASS-PRESERVE: memref.store %foo_1
101+
%0 = memref.load %arg0[%arg2] : memref<4xf32> loc("foo")
102+
memref.store %0, %arg1[%arg2] : memref<4xf32>
103+
}
104+
return
105+
}

0 commit comments

Comments
 (0)