Skip to content

Commit e6b751e

Browse files
authored
Merge pull request #93 from Xilinx/matthias.pdl_builtin
PDLL: Allow to define builtin native calls
2 parents fc8ecac + 8ef9dcd commit e6b751e

File tree

11 files changed

+277
-91
lines changed

11 files changed

+277
-91
lines changed
Lines changed: 36 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,36 @@
1+
//===- Builtins.h - Builtin functions of the PDL dialect --------*- C++ -*-===//
2+
//
3+
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4+
// See https://llvm.org/LICENSE.txt for license information.
5+
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6+
//
7+
//===----------------------------------------------------------------------===//
8+
//
9+
// This file defines builtin functions of the PDL dialect.
10+
//
11+
//===----------------------------------------------------------------------===//
12+
13+
#ifndef MLIR_DIALECT_PDL_IR_BUILTINS_H_
14+
#define MLIR_DIALECT_PDL_IR_BUILTINS_H_
15+
16+
namespace mlir {
17+
class PDLPatternModule;
18+
class Attribute;
19+
class PatternRewriter;
20+
21+
namespace pdl {
22+
void registerBuiltins(PDLPatternModule &pdlPattern);
23+
24+
namespace builtin {
25+
Attribute createDictionaryAttr(PatternRewriter &rewriter);
26+
Attribute addEntryToDictionaryAttr(PatternRewriter &rewriter,
27+
Attribute dictAttr, Attribute attrName,
28+
Attribute attrEntry);
29+
Attribute createArrayAttr(PatternRewriter &rewriter);
30+
Attribute addElemToArrayAttr(PatternRewriter &rewriter, Attribute attr,
31+
Attribute element);
32+
} // namespace builtin
33+
} // namespace pdl
34+
} // namespace mlir
35+
36+
#endif // MLIR_DIALECT_PDL_IR_BUILTINS_H_

mlir/lib/Dialect/PDL/IR/Builtins.cpp

Lines changed: 56 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,56 @@
1+
#include <mlir/Dialect/PDL/IR/Builtins.h>
2+
#include <mlir/IR/PatternMatch.h>
3+
4+
using namespace mlir;
5+
6+
namespace mlir::pdl {
7+
namespace builtin {
8+
mlir::Attribute createDictionaryAttr(mlir::PatternRewriter &rewriter) {
9+
return rewriter.getDictionaryAttr({});
10+
}
11+
12+
mlir::Attribute addEntryToDictionaryAttr(mlir::PatternRewriter &rewriter,
13+
mlir::Attribute dictAttr,
14+
mlir::Attribute attrName,
15+
mlir::Attribute attrEntry) {
16+
assert(isa<DictionaryAttr>(dictAttr));
17+
auto attr = dictAttr.cast<DictionaryAttr>();
18+
auto name = attrName.cast<StringAttr>();
19+
std::vector<NamedAttribute> values = attr.getValue().vec();
20+
21+
// Remove entry if it exists in the dictionary.
22+
llvm::erase_if(values, [&](NamedAttribute &namedAttr) {
23+
return namedAttr.getName() == name.getValue();
24+
});
25+
26+
values.push_back(rewriter.getNamedAttr(name, attrEntry));
27+
return rewriter.getDictionaryAttr(values);
28+
}
29+
30+
mlir::Attribute createArrayAttr(mlir::PatternRewriter &rewriter) {
31+
return rewriter.getArrayAttr({});
32+
}
33+
34+
mlir::Attribute addElemToArrayAttr(mlir::PatternRewriter &rewriter,
35+
mlir::Attribute attr,
36+
mlir::Attribute element) {
37+
assert(isa<ArrayAttr>(attr));
38+
auto values = cast<ArrayAttr>(attr).getValue().vec();
39+
values.push_back(element);
40+
return rewriter.getArrayAttr(values);
41+
}
42+
} // namespace builtin
43+
44+
void registerBuiltins(PDLPatternModule &pdlPattern) {
45+
using namespace builtin;
46+
// See Parser::defineBuiltins()
47+
pdlPattern.registerRewriteFunction("__builtin_createDictionaryAttr",
48+
createDictionaryAttr);
49+
pdlPattern.registerRewriteFunction("__builtin_addEntryToDictionaryAttr",
50+
addEntryToDictionaryAttr);
51+
pdlPattern.registerRewriteFunction("__builtin_createArrayAttr",
52+
createArrayAttr);
53+
pdlPattern.registerRewriteFunction("__builtin_addElemToArrayAttr",
54+
addElemToArrayAttr);
55+
}
56+
} // namespace mlir::pdl

