Skip to content

Commit a309fe1

Browse files
committed
Pdl: Allow to define builtin native calls
1 parent abf4234 commit a309fe1

File tree

11 files changed

+279
-93
lines changed

11 files changed

+279
-93
lines changed
Lines changed: 36 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,36 @@
1+
//===- Builtins.h - Builtin functions of the PDL dialect --------*- C++ -*-===//
2+
//
3+
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4+
// See https://llvm.org/LICENSE.txt for license information.
5+
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6+
//
7+
//===----------------------------------------------------------------------===//
8+
//
9+
// This file defines builtin functions of the PDL dialect.
10+
//
11+
//===----------------------------------------------------------------------===//
12+
13+
#ifndef MLIR_DIALECT_PDL_IR_BUILTINS_H_
14+
#define MLIR_DIALECT_PDL_IR_BUILTINS_H_
15+
16+
namespace mlir {
17+
class PDLPatternModule;
18+
class Attribute;
19+
class PatternRewriter;
20+
21+
namespace pdl {
22+
void registerBuiltins(PDLPatternModule &pdlPattern);
23+
24+
namespace builtin {
25+
Attribute createDictionaryAttr(PatternRewriter &rewriter);
26+
Attribute addEntryToDictionaryAttr(PatternRewriter &rewriter,
27+
Attribute dictAttr, Attribute attrName,
28+
Attribute attrEntry);
29+
Attribute createArrayAttr(PatternRewriter &rewriter);
30+
Attribute addElemToArrayAttr(PatternRewriter &rewriter, Attribute attr,
31+
Attribute element);
32+
} // namespace builtin
33+
} // namespace pdl
34+
} // namespace mlir
35+
36+
#endif // MLIR_DIALECT_PDL_IR_BUILTINS_H_

mlir/lib/Dialect/PDL/IR/Builtins.cpp

Lines changed: 56 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,56 @@
1+
#include <mlir/Dialect/PDL/IR/Builtins.h>
2+
#include <mlir/IR/PatternMatch.h>
3+
4+
using namespace mlir;
5+
6+
namespace mlir::pdl {
7+
namespace builtin {
8+
mlir::Attribute createDictionaryAttr(mlir::PatternRewriter &rewriter) {
9+
return rewriter.getDictionaryAttr({});
10+
}
11+
12+
mlir::Attribute addEntryToDictionaryAttr(mlir::PatternRewriter &rewriter,
13+
mlir::Attribute dictAttr,
14+
mlir::Attribute attrName,
15+
mlir::Attribute attrEntry) {
16+
assert(isa<DictionaryAttr>(dictAttr));
17+
auto attr = dictAttr.cast<DictionaryAttr>();
18+
auto name = attrName.cast<StringAttr>();
19+
std::vector<NamedAttribute> values = attr.getValue().vec();
20+
21+
// Remove entry if it exists in the dictionary.
22+
llvm::erase_if(values, [&](NamedAttribute &namedAttr) {
23+
return namedAttr.getName() == name.getValue();
24+
});
25+
26+
values.push_back(rewriter.getNamedAttr(name, attrEntry));
27+
return rewriter.getDictionaryAttr(values);
28+
}
29+
30+
mlir::Attribute createArrayAttr(mlir::PatternRewriter &rewriter) {
31+
return rewriter.getArrayAttr({});
32+
}
33+
34+
mlir::Attribute addElemToArrayAttr(mlir::PatternRewriter &rewriter,
35+
mlir::Attribute attr,
36+
mlir::Attribute element) {
37+
assert(isa<ArrayAttr>(attr));
38+
auto values = cast<ArrayAttr>(attr).getValue().vec();
39+
values.push_back(element);
40+
return rewriter.getArrayAttr(values);
41+
}
42+
} // namespace builtin
43+
44+
void registerBuiltins(PDLPatternModule &pdlPattern) {
45+
using namespace builtin;
46+
// See Parser::defineBuiltins()
47+
pdlPattern.registerRewriteFunction("__builtin_createDictionaryAttr",
48+
createDictionaryAttr);
49+
pdlPattern.registerRewriteFunction("__builtin_addEntryToDictionaryAttr",
50+
addEntryToDictionaryAttr);
51+
pdlPattern.registerRewriteFunction("__builtin_createArrayAttr",
52+
createArrayAttr);
53+
pdlPattern.registerRewriteFunction("__builtin_addElemToArrayAttr",
54+
addElemToArrayAttr);
55+
}
56+
} // namespace mlir::pdl

