@@ -328,7 +328,8 @@ class Parser {
328
328
FailureOr<ast::Expr *> parseTupleExpr ();
329
329
FailureOr<ast::Expr *> parseTypeExpr ();
330
330
FailureOr<ast::Expr *> parseUnderscoreExpr ();
331
-
331
+ FailureOr<ast::Expr *> parseDictExpr ();
332
+ FailureOr<ast::Expr *> parseArrayAttrExpr ();
332
333
// ===--------------------------------------------------------------------===//
333
334
// Stmts
334
335
@@ -440,6 +441,9 @@ class Parser {
440
441
FailureOr<ast::TupleExpr *> createTupleExpr (SMRange loc,
441
442
ArrayRef<ast::Expr *> elements,
442
443
ArrayRef<StringRef> elementNames);
444
+ FailureOr<ast::DeclRefExpr *>
445
+ createNativeCall (SMRange loc, StringRef nativeFuncName,
446
+ MutableArrayRef<ast::Expr *> arguments);
443
447
444
448
// ===--------------------------------------------------------------------===//
445
449
// Stmts
@@ -1813,6 +1817,12 @@ FailureOr<ast::Expr *> Parser::parseExpr() {
1813
1817
case Token::l_paren:
1814
1818
lhsExpr = parseTupleExpr ();
1815
1819
break ;
1820
+ case Token::l_brace:
1821
+ lhsExpr = parseDictExpr ();
1822
+ break ;
1823
+ case Token::kw_Array:
1824
+ lhsExpr = parseArrayAttrExpr ();
1825
+ break ;
1816
1826
default :
1817
1827
return emitError (" expected expression" );
1818
1828
}
@@ -2243,6 +2253,114 @@ FailureOr<ast::Expr *> Parser::parseUnderscoreExpr() {
2243
2253
return createInlineVariableExpr (type, name, nameLoc, constraints);
2244
2254
}
2245
2255
2256
+ FailureOr<ast::Expr *> Parser::parseDictExpr () {
2257
+ consumeToken (Token::l_brace);
2258
+ SMRange loc = curToken.getLoc ();
2259
+
2260
+ if (parserContext != ParserContext::Rewrite)
2261
+ return emitError (
2262
+ " Parsing of dictionary attributes as constraint not supported!" );
2263
+
2264
+ auto dictAttrCall = createNativeCall (loc, " createDictionaryAttr" , {});
2265
+ if (failed (dictAttrCall))
2266
+ return failure ();
2267
+
2268
+ // Add each nested attribute to the dict
2269
+ do {
2270
+ FailureOr<ast::NamedAttributeDecl *> decl =
2271
+ parseNamedAttributeDecl (llvm::None);
2272
+ if (failed (decl))
2273
+ return failure ();
2274
+
2275
+ ast::NamedAttributeDecl *namedDecl = *decl;
2276
+
2277
+ std::string stringAttrValue =
2278
+ " \" " + std::string ((*namedDecl).getName ().getName ()) + " \" " ;
2279
+ auto *stringAttr = ast::AttributeExpr::create (ctx, loc, stringAttrValue);
2280
+
2281
+ // Declare it as a variable
2282
+ std::string anonName =
2283
+ llvm::formatv (" dict{0}" , anonymousDeclNameCounter++).str ();
2284
+ FailureOr<ast::VariableDecl *> stringAttrDecl =
2285
+ createVariableDecl (anonName, namedDecl->getLoc (), stringAttr, {});
2286
+ if (failed (stringAttrDecl))
2287
+ return failure ();
2288
+
2289
+ // Get its reference
2290
+ auto stringAttrRef = parseDeclRefExpr (
2291
+ (*stringAttrDecl)->getName ().getName (), namedDecl->getLoc ());
2292
+ if (failed (stringAttrRef))
2293
+ return failure ();
2294
+
2295
+ // Create addEntryToDictionaryAttr native call.
2296
+ SmallVector<ast::Expr *> arrayAttrArgs{*dictAttrCall, *stringAttrRef,
2297
+ namedDecl->getValue ()};
2298
+ auto entryToDictionaryCall =
2299
+ createNativeCall (loc, " addEntryToDictionaryAttr" , arrayAttrArgs);
2300
+ if (failed (entryToDictionaryCall))
2301
+ return failure ();
2302
+
2303
+ // Uses the new array for the next element.
2304
+ dictAttrCall = entryToDictionaryCall;
2305
+ } while (consumeIf (Token::comma));
2306
+ if (failed (parseToken (Token::r_brace,
2307
+ " expected `}` to close dictionary attribute" )))
2308
+ return failure ();
2309
+ return dictAttrCall;
2310
+ }
2311
+
2312
+ FailureOr<ast::Expr *> Parser::parseArrayAttrExpr () {
2313
+
2314
+ // Advance to the next token without failing.
2315
+ auto nextToken = [&](Token &curToken, int64_t offset) {
2316
+ SMRange loc = curToken.getLoc ();
2317
+ SMRange dictArraysLoc (
2318
+ loc.Start .getFromPointer (loc.Start .getPointer () + offset),
2319
+ curToken.getEndLoc ());
2320
+ resetToken (dictArraysLoc);
2321
+ };
2322
+
2323
+ SMRange loc = curToken.getLoc ();
2324
+
2325
+ const char tokenAfterArray = *loc.End .getPointer ();
2326
+ if (tokenAfterArray != ' [' )
2327
+ return emitError (curToken.getLoc (), " expected `[` after `array`." );
2328
+
2329
+ // Consume `array[` token by advancing 6 characters.
2330
+ // Since the lexer misinterprets `[{` as a string_block, we can't consume the
2331
+ // array token in the normal way. Instead, advance to the next token without
2332
+ // looking at the new Token::Kind.
2333
+ nextToken (curToken, 6 );
2334
+
2335
+ if (parserContext != ParserContext::Rewrite)
2336
+ return emitError (
2337
+ " Parsing of array attributes as constraint not supported!" );
2338
+
2339
+ auto arrayAttrCall = createNativeCall (loc, " createArrayAttr" , {});
2340
+ if (failed (arrayAttrCall))
2341
+ return failure ();
2342
+
2343
+ do {
2344
+ FailureOr<ast::Expr *> attr = parseExpr ();
2345
+ if (failed (attr))
2346
+ return failure ();
2347
+
2348
+ SmallVector<ast::Expr *> arrayAttrArgs{*arrayAttrCall, *attr};
2349
+ auto elemToArrayCall =
2350
+ createNativeCall (loc, " addElemToArrayAttr" , arrayAttrArgs);
2351
+ if (failed (elemToArrayCall))
2352
+ return failure ();
2353
+
2354
+ // Uses the new array for the next element.
2355
+ arrayAttrCall = elemToArrayCall;
2356
+ } while (consumeIf (Token::comma));
2357
+
2358
+ if (failed (
2359
+ parseToken (Token::r_square, " expected `]` to close array attribute" )))
2360
+ return failure ();
2361
+ return arrayAttrCall;
2362
+ }
2363
+
2246
2364
// ===----------------------------------------------------------------------===//
2247
2365
// Stmts
2248
2366
@@ -3047,6 +3165,38 @@ Parser::createTupleExpr(SMRange loc, ArrayRef<ast::Expr *> elements,
3047
3165
return ast::TupleExpr::create (ctx, loc, elements, elementNames);
3048
3166
}
3049
3167
3168
+ FailureOr<ast::DeclRefExpr *>
3169
+ Parser::createNativeCall (SMRange loc, StringRef nativeFuncName,
3170
+ MutableArrayRef<ast::Expr *> arguments) {
3171
+
3172
+ FailureOr<ast::Expr *> nativeFuncExpr = parseDeclRefExpr (nativeFuncName, loc);
3173
+ if (failed (nativeFuncExpr))
3174
+ return emitError (nativeFuncName + " not found." );
3175
+
3176
+ if (isa<mlir::pdll::ast::RewriteStmt>(*nativeFuncExpr))
3177
+ return emitError (nativeFuncName + " should be defined as a rewriter." );
3178
+
3179
+ FailureOr<ast::CallExpr *> nativeCall =
3180
+ createCallExpr (loc, *nativeFuncExpr, arguments);
3181
+ if (failed (nativeCall))
3182
+ return failure ();
3183
+
3184
+ // Create a unique anonymous name to use, as the name for this decl is not
3185
+ // important.
3186
+ std::string anonName =
3187
+ llvm::formatv (" {0}_{1}" , nativeFuncName, anonymousDeclNameCounter++)
3188
+ .str ();
3189
+ FailureOr<ast::VariableDecl *> varDecl = defineVariableDecl (
3190
+ anonName, loc, (*nativeCall)->getType (), *nativeCall, {});
3191
+ if (failed (varDecl))
3192
+ return failure ();
3193
+
3194
+ FailureOr<ast::DeclRefExpr *> arrayAttrReference =
3195
+ createDeclRefExpr (loc, *varDecl);
3196
+
3197
+ return *arrayAttrReference;
3198
+ }
3199
+
3050
3200
// ===----------------------------------------------------------------------===//
3051
3201
// Stmts
3052
3202
0 commit comments