Skip to content

Commit b12189d

Browse files
committed
Feat: add support for Dictionary and Array attributes in PDLL rewrite
sections.
1 parent f891e20 commit b12189d

File tree

4 files changed

+249
-1
lines changed

4 files changed

+249
-1
lines changed

mlir/lib/Tools/PDLL/Parser/Lexer.cpp

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -329,6 +329,7 @@ Token Lexer::lexIdentifier(const char *tokStart) {
329329
.Case("Value", Token::kw_Value)
330330
.Case("ValueRange", Token::kw_ValueRange)
331331
.Case("with", Token::kw_with)
332+
.Case("array", Token::kw_Array)
332333
.Case("_", Token::underscore)
333334
.Default(Token::identifier);
334335
return Token(kind, str);

mlir/lib/Tools/PDLL/Parser/Lexer.h

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -54,6 +54,7 @@ class Token {
5454

5555
/// General keywords.
5656
kw_Attr,
57+
kw_Array,
5758
kw_erase,
5859
kw_let,
5960
kw_Constraint,

mlir/lib/Tools/PDLL/Parser/Parser.cpp

Lines changed: 151 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -328,7 +328,8 @@ class Parser {
328328
FailureOr<ast::Expr *> parseTupleExpr();
329329
FailureOr<ast::Expr *> parseTypeExpr();
330330
FailureOr<ast::Expr *> parseUnderscoreExpr();
331-
331+
FailureOr<ast::Expr *> parseDictExpr();
332+
FailureOr<ast::Expr *> parseArrayAttrExpr();
332333
//===--------------------------------------------------------------------===//
333334
// Stmts
334335

@@ -440,6 +441,9 @@ class Parser {
440441
FailureOr<ast::TupleExpr *> createTupleExpr(SMRange loc,
441442
ArrayRef<ast::Expr *> elements,
442443
ArrayRef<StringRef> elementNames);
444+
FailureOr<ast::DeclRefExpr *>
445+
createNativeCall(SMRange loc, StringRef nativeFuncName,
446+
MutableArrayRef<ast::Expr *> arguments);
443447

444448
//===--------------------------------------------------------------------===//
445449
// Stmts
@@ -1813,6 +1817,12 @@ FailureOr<ast::Expr *> Parser::parseExpr() {
18131817
case Token::l_paren:
18141818
lhsExpr = parseTupleExpr();
18151819
break;
1820+
case Token::l_brace:
1821+
lhsExpr = parseDictExpr();
1822+
break;
1823+
case Token::kw_Array:
1824+
lhsExpr = parseArrayAttrExpr();
1825+
break;
18161826
default:
18171827
return emitError("expected expression");
18181828
}
@@ -2243,6 +2253,114 @@ FailureOr<ast::Expr *> Parser::parseUnderscoreExpr() {
22432253
return createInlineVariableExpr(type, name, nameLoc, constraints);
22442254
}
22452255

2256+
FailureOr<ast::Expr *> Parser::parseDictExpr() {
2257+
consumeToken(Token::l_brace);
2258+
SMRange loc = curToken.getLoc();
2259+
2260+
if (parserContext != ParserContext::Rewrite)
2261+
return emitError(
2262+
"Parsing of dictionary attributes as constraint not supported!");
2263+
2264+
auto dictAttrCall = createNativeCall(loc, "createDictionaryAttr", {});
2265+
if (failed(dictAttrCall))
2266+
return failure();
2267+
2268+
// Add each nested attribute to the dict
2269+
do {
2270+
FailureOr<ast::NamedAttributeDecl *> decl =
2271+
parseNamedAttributeDecl(llvm::None);
2272+
if (failed(decl))
2273+
return failure();
2274+
2275+
ast::NamedAttributeDecl *namedDecl = *decl;
2276+
2277+
std::string stringAttrValue =
2278+
"\"" + std::string((*namedDecl).getName().getName()) + "\"";
2279+
auto *stringAttr = ast::AttributeExpr::create(ctx, loc, stringAttrValue);
2280+
2281+
// Declare it as a variable
2282+
std::string anonName =
2283+
llvm::formatv("dict{0}", anonymousDeclNameCounter++).str();
2284+
FailureOr<ast::VariableDecl *> stringAttrDecl =
2285+
createVariableDecl(anonName, namedDecl->getLoc(), stringAttr, {});
2286+
if (failed(stringAttrDecl))
2287+
return failure();
2288+
2289+
// Get its reference
2290+
auto stringAttrRef = parseDeclRefExpr(
2291+
(*stringAttrDecl)->getName().getName(), namedDecl->getLoc());
2292+
if (failed(stringAttrRef))
2293+
return failure();
2294+
2295+
// Create addEntryToDictionaryAttr native call.
2296+
SmallVector<ast::Expr *> arrayAttrArgs{*dictAttrCall, *stringAttrRef,
2297+
namedDecl->getValue()};
2298+
auto entryToDictionaryCall =
2299+
createNativeCall(loc, "addEntryToDictionaryAttr", arrayAttrArgs);
2300+
if (failed(entryToDictionaryCall))
2301+
return failure();
2302+
2303+
// Uses the new array for the next element.
2304+
dictAttrCall = entryToDictionaryCall;
2305+
} while (consumeIf(Token::comma));
2306+
if (failed(parseToken(Token::r_brace,
2307+
"expected `}` to close dictionary attribute")))
2308+
return failure();
2309+
return dictAttrCall;
2310+
}
2311+
2312+
FailureOr<ast::Expr *> Parser::parseArrayAttrExpr() {
2313+
2314+
// Advance to the next token without failing.
2315+
auto nextToken = [&](Token &curToken, int64_t offset) {
2316+
SMRange loc = curToken.getLoc();
2317+
SMRange dictArraysLoc(
2318+
loc.Start.getFromPointer(loc.Start.getPointer() + offset),
2319+
curToken.getEndLoc());
2320+
resetToken(dictArraysLoc);
2321+
};
2322+
2323+
SMRange loc = curToken.getLoc();
2324+
2325+
const char tokenAfterArray = *loc.End.getPointer();
2326+
if (tokenAfterArray != '[')
2327+
return emitError(curToken.getLoc(), "expected `[` after `array`.");
2328+
2329+
// Consume `array[` token by advancing 6 characters.
2330+
// Since the lexer misinterprets `[{` as a string_block, we can't consume the
2331+
// array token in the normal way. Instead, advance to the next token without
2332+
// looking at the new Token::Kind.
2333+
nextToken(curToken, 6);
2334+
2335+
if (parserContext != ParserContext::Rewrite)
2336+
return emitError(
2337+
"Parsing of array attributes as constraint not supported!");
2338+
2339+
auto arrayAttrCall = createNativeCall(loc, "createArrayAttr", {});
2340+
if (failed(arrayAttrCall))
2341+
return failure();
2342+
2343+
do {
2344+
FailureOr<ast::Expr *> attr = parseExpr();
2345+
if (failed(attr))
2346+
return failure();
2347+
2348+
SmallVector<ast::Expr *> arrayAttrArgs{*arrayAttrCall, *attr};
2349+
auto elemToArrayCall =
2350+
createNativeCall(loc, "addElemToArrayAttr", arrayAttrArgs);
2351+
if (failed(elemToArrayCall))
2352+
return failure();
2353+
2354+
// Uses the new array for the next element.
2355+
arrayAttrCall = elemToArrayCall;
2356+
} while (consumeIf(Token::comma));
2357+
2358+
if (failed(
2359+
parseToken(Token::r_square, "expected `]` to close array attribute")))
2360+
return failure();
2361+
return arrayAttrCall;
2362+
}
2363+
22462364
//===----------------------------------------------------------------------===//
22472365
// Stmts
22482366

@@ -3047,6 +3165,38 @@ Parser::createTupleExpr(SMRange loc, ArrayRef<ast::Expr *> elements,
30473165
return ast::TupleExpr::create(ctx, loc, elements, elementNames);
30483166
}
30493167

3168+
FailureOr<ast::DeclRefExpr *>
3169+
Parser::createNativeCall(SMRange loc, StringRef nativeFuncName,
3170+
MutableArrayRef<ast::Expr *> arguments) {
3171+
3172+
FailureOr<ast::Expr *> nativeFuncExpr = parseDeclRefExpr(nativeFuncName, loc);
3173+
if (failed(nativeFuncExpr))
3174+
return emitError(nativeFuncName + " not found.");
3175+
3176+
if (isa<mlir::pdll::ast::RewriteStmt>(*nativeFuncExpr))
3177+
return emitError(nativeFuncName + " should be defined as a rewriter.");
3178+
3179+
FailureOr<ast::CallExpr *> nativeCall =
3180+
createCallExpr(loc, *nativeFuncExpr, arguments);
3181+
if (failed(nativeCall))
3182+
return failure();
3183+
3184+
// Create a unique anonymous name to use, as the name for this decl is not
3185+
// important.
3186+
std::string anonName =
3187+
llvm::formatv("{0}_{1}", nativeFuncName, anonymousDeclNameCounter++)
3188+
.str();
3189+
FailureOr<ast::VariableDecl *> varDecl = defineVariableDecl(
3190+
anonName, loc, (*nativeCall)->getType(), *nativeCall, {});
3191+
if (failed(varDecl))
3192+
return failure();
3193+
3194+
FailureOr<ast::DeclRefExpr *> arrayAttrReference =
3195+
createDeclRefExpr(loc, *varDecl);
3196+
3197+
return *arrayAttrReference;
3198+
}
3199+
30503200
//===----------------------------------------------------------------------===//
30513201
// Stmts
30523202

mlir/test/mlir-pdll/CodeGen/MLIR/expr.pdll

Lines changed: 96 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -142,3 +142,99 @@ Pattern RangeExpr {
142142
// CHECK: %[[TYPE:.*]] = type : i32
143143
// CHECK: operation({{.*}}) -> (%[[TYPE]] : !pdl.type)
144144
Pattern TypeExpr => erase op<> -> (type<"i32">);
145+
146+
// -----
147+
148+
//===----------------------------------------------------------------------===//
149+
// Parse attributes and rewrite
150+
//===----------------------------------------------------------------------===//
151+
152+
// Rewriter helpers declarations.
153+
Rewrite createDictionaryAttr() -> Attr;
154+
Rewrite addEntryToDictionaryAttr(dict: Attr, attrName: Attr, attr : Attr) -> Attr;
155+
Rewrite createArrayAttr() -> Attr;
156+
Rewrite addElemToArrayAttr(arrayAttr: Attr, newElement: Attr) -> Attr;
157+
158+
// CHECK-LABEL: pdl.pattern @RewriteOneEntryDictionary
159+
// CHECK: %[[VAL_1:.*]] = operation "test.op"
160+
// CHECK: %[[VAL_2:.*]] = attribute = "test"
161+
// CHECK: rewrite %[[VAL_1]] {
162+
// CHECK: %[[VAL_3:.*]] = apply_native_rewrite "createDictionaryAttr"
163+
// CHECK: %[[VAL_4:.*]] = attribute = "firstAttr"
164+
// CHECK: %[[VAL_5:.*]] = apply_native_rewrite "addEntryToDictionaryAttr"(%[[VAL_3]], %[[VAL_4]], %[[VAL_2]]
165+
// CHECK: %[[VAL_6:.*]] = operation "test.success" {"some_dictionary" = %[[VAL_5]]}
166+
// CHECK: replace %[[VAL_1]] with %[[VAL_6]]
167+
Pattern RewriteOneEntryDictionary {
168+
let root = op<test.op> -> ();
169+
let attr1 = attr<"\"test\"">;
170+
rewrite root with {
171+
let newRoot = op<test.success>() { some_dictionary = {firstAttr=attr1} } -> ();
172+
replace root with newRoot;
173+
};
174+
}
175+
176+
// CHECK-LABEL: pdl.pattern @RewriteMultipleEntriesDictionary
177+
// CHECK: %[[VAL_1:.*]] = operation "test.op"
178+
// CHECK: %[[VAL_2:.*]] = attribute = "test2"
179+
// CHECK: %[[VAL_3:.*]] = attribute = "test3"
180+
// CHECK: rewrite %[[VAL_1]] {
181+
// CHECK: %[[VAL_4:.*]] = apply_native_rewrite "createDictionaryAttr"
182+
// CHECK: %[[VAL_5:.*]] = attribute = "firstAttr"
183+
// CHECK: %[[VAL_6:.*]] = attribute = "test1"
184+
// CHECK: %[[VAL_7:.*]] = apply_native_rewrite "addEntryToDictionaryAttr"(%[[VAL_4]], %[[VAL_5]], %[[VAL_6]]
185+
// CHECK: %[[VAL_8:.*]] = attribute = "secondAttr"
186+
// CHECK: %[[VAL_9:.*]] = apply_native_rewrite "addEntryToDictionaryAttr"(%[[VAL_7]], %[[VAL_8]], %[[VAL_2]]
187+
// CHECK: %[[VAL_10:.*]] = attribute = "thirdAttr"
188+
// CHECK: %[[VAL_11:.*]] = apply_native_rewrite "addEntryToDictionaryAttr"(%[[VAL_9]], %[[VAL_10]], %[[VAL_3]]
189+
// CHECK: %[[VAL_12:.*]] = operation "test.success" {"some_dictionary" = %[[VAL_11]]}
190+
// CHECK: replace %[[VAL_1]] with %[[VAL_12]]
191+
Pattern RewriteMultipleEntriesDictionary {
192+
let root = op<test.op> -> ();
193+
let attr2 = attr<"\"test2\"">;
194+
let attr3 = attr<"\"test3\"">;
195+
rewrite root with {
196+
let newRoot = op<test.success>() { some_dictionary = {"firstAttr" = attr<"\"test1\"">, secondAttr = attr2, thirdAttr = attr3} } -> ();
197+
replace root with newRoot;
198+
};
199+
}
200+
201+
// CHECK-LABEL: pdl.pattern @RewriteOneDictionaryArrayAttr
202+
// CHECK: %[[VAL_1:.*]] = operation "test.op"
203+
// CHECK: rewrite %[[VAL_1]] {
204+
// CHECK: %[[VAL_2:.*]] = apply_native_rewrite "createArrayAttr"
205+
// CHECK: %[[VAL_3:.*]] = apply_native_rewrite "createDictionaryAttr"
206+
// CHECK: %[[VAL_4:.*]] = attribute = "firstAttr"
207+
// CHECK: %[[VAL_5:.*]] = attribute = "test1"
208+
// CHECK: %[[VAL_6:.*]] = apply_native_rewrite "addEntryToDictionaryAttr"(%[[VAL_3]], %[[VAL_4]], %[[VAL_5]]
209+
// CHECK: %[[VAL_7:.*]] = apply_native_rewrite "addElemToArrayAttr"(%[[VAL_2]], %[[VAL_6]]
210+
// CHECK: %[[VAL_8:.*]] = operation "test.success" {"some_array" = %[[VAL_7]]}
211+
// CHECK: replace %[[VAL_1]] with %[[VAL_8]]
212+
Pattern RewriteOneDictionaryArrayAttr {
213+
let root = op<test.op> -> ();
214+
rewrite root with {
215+
let newRoot = op<test.success>() { some_array = array[{"firstAttr" = attr<"\"test1\"">}]} -> ();
216+
replace root with newRoot;
217+
};
218+
}
219+
220+
// CHECK-LABEL: pdl.pattern @RewriteMultiplyElementsArrayAttr
221+
// CHECK: %[[VAL_1:.*]] = operation "test.op"
222+
// CHECK: %[[VAL_2:.*]] = attribute = "test2"
223+
// CHECK: rewrite %[[VAL_1]] {
224+
// CHECK: %[[VAL_3:.*]] = apply_native_rewrite "createArrayAttr"
225+
// CHECK: %[[VAL_4:.*]] = apply_native_rewrite "createDictionaryAttr"
226+
// CHECK: %[[VAL_5:.*]] = attribute = "firstAttr"
227+
// CHECK: %[[VAL_6:.*]] = attribute = "test1"
228+
// CHECK: %[[VAL_7:.*]] = apply_native_rewrite "addEntryToDictionaryAttr"(%[[VAL_4]], %[[VAL_5]], %[[VAL_6]]
229+
// CHECK: %[[VAL_8:.*]] = apply_native_rewrite "addElemToArrayAttr"(%[[VAL_3]], %[[VAL_7]]
230+
// CHECK: %[[VAL_9:.*]] = apply_native_rewrite "addElemToArrayAttr"(%[[VAL_8]], %[[VAL_2]]
231+
// CHECK: %[[VAL_10:.*]] = operation "test.success" {"some_array" = %[[VAL_9]]}
232+
// CHECK: replace %[[VAL_1]] with %[[VAL_10]]
233+
Pattern RewriteMultiplyElementsArrayAttr {
234+
let root = op<test.op> -> ();
235+
let attr2 = attr<"\"test2\"">;
236+
rewrite root with {
237+
let newRoot = op<test.success>() { some_array = array[{"firstAttr" = attr<"\"test1\"">}, attr2]} -> ();
238+
replace root with newRoot;
239+
};
240+
}

0 commit comments

Comments
 (0)