mlir/lib/Dialect/PDL/IR/CMakeLists.txt

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
11
add_mlir_dialect_library(MLIRPDLDialect
2+
Builtins.cpp
23
PDL.cpp
34
PDLTypes.cpp
45

mlir/lib/Rewrite/FrozenRewritePatternSet.cpp

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,7 @@
1313
#include "mlir/Interfaces/SideEffectInterfaces.h"
1414
#include "mlir/Pass/Pass.h"
1515
#include "mlir/Pass/PassManager.h"
16+
#include <mlir/Dialect/PDL/IR/Builtins.h>
1617
#include <optional>
1718

1819
using namespace mlir;
@@ -132,6 +133,8 @@ FrozenRewritePatternSet::FrozenRewritePatternSet(
132133
llvm::report_fatal_error(
133134
"failed to lower PDL pattern module to the PDL Interpreter");
134135

136+
pdl::registerBuiltins(pdlPatterns);
137+
135138
// Generate the pdl bytecode.
136139
impl->pdlByteCode = std::make_unique<detail::PDLByteCode>(
137140
pdlModule, pdlPatterns.takeConfigs(), configMap,

mlir/lib/Tools/PDLL/Parser/Parser.cpp

Lines changed: 77 additions & 28 deletions
Original file line numberDiff line numberDiff line change
@@ -106,6 +106,17 @@ class Parser {
106106
/// Pop the last decl scope from the lexer.
107107
void popDeclScope() { curDeclScope = curDeclScope->getParentScope(); }
108108

109+
/// Creates a native constraint taking a set of Attr as arguments.
110+
/// The number of arguments and their names is given by argNames.
111+
/// The native returns an Attr when returnsAttr is true, otherwise returns
112+
/// nothing.
113+
template <class T>
114+
T *declareBuiltin(StringRef name, ArrayRef<StringRef> argNames,
115+
bool returnsAttr);
116+
117+
/// Register all builtin natives.
118+
void declareBuiltins();
119+
109120
/// Parse the body of an AST module.
110121
LogicalResult parseModuleBody(SmallVectorImpl<ast::Decl *> &decls);
111122

@@ -418,12 +429,12 @@ class Parser {
418429
FailureOr<ast::MemberAccessExpr *>
419430
createMemberAccessExpr(ast::Expr *parentExpr, StringRef name, SMRange loc);
420431

421-
// Create a native call with \p nativeFuncName and \p arguments.
432+
// Create a native call with \p function and \p arguments.
422433
// This should be accompanied by a C++ implementation of the function that
423434
// needs to be linked and registered in passes that process PDLL files.
424-
FailureOr<ast::DeclRefExpr *>
425-
createNativeCall(SMRange loc, StringRef nativeFuncName,
426-
MutableArrayRef<ast::Expr *> arguments);
435+
FailureOr<ast::Expr *>
436+
createBuiltinCall(SMRange loc, ast::Decl *function,
437+
MutableArrayRef<ast::Expr *> arguments);
427438

428439
/// Validate the member access `name` into the given parent expression. On
429440
/// success, this also returns the type of the member accessed.
@@ -578,13 +589,64 @@ class Parser {
578589

579590
/// The optional code completion context.
580591
CodeCompleteContext *codeCompleteContext;
592+
593+
struct {
594+
ast::UserRewriteDecl *createDictionaryAttr;
595+
ast::UserRewriteDecl *addEntryToDictionaryAttr;
596+
ast::UserRewriteDecl *createArrayAttr;
597+
ast::UserRewriteDecl *addElemToArrayAttr;
598+
} builtins{};
581599
};
582600
} // namespace
583601

602+
template <class T>
603+
T *Parser::declareBuiltin(StringRef name, ArrayRef<StringRef> argNames,
604+
bool returnsAttr) {
605+
SMRange loc;
606+
auto attrConstr = ast::ConstraintRef(
607+
ast::AttrConstraintDecl::create(ctx, loc, nullptr), loc);
608+
609+
pushDeclScope();
610+
SmallVector<ast::VariableDecl *> args;
611+
for (auto argName : argNames) {
612+
FailureOr<ast::VariableDecl *> arg =
613+
createArgOrResultVariableDecl(argName, loc, attrConstr);
614+
assert(succeeded(arg));
615+
args.push_back(*arg);
616+
}
617+
SmallVector<ast::VariableDecl *> results;
618+
if (returnsAttr) {
619+
auto result = createArgOrResultVariableDecl("", loc, attrConstr);
620+
assert(succeeded(result));
621+
results.push_back(*result);
622+
}
623+
popDeclScope();
624+
625+
auto *constraintDecl = T::createNative(ctx, ast::Name::create(ctx, name, loc),
626+
args, results, {}, attrTy);
627+
curDeclScope->add(constraintDecl);
628+
return constraintDecl;
629+
}
630+
631+
void Parser::declareBuiltins() {
632+
builtins.createDictionaryAttr = declareBuiltin<ast::UserRewriteDecl>(
633+
"__builtin_createDictionaryAttr", {}, /*returnsAttr=*/true);
634+
builtins.addEntryToDictionaryAttr = declareBuiltin<ast::UserRewriteDecl>(
635+
"__builtin_addEntryToDictionaryAttr", {"attr", "attrName", "attrEntry"},
636+
/*returnsAttr=*/true);
637+
builtins.createArrayAttr = declareBuiltin<ast::UserRewriteDecl>(
638+
"__builtin_createArrayAttr", {}, /*returnsAttr=*/true);
639+
builtins.addElemToArrayAttr = declareBuiltin<ast::UserRewriteDecl>(
640+
"__builtin_addElemToArrayAttr", {"attr", "element"},
641+
/*returnsAttr=*/true);
642+
}
643+
584644
FailureOr<ast::Module *> Parser::parseModule() {
585645
SMLoc moduleLoc = curToken.getStartLoc();
586646
pushDeclScope();
587647

648+
declareBuiltins();
649+
588650
// Parse the top-level decls of the module.
589651
SmallVector<ast::Decl *> decls;
590652
if (failed(parseModuleBody(decls)))
@@ -1874,7 +1936,7 @@ FailureOr<ast::Expr *> Parser::parseArrayAttrExpr() {
18741936
"Parsing of array attributes as constraint not supported!");
18751937

18761938
auto arrayAttrCall =
1877-
createNativeCall(curToken.getLoc(), "createArrayAttr", {});
1939+
createBuiltinCall(curToken.getLoc(), builtins.createArrayAttr, {});
18781940
if (failed(arrayAttrCall))
18791941
return failure();
18801942

@@ -1884,8 +1946,8 @@ FailureOr<ast::Expr *> Parser::parseArrayAttrExpr() {
18841946
return failure();
18851947

18861948
SmallVector<ast::Expr *> arrayAttrArgs{*arrayAttrCall, *attr};
1887-
auto elemToArrayCall = createNativeCall(
1888-
curToken.getLoc(), "addElemToArrayAttr", arrayAttrArgs);
1949+
auto elemToArrayCall = createBuiltinCall(
1950+
curToken.getLoc(), builtins.addElemToArrayAttr, arrayAttrArgs);
18891951
if (failed(elemToArrayCall))
18901952
return failure();
18911953

@@ -1966,7 +2028,7 @@ FailureOr<ast::Expr *> Parser::parseDictAttrExpr() {
19662028
return emitError(
19672029
"Parsing of dictionary attributes as constraint not supported!");
19682030

1969-
auto dictAttrCall = createNativeCall(loc, "createDictionaryAttr", {});
2031+
auto dictAttrCall = createBuiltinCall(loc, builtins.createDictionaryAttr, {});
19702032
if (failed(dictAttrCall))
19712033
return failure();
19722034

@@ -2000,8 +2062,8 @@ FailureOr<ast::Expr *> Parser::parseDictAttrExpr() {
20002062
// Create addEntryToDictionaryAttr native call.
20012063
SmallVector<ast::Expr *> arrayAttrArgs{*dictAttrCall, *stringAttrRef,
20022064
namedDecl->getValue()};
2003-
auto entryToDictionaryCall =
2004-
createNativeCall(loc, "addEntryToDictionaryAttr", arrayAttrArgs);
2065+
auto entryToDictionaryCall = createBuiltinCall(
2066+
loc, builtins.addEntryToDictionaryAttr, arrayAttrArgs);
20052067
if (failed(entryToDictionaryCall))
20062068
return failure();
20072069

@@ -2923,33 +2985,20 @@ Parser::createMemberAccessExpr(ast::Expr *parentExpr, StringRef name,
29232985
return ast::MemberAccessExpr::create(ctx, loc, parentExpr, name, *memberType);
29242986
}
29252987

2926-
FailureOr<ast::DeclRefExpr *>
2927-
Parser::createNativeCall(SMRange loc, StringRef nativeFuncName,
2928-
MutableArrayRef<ast::Expr *> arguments) {
2988+
FailureOr<ast::Expr *>
2989+
Parser::createBuiltinCall(SMRange loc, ast::Decl *function,
2990+
MutableArrayRef<ast::Expr *> arguments) {
29292991

2930-
FailureOr<ast::Expr *> nativeFuncExpr = parseDeclRefExpr(nativeFuncName, loc);
2992+
FailureOr<ast::Expr *> nativeFuncExpr = createDeclRefExpr(loc, function);
29312993
if (failed(nativeFuncExpr))
29322994
return failure();
29332995

2934-
if (!(*nativeFuncExpr)->getType().isa<ast::RewriteType>())
2935-
return emitError(nativeFuncName + " should be defined as a rewriter.");
2936-
29372996
FailureOr<ast::CallExpr *> nativeCall =
29382997
createCallExpr(loc, *nativeFuncExpr, arguments);
29392998
if (failed(nativeCall))
29402999
return failure();
29413000

2942-
// Create a unique anonymous name declaration to use, as its name is not
2943-
// important.
2944-
std::string anonName =
2945-
llvm::formatv("{0}_{1}", nativeFuncName, anonymousDeclNameCounter++)
2946-
.str();
2947-
FailureOr<ast::VariableDecl *> varDecl = defineVariableDecl(
2948-
anonName, loc, (*nativeCall)->getType(), *nativeCall, {});
2949-
if (failed(varDecl))
2950-
return failure();
2951-
2952-
return createDeclRefExpr(loc, *varDecl);
3001+
return *nativeCall;
29533002
}
29543003

29553004
FailureOr<ast::Type> Parser::validateMemberAccess(ast::Expr *parentExpr,

0 commit comments

Comments
 (0)