Skip to content

Commit aa21ce4

Browse files
authored
[mlir] Do not set lastToken in AsmParser's resetToken function and add a unit test for AsmParsers's locations (#105529)
This changes the function `resetToken` to not update `lastToken`. The member `lastToken` is the last token that was consumed by the parser. Resetting the lexer position to a different position does not cause any token to be consumed, so `lastToken` should not be updated. Setting it to `curToken` can cause the scopeLoc.end location of `OperationDefinition `to be off-by-one, pointing to the first token after the operation. An example for an operation for which the scopeLoc.end location was wrong before is: ``` %0 = torch.vtensor.literal(dense_resource<__elided__> : tensor<768xbf16>) : !torch.vtensor<[768],bf16> ``` Here the scope end loc always pointed to the next token This also adds a test for the Locations of `OperationDefinitions`. Without the change to `resetToken` the test failes, with the scope end location for `llvm.mlir.undef` pointing to the `func.return` in the next line
1 parent b98aa6f commit aa21ce4

File tree

3 files changed

+89
-2
lines changed

3 files changed

+89
-2
lines changed

mlir/lib/AsmParser/Parser.h

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -130,10 +130,14 @@ class Parser {
130130
consumeToken();
131131
}
132132

133-
/// Reset the parser to the given lexer position.
133+
/// Reset the parser to the given lexer position. Resetting the parser/lexer
134+
/// position does not update 'state.lastToken'. 'state.lastToken' is the
135+
/// last parsed token, and is used to provide the scope end location for
136+
/// OperationDefinitions. To ensure the correctness of the end location, the
137+
/// last consumed token of an OperationDefinition needs to be the last token
138+
/// belonging to it.
134139
void resetToken(const char *tokPos) {
135140
state.lex.resetPointer(tokPos);
136-
state.lastToken = state.curToken;
137141
state.curToken = state.lex.lexToken();
138142
}
139143

mlir/unittests/Parser/CMakeLists.txt

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,8 @@ add_mlir_unittest(MLIRParserTests
88
target_include_directories(MLIRParserTests PRIVATE "${MLIR_BINARY_DIR}/test/lib/Dialect/Test")
99

1010
target_link_libraries(MLIRParserTests PRIVATE
11+
MLIRFuncDialect
12+
MLIRLLVMDialect
1113
MLIRIR
1214
MLIRParser
1315
MLIRTestDialect

mlir/unittests/Parser/ParserTest.cpp

Lines changed: 81 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -8,8 +8,12 @@
88

99
#include "mlir/Parser/Parser.h"
1010
#include "mlir/AsmParser/AsmParser.h"
11+
#include "mlir/AsmParser/AsmParserState.h"
12+
#include "mlir/Dialect/Func/IR/FuncOps.h"
13+
#include "mlir/Dialect/LLVMIR/LLVMDialect.h"
1114
#include "mlir/IR/BuiltinOps.h"
1215
#include "mlir/IR/Verifier.h"
16+
#include "llvm/Support/SourceMgr.h"
1317

1418
#include "gmock/gmock.h"
1519

@@ -101,4 +105,81 @@ TEST(MLIRParser, ParseAttr) {
101105
EXPECT_EQ(attr, b.getI64IntegerAttr(9));
102106
}
103107
}
108+
109+
TEST(MLIRParser, AsmParserLocations) {
110+
std::string moduleStr = R"mlir(
111+
func.func @foo() -> !llvm.array<2 x f32> {
112+
%0 = llvm.mlir.undef : !llvm.array<2 x f32>
113+
func.return %0 : !llvm.array<2 x f32>
114+
}
115+
)mlir";
116+
117+
DialectRegistry registry;
118+
registry.insert<func::FuncDialect, LLVM::LLVMDialect>();
119+
MLIRContext context(registry);
120+
121+
auto memBuffer =
122+
llvm::MemoryBuffer::getMemBuffer(moduleStr, "AsmParserTest.mlir",
123+
/*RequiresNullTerminator=*/false);
124+
ASSERT_TRUE(memBuffer);
125+
126+
llvm::SourceMgr sourceMgr;
127+
sourceMgr.AddNewSourceBuffer(std::move(memBuffer), llvm::SMLoc());
128+
129+
Block block;
130+
AsmParserState parseState;
131+
const LogicalResult parseResult =
132+
parseAsmSourceFile(sourceMgr, &block, &context, &parseState);
133+
ASSERT_TRUE(parseResult.succeeded());
134+
135+
auto funcOp = *block.getOps<func::FuncOp>().begin();
136+
const AsmParserState::OperationDefinition *funcOpDefinition =
137+
parseState.getOpDef(funcOp);
138+
ASSERT_TRUE(funcOpDefinition);
139+
140+
const std::pair expectedStartFunc{2u, 1u};
141+
const std::pair expectedEndFunc{2u, 10u};
142+
const std::pair expectedScopeEndFunc{5u, 2u};
143+
ASSERT_EQ(sourceMgr.getLineAndColumn(funcOpDefinition->loc.Start),
144+
expectedStartFunc);
145+
ASSERT_EQ(sourceMgr.getLineAndColumn(funcOpDefinition->loc.End),
146+
expectedEndFunc);
147+
ASSERT_EQ(funcOpDefinition->loc.Start, funcOpDefinition->scopeLoc.Start);
148+
ASSERT_EQ(sourceMgr.getLineAndColumn(funcOpDefinition->scopeLoc.End),
149+
expectedScopeEndFunc);
150+
151+
auto llvmUndef = *funcOp.getOps<LLVM::UndefOp>().begin();
152+
const AsmParserState::OperationDefinition *llvmUndefDefinition =
153+
parseState.getOpDef(llvmUndef);
154+
ASSERT_TRUE(llvmUndefDefinition);
155+
156+
const std::pair expectedStartUndef{3u, 8u};
157+
const std::pair expectedEndUndef{3u, 23u};
158+
const std::pair expectedScopeEndUndef{3u, 46u};
159+
ASSERT_EQ(sourceMgr.getLineAndColumn(llvmUndefDefinition->loc.Start),
160+
expectedStartUndef);
161+
ASSERT_EQ(sourceMgr.getLineAndColumn(llvmUndefDefinition->loc.End),
162+
expectedEndUndef);
163+
ASSERT_EQ(llvmUndefDefinition->loc.Start,
164+
llvmUndefDefinition->scopeLoc.Start);
165+
ASSERT_EQ(sourceMgr.getLineAndColumn(llvmUndefDefinition->scopeLoc.End),
166+
expectedScopeEndUndef);
167+
168+
auto funcReturn = *funcOp.getOps<func::ReturnOp>().begin();
169+
const AsmParserState::OperationDefinition *funcReturnDefinition =
170+
parseState.getOpDef(funcReturn);
171+
ASSERT_TRUE(funcReturnDefinition);
172+
173+
const std::pair expectedStartReturn{4u, 3u};
174+
const std::pair expectedEndReturn{4u, 14u};
175+
const std::pair expectedScopeEndReturn{4u, 40u};
176+
ASSERT_EQ(sourceMgr.getLineAndColumn(funcReturnDefinition->loc.Start),
177+
expectedStartReturn);
178+
ASSERT_EQ(sourceMgr.getLineAndColumn(funcReturnDefinition->loc.End),
179+
expectedEndReturn);
180+
ASSERT_EQ(funcReturnDefinition->loc.Start,
181+
funcReturnDefinition->scopeLoc.Start);
182+
ASSERT_EQ(sourceMgr.getLineAndColumn(funcReturnDefinition->scopeLoc.End),
183+
expectedScopeEndReturn);
184+
}
104185
} // namespace

0 commit comments

Comments
 (0)