Skip to content

Commit 96f54ed

Browse files
committed
fix: address comments from reviewers.
1 parent 0339abb commit 96f54ed

File tree

5 files changed

+182
-151
lines changed

5 files changed

+182
-151
lines changed

mlir/lib/Tools/PDLL/Parser/Lexer.cpp

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -329,7 +329,6 @@ Token Lexer::lexIdentifier(const char *tokStart) {
329329
.Case("Value", Token::kw_Value)
330330
.Case("ValueRange", Token::kw_ValueRange)
331331
.Case("with", Token::kw_with)
332-
.Case("array", Token::kw_Array)
333332
.Case("_", Token::underscore)
334333
.Default(Token::identifier);
335334
return Token(kind, str);
@@ -379,7 +378,7 @@ Token Lexer::lexString(const char *tokStart, bool isStringBlock) {
379378
--curPtr;
380379

381380
StringRef expectedEndStr = isStringBlock ? "}]" : "\"";
382-
return emitError(curPtr - 1,
381+
return emitError(tokStart,
383382
"expected '" + expectedEndStr + "' in string literal");
384383
}
385384

mlir/lib/Tools/PDLL/Parser/Lexer.h

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -54,7 +54,6 @@ class Token {
5454

5555
/// General keywords.
5656
kw_Attr,
57-
kw_Array,
5857
kw_erase,
5958
kw_let,
6059
kw_Constraint,

mlir/lib/Tools/PDLL/Parser/Parser.cpp

Lines changed: 132 additions & 146 deletions
Original file line numberDiff line numberDiff line change
@@ -314,9 +314,11 @@ class Parser {
314314
FailureOr<ast::Expr *> parseExpr();
315315

316316
/// Identifier expressions.
317+
FailureOr<ast::Expr *> parseArrayAttrExpr();
317318
FailureOr<ast::Expr *> parseAttributeExpr();
318319
FailureOr<ast::Expr *> parseCallExpr(ast::Expr *parentExpr);
319320
FailureOr<ast::Expr *> parseDeclRefExpr(StringRef name, SMRange loc);
321+
FailureOr<ast::Expr *> parseDictExpr();
320322
FailureOr<ast::Expr *> parseIdentifierExpr();
321323
FailureOr<ast::Expr *> parseInlineConstraintLambdaExpr();
322324
FailureOr<ast::Expr *> parseInlineRewriteLambdaExpr();
@@ -329,8 +331,6 @@ class Parser {
329331
FailureOr<ast::Expr *> parseTupleExpr();
330332
FailureOr<ast::Expr *> parseTypeExpr();
331333
FailureOr<ast::Expr *> parseUnderscoreExpr();
332-
FailureOr<ast::Expr *> parseDictExpr();
333-
FailureOr<ast::Expr *> parseArrayAttrExpr();
334334
//===--------------------------------------------------------------------===//
335335
// Stmts
336336

@@ -414,6 +414,13 @@ class Parser {
414414
FailureOr<ast::MemberAccessExpr *>
415415
createMemberAccessExpr(ast::Expr *parentExpr, StringRef name, SMRange loc);
416416

417+
// Create a native call with \p nativeFuncName and \p arguments.
418+
// This should be accompanied by a C++ implementation of the function that
419+
// needs to be linked and registered in passes that process PDLL files.
420+
FailureOr<ast::DeclRefExpr *>
421+
createNativeCall(SMRange loc, StringRef nativeFuncName,
422+
MutableArrayRef<ast::Expr *> arguments);
423+
417424
/// Validate the member access `name` into the given parent expression. On
418425
/// success, this also returns the type of the member accessed.
419426
FailureOr<ast::Type> validateMemberAccess(ast::Expr *parentExpr,
@@ -443,9 +450,6 @@ class Parser {
443450
FailureOr<ast::TupleExpr *> createTupleExpr(SMRange loc,
444451
ArrayRef<ast::Expr *> elements,
445452
ArrayRef<StringRef> elementNames);
446-
FailureOr<ast::DeclRefExpr *>
447-
createNativeCall(SMRange loc, StringRef nativeFuncName,
448-
MutableArrayRef<ast::Expr *> arguments);
449453

450454
//===--------------------------------------------------------------------===//
451455
// Stmts
@@ -1822,9 +1826,12 @@ FailureOr<ast::Expr *> Parser::parseExpr() {
18221826
case Token::l_brace:
18231827
lhsExpr = parseDictExpr();
18241828
break;
1825-
case Token::kw_Array:
1829+
case Token::l_square:
18261830
lhsExpr = parseArrayAttrExpr();
18271831
break;
1832+
case Token::string_block:
1833+
return emitError("expected expression. If you are trying to create an "
1834+
"ArrayAttr, use a space between `[` and `{`.");
18281835
default:
18291836
return emitError("expected expression");
18301837
}
@@ -1848,6 +1855,40 @@ FailureOr<ast::Expr *> Parser::parseExpr() {
18481855
}
18491856
}
18501857

1858+
FailureOr<ast::Expr *> Parser::parseArrayAttrExpr() {
1859+
1860+
consumeToken(Token::l_square);
1861+
1862+
if (parserContext != ParserContext::Rewrite)
1863+
return emitError(
1864+
"Parsing of array attributes as constraint not supported!");
1865+
1866+
auto arrayAttrCall =
1867+
createNativeCall(curToken.getLoc(), "createArrayAttr", {});
1868+
if (failed(arrayAttrCall))
1869+
return failure();
1870+
1871+
do {
1872+
FailureOr<ast::Expr *> attr = parseExpr();
1873+
if (failed(attr))
1874+
return failure();
1875+
1876+
SmallVector<ast::Expr *> arrayAttrArgs{*arrayAttrCall, *attr};
1877+
auto elemToArrayCall = createNativeCall(
1878+
curToken.getLoc(), "addElemToArrayAttr", arrayAttrArgs);
1879+
if (failed(elemToArrayCall))
1880+
return failure();
1881+
1882+
// Uses the new array for the next element.
1883+
arrayAttrCall = elemToArrayCall;
1884+
} while (consumeIf(Token::comma));
1885+
1886+
if (failed(
1887+
parseToken(Token::r_square, "expected `]` to close array attribute")))
1888+
return failure();
1889+
return arrayAttrCall;
1890+
}
1891+
18511892
FailureOr<ast::Expr *> Parser::parseAttributeExpr() {
18521893
SMRange loc = curToken.getLoc();
18531894
consumeToken(Token::kw_attr);
@@ -1906,6 +1947,62 @@ FailureOr<ast::Expr *> Parser::parseDeclRefExpr(StringRef name, SMRange loc) {
19061947
return createDeclRefExpr(loc, decl);
19071948
}
19081949

1950+
FailureOr<ast::Expr *> Parser::parseDictExpr() {
1951+
consumeToken(Token::l_brace);
1952+
SMRange loc = curToken.getLoc();
1953+
1954+
if (parserContext != ParserContext::Rewrite)
1955+
return emitError(
1956+
"Parsing of dictionary attributes as constraint not supported!");
1957+
1958+
auto dictAttrCall = createNativeCall(loc, "createDictionaryAttr", {});
1959+
if (failed(dictAttrCall))
1960+
return failure();
1961+
1962+
// Add each nested attribute to the dict
1963+
do {
1964+
FailureOr<ast::NamedAttributeDecl *> decl =
1965+
parseNamedAttributeDecl(std::nullopt);
1966+
if (failed(decl))
1967+
return failure();
1968+
1969+
ast::NamedAttributeDecl *namedDecl = *decl;
1970+
1971+
std::string stringAttrValue =
1972+
"\"" + std::string((*namedDecl).getName().getName()) + "\"";
1973+
auto *stringAttr = ast::AttributeExpr::create(ctx, loc, stringAttrValue);
1974+
1975+
// Declare it as a variable
1976+
std::string anonName =
1977+
llvm::formatv("dict{0}", anonymousDeclNameCounter++).str();
1978+
FailureOr<ast::VariableDecl *> stringAttrDecl =
1979+
createVariableDecl(anonName, namedDecl->getLoc(), stringAttr, {});
1980+
if (failed(stringAttrDecl))
1981+
return failure();
1982+
1983+
// Get its reference
1984+
auto stringAttrRef = parseDeclRefExpr(
1985+
(*stringAttrDecl)->getName().getName(), namedDecl->getLoc());
1986+
if (failed(stringAttrRef))
1987+
return failure();
1988+
1989+
// Create addEntryToDictionaryAttr native call.
1990+
SmallVector<ast::Expr *> arrayAttrArgs{*dictAttrCall, *stringAttrRef,
1991+
namedDecl->getValue()};
1992+
auto entryToDictionaryCall =
1993+
createNativeCall(loc, "addEntryToDictionaryAttr", arrayAttrArgs);
1994+
if (failed(entryToDictionaryCall))
1995+
return failure();
1996+
1997+
// Uses the new array for the next element.
1998+
dictAttrCall = entryToDictionaryCall;
1999+
} while (consumeIf(Token::comma));
2000+
if (failed(parseToken(Token::r_brace,
2001+
"expected `}` to close dictionary attribute")))
2002+
return failure();
2003+
return dictAttrCall;
2004+
}
2005+
19092006
FailureOr<ast::Expr *> Parser::parseIdentifierExpr() {
19102007
StringRef name = curToken.getSpelling();
19112008
SMRange nameLoc = curToken.getLoc();
@@ -2255,114 +2352,6 @@ FailureOr<ast::Expr *> Parser::parseUnderscoreExpr() {
22552352
return createInlineVariableExpr(type, name, nameLoc, constraints);
22562353
}
22572354

2258-
FailureOr<ast::Expr *> Parser::parseDictExpr() {
2259-
consumeToken(Token::l_brace);
2260-
SMRange loc = curToken.getLoc();
2261-
2262-
if (parserContext != ParserContext::Rewrite)
2263-
return emitError(
2264-
"Parsing of dictionary attributes as constraint not supported!");
2265-
2266-
auto dictAttrCall = createNativeCall(loc, "createDictionaryAttr", {});
2267-
if (failed(dictAttrCall))
2268-
return failure();
2269-
2270-
// Add each nested attribute to the dict
2271-
do {
2272-
FailureOr<ast::NamedAttributeDecl *> decl =
2273-
parseNamedAttributeDecl(llvm::None);
2274-
if (failed(decl))
2275-
return failure();
2276-
2277-
ast::NamedAttributeDecl *namedDecl = *decl;
2278-
2279-
std::string stringAttrValue =
2280-
"\"" + std::string((*namedDecl).getName().getName()) + "\"";
2281-
auto *stringAttr = ast::AttributeExpr::create(ctx, loc, stringAttrValue);
2282-
2283-
// Declare it as a variable
2284-
std::string anonName =
2285-
llvm::formatv("dict{0}", anonymousDeclNameCounter++).str();
2286-
FailureOr<ast::VariableDecl *> stringAttrDecl =
2287-
createVariableDecl(anonName, namedDecl->getLoc(), stringAttr, {});
2288-
if (failed(stringAttrDecl))
2289-
return failure();
2290-
2291-
// Get its reference
2292-
auto stringAttrRef = parseDeclRefExpr(
2293-
(*stringAttrDecl)->getName().getName(), namedDecl->getLoc());
2294-
if (failed(stringAttrRef))
2295-
return failure();
2296-
2297-
// Create addEntryToDictionaryAttr native call.
2298-
SmallVector<ast::Expr *> arrayAttrArgs{*dictAttrCall, *stringAttrRef,
2299-
namedDecl->getValue()};
2300-
auto entryToDictionaryCall =
2301-
createNativeCall(loc, "addEntryToDictionaryAttr", arrayAttrArgs);
2302-
if (failed(entryToDictionaryCall))
2303-
return failure();
2304-
2305-
// Uses the new array for the next element.
2306-
dictAttrCall = entryToDictionaryCall;
2307-
} while (consumeIf(Token::comma));
2308-
if (failed(parseToken(Token::r_brace,
2309-
"expected `}` to close dictionary attribute")))
2310-
return failure();
2311-
return dictAttrCall;
2312-
}
2313-
2314-
FailureOr<ast::Expr *> Parser::parseArrayAttrExpr() {
2315-
2316-
// Advance to the next token without failing.
2317-
auto nextToken = [&](Token &curToken, int64_t offset) {
2318-
SMRange loc = curToken.getLoc();
2319-
SMRange dictArraysLoc(
2320-
loc.Start.getFromPointer(loc.Start.getPointer() + offset),
2321-
curToken.getEndLoc());
2322-
resetToken(dictArraysLoc);
2323-
};
2324-
2325-
SMRange loc = curToken.getLoc();
2326-
2327-
const char tokenAfterArray = *loc.End.getPointer();
2328-
if (tokenAfterArray != '[')
2329-
return emitError(curToken.getLoc(), "expected `[` after `array`.");
2330-
2331-
// Consume `array[` token by advancing 6 characters.
2332-
// Since the lexer misinterprets `[{` as a string_block, we can't consume the
2333-
// array token in the normal way. Instead, advance to the next token without
2334-
// looking at the new Token::Kind.
2335-
nextToken(curToken, 6);
2336-
2337-
if (parserContext != ParserContext::Rewrite)
2338-
return emitError(
2339-
"Parsing of array attributes as constraint not supported!");
2340-
2341-
auto arrayAttrCall = createNativeCall(loc, "createArrayAttr", {});
2342-
if (failed(arrayAttrCall))
2343-
return failure();
2344-
2345-
do {
2346-
FailureOr<ast::Expr *> attr = parseExpr();
2347-
if (failed(attr))
2348-
return failure();
2349-
2350-
SmallVector<ast::Expr *> arrayAttrArgs{*arrayAttrCall, *attr};
2351-
auto elemToArrayCall =
2352-
createNativeCall(loc, "addElemToArrayAttr", arrayAttrArgs);
2353-
if (failed(elemToArrayCall))
2354-
return failure();
2355-
2356-
// Uses the new array for the next element.
2357-
arrayAttrCall = elemToArrayCall;
2358-
} while (consumeIf(Token::comma));
2359-
2360-
if (failed(
2361-
parseToken(Token::r_square, "expected `]` to close array attribute")))
2362-
return failure();
2363-
return arrayAttrCall;
2364-
}
2365-
23662355
//===----------------------------------------------------------------------===//
23672356
// Stmts
23682357

@@ -2887,6 +2876,35 @@ Parser::createMemberAccessExpr(ast::Expr *parentExpr, StringRef name,
28872876
return ast::MemberAccessExpr::create(ctx, loc, parentExpr, name, *memberType);
28882877
}
28892878

2879+
FailureOr<ast::DeclRefExpr *>
2880+
Parser::createNativeCall(SMRange loc, StringRef nativeFuncName,
2881+
MutableArrayRef<ast::Expr *> arguments) {
2882+
2883+
FailureOr<ast::Expr *> nativeFuncExpr = parseDeclRefExpr(nativeFuncName, loc);
2884+
if (failed(nativeFuncExpr))
2885+
return emitError(nativeFuncName + " not found.");
2886+
2887+
if (!(*nativeFuncExpr)->getType().isa<ast::RewriteType>())
2888+
return emitError(nativeFuncName + " should be defined as a rewriter.");
2889+
2890+
FailureOr<ast::CallExpr *> nativeCall =
2891+
createCallExpr(loc, *nativeFuncExpr, arguments);
2892+
if (failed(nativeCall))
2893+
return failure();
2894+
2895+
// Create a unique anonymous name to use, as the name for this decl is not
2896+
// important.
2897+
std::string anonName =
2898+
llvm::formatv("{0}_{1}", nativeFuncName, anonymousDeclNameCounter++)
2899+
.str();
2900+
FailureOr<ast::VariableDecl *> varDecl = defineVariableDecl(
2901+
anonName, loc, (*nativeCall)->getType(), *nativeCall, {});
2902+
if (failed(varDecl))
2903+
return failure();
2904+
2905+
return createDeclRefExpr(loc, *varDecl);
2906+
}
2907+
28902908
FailureOr<ast::Type> Parser::validateMemberAccess(ast::Expr *parentExpr,
28912909
StringRef name, SMRange loc) {
28922910
ast::Type parentType = parentExpr->getType();
@@ -3166,38 +3184,6 @@ Parser::createTupleExpr(SMRange loc, ArrayRef<ast::Expr *> elements,
31663184
return ast::TupleExpr::create(ctx, loc, elements, elementNames);
31673185
}
31683186

3169-
FailureOr<ast::DeclRefExpr *>
3170-
Parser::createNativeCall(SMRange loc, StringRef nativeFuncName,
3171-
MutableArrayRef<ast::Expr *> arguments) {
3172-
3173-
FailureOr<ast::Expr *> nativeFuncExpr = parseDeclRefExpr(nativeFuncName, loc);
3174-
if (failed(nativeFuncExpr))
3175-
return emitError(nativeFuncName + " not found.");
3176-
3177-
if (!(*nativeFuncExpr)->getType().isa<ast::RewriteType>())
3178-
return emitError(nativeFuncName + " should be defined as a rewriter.");
3179-
3180-
FailureOr<ast::CallExpr *> nativeCall =
3181-
createCallExpr(loc, *nativeFuncExpr, arguments);
3182-
if (failed(nativeCall))
3183-
return failure();
3184-
3185-
// Create a unique anonymous name to use, as the name for this decl is not
3186-
// important.
3187-
std::string anonName =
3188-
llvm::formatv("{0}_{1}", nativeFuncName, anonymousDeclNameCounter++)
3189-
.str();
3190-
FailureOr<ast::VariableDecl *> varDecl = defineVariableDecl(
3191-
anonName, loc, (*nativeCall)->getType(), *nativeCall, {});
3192-
if (failed(varDecl))
3193-
return failure();
3194-
3195-
FailureOr<ast::DeclRefExpr *> arrayAttrReference =
3196-
createDeclRefExpr(loc, *varDecl);
3197-
3198-
return *arrayAttrReference;
3199-
}
3200-
32013187
//===----------------------------------------------------------------------===//
32023188
// Stmts
32033189

0 commit comments

Comments
 (0)