Skip to content

Feat: add support for Dictionary and Array attributes in PDLL rewrite sections. #56

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion mlir/lib/Tools/PDLL/Parser/Lexer.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -378,7 +378,7 @@ Token Lexer::lexString(const char *tokStart, bool isStringBlock) {
--curPtr;

StringRef expectedEndStr = isStringBlock ? "}]" : "\"";
return emitError(curPtr - 1,
return emitError(tokStart,
"expected '" + expectedEndStr + "' in string literal");
}

Expand Down
138 changes: 137 additions & 1 deletion mlir/lib/Tools/PDLL/Parser/Parser.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -314,9 +314,11 @@ class Parser {
FailureOr<ast::Expr *> parseExpr();

/// Identifier expressions.
FailureOr<ast::Expr *> parseArrayAttrExpr();
FailureOr<ast::Expr *> parseAttributeExpr();
FailureOr<ast::Expr *> parseCallExpr(ast::Expr *parentExpr);
FailureOr<ast::Expr *> parseDeclRefExpr(StringRef name, SMRange loc);
FailureOr<ast::Expr *> parseDictAttrExpr();
FailureOr<ast::Expr *> parseIdentifierExpr();
FailureOr<ast::Expr *> parseInlineConstraintLambdaExpr();
FailureOr<ast::Expr *> parseInlineRewriteLambdaExpr();
Expand All @@ -329,7 +331,6 @@ class Parser {
FailureOr<ast::Expr *> parseTupleExpr();
FailureOr<ast::Expr *> parseTypeExpr();
FailureOr<ast::Expr *> parseUnderscoreExpr();

//===--------------------------------------------------------------------===//
// Stmts

Expand Down Expand Up @@ -413,6 +414,13 @@ class Parser {
FailureOr<ast::MemberAccessExpr *>
createMemberAccessExpr(ast::Expr *parentExpr, StringRef name, SMRange loc);

// Create a native call with \p nativeFuncName and \p arguments.
// This should be accompanied by a C++ implementation of the function that
// needs to be linked and registered in passes that process PDLL files.
FailureOr<ast::DeclRefExpr *>
createNativeCall(SMRange loc, StringRef nativeFuncName,
MutableArrayRef<ast::Expr *> arguments);

/// Validate the member access `name` into the given parent expression. On
/// success, this also returns the type of the member accessed.
FailureOr<ast::Type> validateMemberAccess(ast::Expr *parentExpr,
Expand Down Expand Up @@ -1815,6 +1823,15 @@ FailureOr<ast::Expr *> Parser::parseExpr() {
case Token::l_paren:
lhsExpr = parseTupleExpr();
break;
case Token::l_brace:
lhsExpr = parseDictAttrExpr();
break;
case Token::l_square:
lhsExpr = parseArrayAttrExpr();
break;
case Token::string_block:
return emitError("expected expression. If you are trying to create an "
"ArrayAttr, use a space between `[` and `{`.");
default:
return emitError("expected expression");
}
Expand All @@ -1838,6 +1855,40 @@ FailureOr<ast::Expr *> Parser::parseExpr() {
}
}

FailureOr<ast::Expr *> Parser::parseArrayAttrExpr() {

consumeToken(Token::l_square);

if (parserContext != ParserContext::Rewrite)
return emitError(
"Parsing of array attributes as constraint not supported!");

auto arrayAttrCall =
createNativeCall(curToken.getLoc(), "createArrayAttr", {});
if (failed(arrayAttrCall))
return failure();

do {
FailureOr<ast::Expr *> attr = parseExpr();
if (failed(attr))
return failure();

SmallVector<ast::Expr *> arrayAttrArgs{*arrayAttrCall, *attr};
auto elemToArrayCall = createNativeCall(
curToken.getLoc(), "addElemToArrayAttr", arrayAttrArgs);
if (failed(elemToArrayCall))
return failure();

// Uses the new array for the next element.
arrayAttrCall = elemToArrayCall;
} while (consumeIf(Token::comma));

if (failed(
parseToken(Token::r_square, "expected `]` to close array attribute")))
return failure();
return arrayAttrCall;
}

FailureOr<ast::Expr *> Parser::parseAttributeExpr() {
SMRange loc = curToken.getLoc();
consumeToken(Token::kw_attr);
Expand Down Expand Up @@ -1896,6 +1947,62 @@ FailureOr<ast::Expr *> Parser::parseDeclRefExpr(StringRef name, SMRange loc) {
return createDeclRefExpr(loc, decl);
}

FailureOr<ast::Expr *> Parser::parseDictAttrExpr() {
consumeToken(Token::l_brace);
SMRange loc = curToken.getLoc();

if (parserContext != ParserContext::Rewrite)
return emitError(
"Parsing of dictionary attributes as constraint not supported!");

auto dictAttrCall = createNativeCall(loc, "createDictionaryAttr", {});
if (failed(dictAttrCall))
return failure();

// Add each nested attribute to the dict
do {
FailureOr<ast::NamedAttributeDecl *> decl =
parseNamedAttributeDecl(std::nullopt);
if (failed(decl))
return failure();

ast::NamedAttributeDecl *namedDecl = *decl;

std::string stringAttrValue =
"\"" + std::string((*namedDecl).getName().getName()) + "\"";
auto *stringAttr = ast::AttributeExpr::create(ctx, loc, stringAttrValue);

// Declare it as a variable
std::string anonName =
llvm::formatv("dict{0}", anonymousDeclNameCounter++).str();
FailureOr<ast::VariableDecl *> stringAttrDecl =
createVariableDecl(anonName, namedDecl->getLoc(), stringAttr, {});
if (failed(stringAttrDecl))
return failure();

// Get its reference
auto stringAttrRef = parseDeclRefExpr(
(*stringAttrDecl)->getName().getName(), namedDecl->getLoc());
if (failed(stringAttrRef))
return failure();

// Create addEntryToDictionaryAttr native call.
SmallVector<ast::Expr *> arrayAttrArgs{*dictAttrCall, *stringAttrRef,
namedDecl->getValue()};
auto entryToDictionaryCall =
createNativeCall(loc, "addEntryToDictionaryAttr", arrayAttrArgs);
if (failed(entryToDictionaryCall))
return failure();

// Uses the new array for the next element.
dictAttrCall = entryToDictionaryCall;
} while (consumeIf(Token::comma));
if (failed(parseToken(Token::r_brace,
"expected `}` to close dictionary attribute")))
return failure();
return dictAttrCall;
}

FailureOr<ast::Expr *> Parser::parseIdentifierExpr() {
StringRef name = curToken.getSpelling();
SMRange nameLoc = curToken.getLoc();
Expand Down Expand Up @@ -2769,6 +2876,35 @@ Parser::createMemberAccessExpr(ast::Expr *parentExpr, StringRef name,
return ast::MemberAccessExpr::create(ctx, loc, parentExpr, name, *memberType);
}

FailureOr<ast::DeclRefExpr *>
Parser::createNativeCall(SMRange loc, StringRef nativeFuncName,
MutableArrayRef<ast::Expr *> arguments) {

FailureOr<ast::Expr *> nativeFuncExpr = parseDeclRefExpr(nativeFuncName, loc);
if (failed(nativeFuncExpr))
return failure();

if (!(*nativeFuncExpr)->getType().isa<ast::RewriteType>())
return emitError(nativeFuncName + " should be defined as a rewriter.");

FailureOr<ast::CallExpr *> nativeCall =
createCallExpr(loc, *nativeFuncExpr, arguments);
if (failed(nativeCall))
return failure();

// Create a unique anonymous name declaration to use, as its name is not
// important.
std::string anonName =
llvm::formatv("{0}_{1}", nativeFuncName, anonymousDeclNameCounter++)
.str();
FailureOr<ast::VariableDecl *> varDecl = defineVariableDecl(
anonName, loc, (*nativeCall)->getType(), *nativeCall, {});
if (failed(varDecl))
return failure();

return createDeclRefExpr(loc, *varDecl);
}

FailureOr<ast::Type> Parser::validateMemberAccess(ast::Expr *parentExpr,
StringRef name, SMRange loc) {
ast::Type parentType = parentExpr->getType();
Expand Down
118 changes: 118 additions & 0 deletions mlir/test/mlir-pdll/CodeGen/MLIR/expr.pdll
Original file line number Diff line number Diff line change
Expand Up @@ -142,3 +142,121 @@ Pattern RangeExpr {
// CHECK: %[[TYPE:.*]] = type : i32
// CHECK: operation({{.*}}) -> (%[[TYPE]] : !pdl.type)
Pattern TypeExpr => erase op<> -> (type<"i32">);

// -----

//===----------------------------------------------------------------------===//
// Parse attributes and rewrite
//===----------------------------------------------------------------------===//

// Rewriter helpers declarations.
Rewrite createDictionaryAttr() -> Attr;
Rewrite addEntryToDictionaryAttr(dict: Attr, attrName: Attr, attr : Attr) -> Attr;
Rewrite createArrayAttr() -> Attr;
Rewrite addElemToArrayAttr(arrayAttr: Attr, newElement: Attr) -> Attr;

// CHECK-LABEL: pdl.pattern @RewriteOneEntryDictionary
// CHECK: %[[VAL_1:.*]] = operation "test.op"
// CHECK: %[[VAL_2:.*]] = attribute = "test"
// CHECK: rewrite %[[VAL_1]] {
// CHECK: %[[VAL_3:.*]] = apply_native_rewrite "createDictionaryAttr"
// CHECK: %[[VAL_4:.*]] = attribute = "firstAttr"
// CHECK: %[[VAL_5:.*]] = apply_native_rewrite "addEntryToDictionaryAttr"(%[[VAL_3]], %[[VAL_4]], %[[VAL_2]]
// CHECK: %[[VAL_6:.*]] = operation "test.success" {"some_dictionary" = %[[VAL_5]]}
// CHECK: replace %[[VAL_1]] with %[[VAL_6]]
Pattern RewriteOneEntryDictionary {
let root = op<test.op> -> ();
let attr1 = attr<"\"test\"">;
rewrite root with {
let newRoot = op<test.success>() { some_dictionary = {firstAttr=attr1} } -> ();
replace root with newRoot;
};
}

// -----

// Rewriter helpers declarations.
Rewrite createDictionaryAttr() -> Attr;
Rewrite addEntryToDictionaryAttr(dict: Attr, attrName: Attr, attr : Attr) -> Attr;

// CHECK-LABEL: pdl.pattern @RewriteMultipleEntriesDictionary
// CHECK: %[[VAL_1:.*]] = operation "test.op"
// CHECK: %[[VAL_2:.*]] = attribute = "test2"
// CHECK: %[[VAL_3:.*]] = attribute = "test3"
// CHECK: rewrite %[[VAL_1]] {
// CHECK: %[[VAL_4:.*]] = apply_native_rewrite "createDictionaryAttr"
// CHECK: %[[VAL_5:.*]] = attribute = "firstAttr"
// CHECK: %[[VAL_6:.*]] = attribute = "test1"
// CHECK: %[[VAL_7:.*]] = apply_native_rewrite "addEntryToDictionaryAttr"(%[[VAL_4]], %[[VAL_5]], %[[VAL_6]]
// CHECK: %[[VAL_8:.*]] = attribute = "secondAttr"
// CHECK: %[[VAL_9:.*]] = apply_native_rewrite "addEntryToDictionaryAttr"(%[[VAL_7]], %[[VAL_8]], %[[VAL_2]]
// CHECK: %[[VAL_10:.*]] = attribute = "thirdAttr"
// CHECK: %[[VAL_11:.*]] = apply_native_rewrite "addEntryToDictionaryAttr"(%[[VAL_9]], %[[VAL_10]], %[[VAL_3]]
// CHECK: %[[VAL_12:.*]] = operation "test.success" {"some_dictionary" = %[[VAL_11]]}
// CHECK: replace %[[VAL_1]] with %[[VAL_12]]
Pattern RewriteMultipleEntriesDictionary {
let root = op<test.op> -> ();
let attr2 = attr<"\"test2\"">;
let attr3 = attr<"\"test3\"">;
rewrite root with {
let newRoot = op<test.success>() { some_dictionary = {"firstAttr" = attr<"\"test1\"">, secondAttr = attr2, thirdAttr = attr3} } -> ();
replace root with newRoot;
};
}

// -----

// Rewriter helpers declarations.
Rewrite createDictionaryAttr() -> Attr;
Rewrite addEntryToDictionaryAttr(dict: Attr, attrName: Attr, attr : Attr) -> Attr;
Rewrite createArrayAttr() -> Attr;
Rewrite addElemToArrayAttr(arrayAttr: Attr, newElement: Attr) -> Attr;

// CHECK-LABEL: pdl.pattern @RewriteOneDictionaryArrayAttr
// CHECK: %[[VAL_1:.*]] = operation "test.op"
// CHECK: rewrite %[[VAL_1]] {
// CHECK: %[[VAL_2:.*]] = apply_native_rewrite "createArrayAttr"
// CHECK: %[[VAL_3:.*]] = apply_native_rewrite "createDictionaryAttr"
// CHECK: %[[VAL_4:.*]] = attribute = "firstAttr"
// CHECK: %[[VAL_5:.*]] = attribute = "test1"
// CHECK: %[[VAL_6:.*]] = apply_native_rewrite "addEntryToDictionaryAttr"(%[[VAL_3]], %[[VAL_4]], %[[VAL_5]]
// CHECK: %[[VAL_7:.*]] = apply_native_rewrite "addElemToArrayAttr"(%[[VAL_2]], %[[VAL_6]]
// CHECK: %[[VAL_8:.*]] = operation "test.success" {"some_array" = %[[VAL_7]]}
// CHECK: replace %[[VAL_1]] with %[[VAL_8]]
Pattern RewriteOneDictionaryArrayAttr {
let root = op<test.op> -> ();
rewrite root with {
let newRoot = op<test.success>() { some_array = [ {"firstAttr" = attr<"\"test1\"">}]} -> ();
replace root with newRoot;
};
}

// -----

// Rewriter helpers declarations.
Rewrite createDictionaryAttr() -> Attr;
Rewrite addEntryToDictionaryAttr(dict: Attr, attrName: Attr, attr : Attr) -> Attr;
Rewrite createArrayAttr() -> Attr;
Rewrite addElemToArrayAttr(arrayAttr: Attr, newElement: Attr) -> Attr;

// CHECK-LABEL: pdl.pattern @RewriteMultiplyElementsArrayAttr
// CHECK: %[[VAL_1:.*]] = operation "test.op"
// CHECK: %[[VAL_2:.*]] = attribute = "test2"
// CHECK: rewrite %[[VAL_1]] {
// CHECK: %[[VAL_3:.*]] = apply_native_rewrite "createArrayAttr"
// CHECK: %[[VAL_4:.*]] = apply_native_rewrite "createDictionaryAttr"
// CHECK: %[[VAL_5:.*]] = attribute = "firstAttr"
// CHECK: %[[VAL_6:.*]] = attribute = "test1"
// CHECK: %[[VAL_7:.*]] = apply_native_rewrite "addEntryToDictionaryAttr"(%[[VAL_4]], %[[VAL_5]], %[[VAL_6]]
// CHECK: %[[VAL_8:.*]] = apply_native_rewrite "addElemToArrayAttr"(%[[VAL_3]], %[[VAL_7]]
// CHECK: %[[VAL_9:.*]] = apply_native_rewrite "addElemToArrayAttr"(%[[VAL_8]], %[[VAL_2]]
// CHECK: %[[VAL_10:.*]] = operation "test.success" {"some_array" = %[[VAL_9]]}
// CHECK: replace %[[VAL_1]] with %[[VAL_10]]
Pattern RewriteMultiplyElementsArrayAttr {
let root = op<test.op> -> ();
let attr2 = attr<"\"test2\"">;
rewrite root with {
let newRoot = op<test.success>() { some_array = [ {"firstAttr" = attr<"\"test1\"">}, attr2]} -> ();
replace root with newRoot;
};
}
37 changes: 37 additions & 0 deletions mlir/test/mlir-pdll/Parser/expr-failure.pdll
Original file line number Diff line number Diff line change
Expand Up @@ -218,6 +218,43 @@ Pattern {

// -----

Pattern {
let root = op<func.func> -> ();
// CHECK: expected expression. If you are trying to create an ArrayAttr, use a space between `[` and `{`.
rewrite root with {
let newRoot = op<func.func>() { some_array = [{"firstAttr" = attr<"\"test\"">}]} -> ();
replace root with newRoot;
};;
}


// -----

Pattern {
let root = op<func.func> -> ();
// CHECK: expected '}]' in string literal
rewrite root with {
let newRoot = op<func.func>() { some_array = [{"firstAttr" = attr<"\"test\"">}, attr<"\"test\"">] } -> ();
replace root with newRoot;
};
}

// -----

Rewrite addElemToArrayAttr(arrayAttr: Attr, newElement: Attr) -> Attr;

Pattern {
let root = op<test.op> -> ();
let attr = attr<"\"test\"">;
rewrite root with {
// CHECK: undefined reference to `createArrayAttr`
let newRoot = op<test.success>() { some_array = [ attr<"\"test\""> ]} -> ();
replace root with newRoot;
};
}

// -----

//===----------------------------------------------------------------------===//
// `op` Expr
//===----------------------------------------------------------------------===//
Expand Down
27 changes: 27 additions & 0 deletions mlir/test/mlir-pdll/Parser/expr.pdll
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,33 @@ Pattern {

// -----

// CHECK: |-NamedAttributeDecl {{.*}} Name<some_array>
// CHECK: `-UserRewriteDecl {{.*}} Name<addElemToArrayAttr> ResultType<Attr>
// CHECK: `Arguments`
// CHECK: `-CallExpr {{.*}} Type<Attr>
// CHECK: `-UserRewriteDecl {{.*}} Name<createArrayAttr> ResultType<Attr>
// CHECK: `-CallExpr {{.*}} Type<Attr>
// CHECK: `-UserRewriteDecl {{.*}} Name<addEntryToDictionaryAttr> ResultType<Attr>
// CHECK: `Arguments`
// CHECK: `-CallExpr {{.*}} Type<Attr>
// CHECK: `-UserRewriteDecl {{.*}} Name<createDictionaryAttr> ResultType<Attr>
// CHECK: `-AttributeExpr {{.*}} Value<""firstAttr"">
Rewrite createDictionaryAttr() -> Attr;
Rewrite addEntryToDictionaryAttr(dict: Attr, attrName: Attr, attr : Attr) -> Attr;
Rewrite createArrayAttr() -> Attr;
Rewrite addElemToArrayAttr(arrayAttr: Attr, newElement: Attr) -> Attr;

Pattern {
let root = op<test.op> -> ();
let attr = attr<"\"test\"">;
rewrite root with {
let newRoot = op<test.success>() { some_array = [ {"firstAttr" = attr<"\"test\"">}], attr} -> ();
replace root with newRoot;
};
}

// -----

//===----------------------------------------------------------------------===//
// CallExpr
//===----------------------------------------------------------------------===//
Expand Down