mlir/lib/Dialect/PDL/IR/CMakeLists.txt

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
11
add_mlir_dialect_library(MLIRPDLDialect
2+
Builtins.cpp
23
PDL.cpp
34
PDLTypes.cpp
45

mlir/lib/Rewrite/FrozenRewritePatternSet.cpp

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,7 @@
1313
#include "mlir/Interfaces/SideEffectInterfaces.h"
1414
#include "mlir/Pass/Pass.h"
1515
#include "mlir/Pass/PassManager.h"
16+
#include <mlir/Dialect/PDL/IR/Builtins.h>
1617
#include <optional>
1718

1819
using namespace mlir;
@@ -132,6 +133,8 @@ FrozenRewritePatternSet::FrozenRewritePatternSet(
132133
llvm::report_fatal_error(
133134
"failed to lower PDL pattern module to the PDL Interpreter");
134135

136+
pdl::registerBuiltins(pdlPatterns);
137+
135138
// Generate the pdl bytecode.
136139
impl->pdlByteCode = std::make_unique<detail::PDLByteCode>(
137140
pdlModule, pdlPatterns.takeConfigs(), configMap,

mlir/lib/Tools/PDLL/Parser/Parser.cpp

Lines changed: 77 additions & 28 deletions
Original file line numberDiff line numberDiff line change
@@ -106,6 +106,17 @@ class Parser {
106106
/// Pop the last decl scope from the lexer.
107107
void popDeclScope() { curDeclScope = curDeclScope->getParentScope(); }
108108

109+
/// Creates a native constraint taking a set of Attr as arguments.
110+
/// The number of arguments and their names is given by argNames.
111+
/// The native returns an Attr when returnsAttr is true, otherwise returns
112+
/// nothing.
113+
template <class T>
114+
T *declareBuiltin(StringRef name, ArrayRef<StringRef> argNames,
115+
bool returnsAttr);
116+
117+
/// Register all builtin natives.
118+
void declareBuiltins();
119+
109120
/// Parse the body of an AST module.
110121
LogicalResult parseModuleBody(SmallVectorImpl<ast::Decl *> &decls);
111122

@@ -416,12 +427,12 @@ class Parser {
416427
FailureOr<ast::MemberAccessExpr *>
417428
createMemberAccessExpr(ast::Expr *parentExpr, StringRef name, SMRange loc);
418429

419-
// Create a native call with \p nativeFuncName and \p arguments.
430+
// Create a native call with \p function and \p arguments.
420431
// This should be accompanied by a C++ implementation of the function that
421432
// needs to be linked and registered in passes that process PDLL files.
422-
FailureOr<ast::DeclRefExpr *>
423-
createNativeCall(SMRange loc, StringRef nativeFuncName,
424-
MutableArrayRef<ast::Expr *> arguments);
433+
FailureOr<ast::Expr *>
434+
createBuiltinCall(SMRange loc, ast::Decl *function,
435+
MutableArrayRef<ast::Expr *> arguments);
425436

426437
/// Validate the member access `name` into the given parent expression. On
427438
/// success, this also returns the type of the member accessed.
@@ -576,13 +587,64 @@ class Parser {
576587

577588
/// The optional code completion context.
578589
CodeCompleteContext *codeCompleteContext;
590+
591+
struct {
592+
ast::UserRewriteDecl *createDictionaryAttr;
593+
ast::UserRewriteDecl *addEntryToDictionaryAttr;
594+
ast::UserRewriteDecl *createArrayAttr;
595+
ast::UserRewriteDecl *addElemToArrayAttr;
596+
} builtins{};
579597
};
580598
} // namespace
581599

600+
template <class T>
601+
T *Parser::declareBuiltin(StringRef name, ArrayRef<StringRef> argNames,
602+
bool returnsAttr) {
603+
SMRange loc;
604+
auto attrConstr = ast::ConstraintRef(
605+
ast::AttrConstraintDecl::create(ctx, loc, nullptr), loc);
606+
607+
pushDeclScope();
608+
SmallVector<ast::VariableDecl *> args;
609+
for (auto argName : argNames) {
610+
FailureOr<ast::VariableDecl *> arg =
611+
createArgOrResultVariableDecl(argName, loc, attrConstr);
612+
assert(succeeded(arg));
613+
args.push_back(*arg);
614+
}
615+
SmallVector<ast::VariableDecl *> results;
616+
if (returnsAttr) {
617+
auto result = createArgOrResultVariableDecl("", loc, attrConstr);
618+
assert(succeeded(result));
619+
results.push_back(*result);
620+
}
621+
popDeclScope();
622+
623+
auto *constraintDecl = T::createNative(ctx, ast::Name::create(ctx, name, loc),
624+
args, results, {}, attrTy);
625+
curDeclScope->add(constraintDecl);
626+
return constraintDecl;
627+
}
628+
629+
void Parser::declareBuiltins() {
630+
builtins.createDictionaryAttr = declareBuiltin<ast::UserRewriteDecl>(
631+
"__builtin_createDictionaryAttr", {}, /*returnsAttr=*/true);
632+
builtins.addEntryToDictionaryAttr = declareBuiltin<ast::UserRewriteDecl>(
633+
"__builtin_addEntryToDictionaryAttr", {"attr", "attrName", "attrEntry"},
634+
/*returnsAttr=*/true);
635+
builtins.createArrayAttr = declareBuiltin<ast::UserRewriteDecl>(
636+
"__builtin_createArrayAttr", {}, /*returnsAttr=*/true);
637+
builtins.addElemToArrayAttr = declareBuiltin<ast::UserRewriteDecl>(
638+
"__builtin_addElemToArrayAttr", {"attr", "element"},
639+
/*returnsAttr=*/true);
640+
}
641+
582642
FailureOr<ast::Module *> Parser::parseModule() {
583643
SMLoc moduleLoc = curToken.getStartLoc();
584644
pushDeclScope();
585645

646+
declareBuiltins();
647+
586648
// Parse the top-level decls of the module.
587649
SmallVector<ast::Decl *> decls;
588650
if (failed(parseModuleBody(decls)))
@@ -1869,7 +1931,7 @@ FailureOr<ast::Expr *> Parser::parseArrayAttrExpr() {
18691931
"Parsing of array attributes as constraint not supported!");
18701932

18711933
auto arrayAttrCall =
1872-
createNativeCall(curToken.getLoc(), "createArrayAttr", {});
1934+
createBuiltinCall(curToken.getLoc(), builtins.createArrayAttr, {});
18731935
if (failed(arrayAttrCall))
18741936
return failure();
18751937

@@ -1879,8 +1941,8 @@ FailureOr<ast::Expr *> Parser::parseArrayAttrExpr() {
18791941
return failure();
18801942

18811943
SmallVector<ast::Expr *> arrayAttrArgs{*arrayAttrCall, *attr};
1882-
auto elemToArrayCall = createNativeCall(
1883-
curToken.getLoc(), "addElemToArrayAttr", arrayAttrArgs);
1944+
auto elemToArrayCall = createBuiltinCall(
1945+
curToken.getLoc(), builtins.addElemToArrayAttr, arrayAttrArgs);
18841946
if (failed(elemToArrayCall))
18851947
return failure();
18861948

@@ -1961,7 +2023,7 @@ FailureOr<ast::Expr *> Parser::parseDictAttrExpr() {
19612023
return emitError(
19622024
"Parsing of dictionary attributes as constraint not supported!");
19632025

1964-
auto dictAttrCall = createNativeCall(loc, "createDictionaryAttr", {});
2026+
auto dictAttrCall = createBuiltinCall(loc, builtins.createDictionaryAttr, {});
19652027
if (failed(dictAttrCall))
19662028
return failure();
19672029

@@ -1995,8 +2057,8 @@ FailureOr<ast::Expr *> Parser::parseDictAttrExpr() {
19952057
// Create addEntryToDictionaryAttr native call.
19962058
SmallVector<ast::Expr *> arrayAttrArgs{*dictAttrCall, *stringAttrRef,
19972059
namedDecl->getValue()};
1998-
auto entryToDictionaryCall =
1999-
createNativeCall(loc, "addEntryToDictionaryAttr", arrayAttrArgs);
2060+
auto entryToDictionaryCall = createBuiltinCall(
2061+
loc, builtins.addEntryToDictionaryAttr, arrayAttrArgs);
20002062
if (failed(entryToDictionaryCall))
20012063
return failure();
20022064

@@ -2895,33 +2957,20 @@ Parser::createMemberAccessExpr(ast::Expr *parentExpr, StringRef name,
28952957
return ast::MemberAccessExpr::create(ctx, loc, parentExpr, name, *memberType);
28962958
}
28972959

2898-
FailureOr<ast::DeclRefExpr *>
2899-
Parser::createNativeCall(SMRange loc, StringRef nativeFuncName,
2900-
MutableArrayRef<ast::Expr *> arguments) {
2960+
FailureOr<ast::Expr *>
2961+
Parser::createBuiltinCall(SMRange loc, ast::Decl *function,
2962+
MutableArrayRef<ast::Expr *> arguments) {
29012963

2902-
FailureOr<ast::Expr *> nativeFuncExpr = parseDeclRefExpr(nativeFuncName, loc);
2964+
FailureOr<ast::Expr *> nativeFuncExpr = createDeclRefExpr(loc, function);
29032965
if (failed(nativeFuncExpr))
29042966
return failure();
29052967

2906-
if (!(*nativeFuncExpr)->getType().isa<ast::RewriteType>())
2907-
return emitError(nativeFuncName + " should be defined as a rewriter.");
2908-
29092968
FailureOr<ast::CallExpr *> nativeCall =
29102969
createCallExpr(loc, *nativeFuncExpr, arguments);
29112970
if (failed(nativeCall))
29122971
return failure();
29132972

2914-
// Create a unique anonymous name declaration to use, as its name is not
2915-
// important.
2916-
std::string anonName =
2917-
llvm::formatv("{0}_{1}", nativeFuncName, anonymousDeclNameCounter++)
2918-
.str();
2919-
FailureOr<ast::VariableDecl *> varDecl = defineVariableDecl(
2920-
anonName, loc, (*nativeCall)->getType(), *nativeCall, {});
2921-
if (failed(varDecl))
2922-
return failure();
2923-
2924-
return createDeclRefExpr(loc, *varDecl);
2973+
return *nativeCall;
29252974
}
29262975

29272976
FailureOr<ast::Type> Parser::validateMemberAccess(ast::Expr *parentExpr,

0 commit comments

Comments
 (0)