@@ -28,12 +28,14 @@ class ReplaceAllSymbolUsesTest : public ::testing::Test {
28
28
void SetUp () override {
29
29
::test::registerTestDialect (registry);
30
30
context = std::make_unique<MLIRContext>(registry);
31
+ builder = std::make_unique<OpBuilder>(context.get ());
31
32
}
32
33
33
34
void testReplaceAllSymbolUses (ReplaceFnType replaceFn) {
34
35
// Set up IR and find func ops.
35
36
OwningOpRef<ModuleOp> module =
36
37
parseSourceString<ModuleOp>(kInput , context.get ());
38
+ ASSERT_TRUE (module );
37
39
SymbolTable symbolTable (module .get ());
38
40
auto opIterator = module ->getBody (0 )->getOperations ().begin ();
39
41
auto fooOp = cast<FunctionOpInterface>(opIterator++);
@@ -46,7 +48,7 @@ class ReplaceAllSymbolUsesTest : public ::testing::Test {
46
48
ASSERT_TRUE (succeeded (res));
47
49
ASSERT_TRUE (succeeded (verify (module .get ())));
48
50
49
- // Check that it got renamed.
51
+ // Check that callee of the call op got renamed.
50
52
bool calleeFound = false ;
51
53
fooOp->walk ([&](CallOpInterface callOp) {
52
54
StringAttr callee = callOp.getCallableForCallee ()
@@ -56,13 +58,19 @@ class ReplaceAllSymbolUsesTest : public ::testing::Test {
56
58
calleeFound = true ;
57
59
});
58
60
EXPECT_TRUE (calleeFound);
61
+
62
+ // Check that module attribute did *not* get renamed.
63
+ auto moduleAttr = (*module )->getAttrOfType <FlatSymbolRefAttr>(" test.attr" );
64
+ ASSERT_TRUE (moduleAttr);
65
+ EXPECT_EQ (moduleAttr.getValue (), StringRef (" bar" ));
59
66
}
60
67
61
68
std::unique_ptr<MLIRContext> context;
69
+ std::unique_ptr<OpBuilder> builder;
62
70
63
71
private:
64
72
constexpr static llvm::StringLiteral kInput = R"MLIR(
65
- module {
73
+ module attributes { test.attr = @bar } {
66
74
test.conversion_func_op private @foo() {
67
75
"test.conversion_call_op"() { callee=@bar } : () -> ()
68
76
"test.return"() : () -> ()
@@ -81,7 +89,7 @@ TEST_F(ReplaceAllSymbolUsesTest, OperationInModuleOp) {
81
89
testReplaceAllSymbolUses ([&](auto symbolTable, auto module , auto fooOp,
82
90
auto barOp) -> LogicalResult {
83
91
return symbolTable.replaceAllSymbolUses (
84
- barOp, StringAttr::get (context. get (), " baz" ), module );
92
+ barOp, builder-> getStringAttr ( " baz" ), module );
85
93
});
86
94
}
87
95
@@ -90,8 +98,7 @@ TEST_F(ReplaceAllSymbolUsesTest, StringAttrInModuleOp) {
90
98
testReplaceAllSymbolUses ([&](auto symbolTable, auto module , auto fooOp,
91
99
auto barOp) -> LogicalResult {
92
100
return symbolTable.replaceAllSymbolUses (
93
- StringAttr::get (context.get (), " bar" ),
94
- StringAttr::get (context.get (), " baz" ), module );
101
+ builder->getStringAttr (" bar" ), builder->getStringAttr (" baz" ), module );
95
102
});
96
103
}
97
104
@@ -100,17 +107,17 @@ TEST_F(ReplaceAllSymbolUsesTest, OperationInModuleBody) {
100
107
testReplaceAllSymbolUses ([&](auto symbolTable, auto module , auto fooOp,
101
108
auto barOp) -> LogicalResult {
102
109
return symbolTable.replaceAllSymbolUses (
103
- barOp, StringAttr::get (context. get (), " baz" ), &module ->getRegion (0 ));
110
+ barOp, builder-> getStringAttr ( " baz" ), &module ->getRegion (0 ));
104
111
});
105
112
}
106
113
107
114
TEST_F (ReplaceAllSymbolUsesTest, StringAttrInModuleBody) {
108
115
// Symbol as `StringAttr`, rename within module body.
109
116
testReplaceAllSymbolUses ([&](auto symbolTable, auto module , auto fooOp,
110
117
auto barOp) -> LogicalResult {
111
- return symbolTable.replaceAllSymbolUses (
112
- StringAttr::get (context. get (), " bar " ),
113
- StringAttr::get (context. get (), " baz " ), &module ->getRegion (0 ));
118
+ return symbolTable.replaceAllSymbolUses (builder-> getStringAttr ( " bar " ),
119
+ builder-> getStringAttr ( " baz " ),
120
+ &module ->getRegion (0 ));
114
121
});
115
122
}
116
123
@@ -119,7 +126,7 @@ TEST_F(ReplaceAllSymbolUsesTest, OperationInFuncOp) {
119
126
testReplaceAllSymbolUses ([&](auto symbolTable, auto module , auto fooOp,
120
127
auto barOp) -> LogicalResult {
121
128
return symbolTable.replaceAllSymbolUses (
122
- barOp, StringAttr::get (context. get (), " baz" ), fooOp);
129
+ barOp, builder-> getStringAttr ( " baz" ), fooOp);
123
130
});
124
131
}
125
132
@@ -128,8 +135,7 @@ TEST_F(ReplaceAllSymbolUsesTest, StringAttrInFuncOp) {
128
135
testReplaceAllSymbolUses ([&](auto symbolTable, auto module , auto fooOp,
129
136
auto barOp) -> LogicalResult {
130
137
return symbolTable.replaceAllSymbolUses (
131
- StringAttr::get (context.get (), " bar" ),
132
- StringAttr::get (context.get (), " baz" ), fooOp);
138
+ builder->getStringAttr (" bar" ), builder->getStringAttr (" baz" ), fooOp);
133
139
});
134
140
}
135
141
0 commit comments