Skip to content

Commit da99b86

Browse files
[mlir] Extend tests of SymbolTable::replaceAllSymbolUses.
This is a follow-up commit for 4790578@llvm/llvm-project (llvm#68320) that adds more tests. In particular, the tests now check that the `limit` op itself is not traversed, i.e., symbols in attributes in of the `limit` op are not renamed.
1 parent 4803ba9 commit da99b86

File tree

1 file changed

+18
-12
lines changed

1 file changed

+18
-12
lines changed

mlir/unittests/IR/SymbolTableTest.cpp

Lines changed: 18 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -28,12 +28,14 @@ class ReplaceAllSymbolUsesTest : public ::testing::Test {
2828
void SetUp() override {
2929
::test::registerTestDialect(registry);
3030
context = std::make_unique<MLIRContext>(registry);
31+
builder = std::make_unique<OpBuilder>(context.get());
3132
}
3233

3334
void testReplaceAllSymbolUses(ReplaceFnType replaceFn) {
3435
// Set up IR and find func ops.
3536
OwningOpRef<ModuleOp> module =
3637
parseSourceString<ModuleOp>(kInput, context.get());
38+
ASSERT_TRUE(module);
3739
SymbolTable symbolTable(module.get());
3840
auto opIterator = module->getBody(0)->getOperations().begin();
3941
auto fooOp = cast<FunctionOpInterface>(opIterator++);
@@ -46,7 +48,7 @@ class ReplaceAllSymbolUsesTest : public ::testing::Test {
4648
ASSERT_TRUE(succeeded(res));
4749
ASSERT_TRUE(succeeded(verify(module.get())));
4850

49-
// Check that it got renamed.
51+
// Check that callee of the call op got renamed.
5052
bool calleeFound = false;
5153
fooOp->walk([&](CallOpInterface callOp) {
5254
StringAttr callee = callOp.getCallableForCallee()
@@ -56,13 +58,19 @@ class ReplaceAllSymbolUsesTest : public ::testing::Test {
5658
calleeFound = true;
5759
});
5860
EXPECT_TRUE(calleeFound);
61+
62+
// Check that module attribute did *not* get renamed.
63+
auto moduleAttr = (*module)->getAttrOfType<FlatSymbolRefAttr>("test.attr");
64+
ASSERT_TRUE(moduleAttr);
65+
EXPECT_EQ(moduleAttr.getValue(), StringRef("bar"));
5966
}
6067

6168
std::unique_ptr<MLIRContext> context;
69+
std::unique_ptr<OpBuilder> builder;
6270

6371
private:
6472
constexpr static llvm::StringLiteral kInput = R"MLIR(
65-
module {
73+
module attributes { test.attr = @bar } {
6674
test.conversion_func_op private @foo() {
6775
"test.conversion_call_op"() { callee=@bar } : () -> ()
6876
"test.return"() : () -> ()
@@ -81,7 +89,7 @@ TEST_F(ReplaceAllSymbolUsesTest, OperationInModuleOp) {
8189
testReplaceAllSymbolUses([&](auto symbolTable, auto module, auto fooOp,
8290
auto barOp) -> LogicalResult {
8391
return symbolTable.replaceAllSymbolUses(
84-
barOp, StringAttr::get(context.get(), "baz"), module);
92+
barOp, builder->getStringAttr("baz"), module);
8593
});
8694
}
8795

@@ -90,8 +98,7 @@ TEST_F(ReplaceAllSymbolUsesTest, StringAttrInModuleOp) {
9098
testReplaceAllSymbolUses([&](auto symbolTable, auto module, auto fooOp,
9199
auto barOp) -> LogicalResult {
92100
return symbolTable.replaceAllSymbolUses(
93-
StringAttr::get(context.get(), "bar"),
94-
StringAttr::get(context.get(), "baz"), module);
101+
builder->getStringAttr("bar"), builder->getStringAttr("baz"), module);
95102
});
96103
}
97104

@@ -100,17 +107,17 @@ TEST_F(ReplaceAllSymbolUsesTest, OperationInModuleBody) {
100107
testReplaceAllSymbolUses([&](auto symbolTable, auto module, auto fooOp,
101108
auto barOp) -> LogicalResult {
102109
return symbolTable.replaceAllSymbolUses(
103-
barOp, StringAttr::get(context.get(), "baz"), &module->getRegion(0));
110+
barOp, builder->getStringAttr("baz"), &module->getRegion(0));
104111
});
105112
}
106113

107114
TEST_F(ReplaceAllSymbolUsesTest, StringAttrInModuleBody) {
108115
// Symbol as `StringAttr`, rename within module body.
109116
testReplaceAllSymbolUses([&](auto symbolTable, auto module, auto fooOp,
110117
auto barOp) -> LogicalResult {
111-
return symbolTable.replaceAllSymbolUses(
112-
StringAttr::get(context.get(), "bar"),
113-
StringAttr::get(context.get(), "baz"), &module->getRegion(0));
118+
return symbolTable.replaceAllSymbolUses(builder->getStringAttr("bar"),
119+
builder->getStringAttr("baz"),
120+
&module->getRegion(0));
114121
});
115122
}
116123

@@ -119,7 +126,7 @@ TEST_F(ReplaceAllSymbolUsesTest, OperationInFuncOp) {
119126
testReplaceAllSymbolUses([&](auto symbolTable, auto module, auto fooOp,
120127
auto barOp) -> LogicalResult {
121128
return symbolTable.replaceAllSymbolUses(
122-
barOp, StringAttr::get(context.get(), "baz"), fooOp);
129+
barOp, builder->getStringAttr("baz"), fooOp);
123130
});
124131
}
125132

@@ -128,8 +135,7 @@ TEST_F(ReplaceAllSymbolUsesTest, StringAttrInFuncOp) {
128135
testReplaceAllSymbolUses([&](auto symbolTable, auto module, auto fooOp,
129136
auto barOp) -> LogicalResult {
130137
return symbolTable.replaceAllSymbolUses(
131-
StringAttr::get(context.get(), "bar"),
132-
StringAttr::get(context.get(), "baz"), fooOp);
138+
builder->getStringAttr("bar"), builder->getStringAttr("baz"), fooOp);
133139
});
134140
}
135141

0 commit comments

Comments
 (0)