Skip to content

Commit 1265071

Browse files
FXML.2007: PDLL support for creating new ops with empty regions (#30)
1 parent 7cc5626 commit 1265071

File tree

17 files changed

+261
-52
lines changed

17 files changed

+261
-52
lines changed

mlir/include/mlir/Dialect/PDL/IR/PDLOps.td

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -346,7 +346,8 @@ def PDL_OperationOp : PDL_Op<"operation", [AttrSizedOperandSegments]> {
346346
Variadic<PDL_InstOrRangeOf<PDL_Value>>:$operandValues,
347347
Variadic<PDL_Attribute>:$attributeValues,
348348
StrArrayAttr:$attributeValueNames,
349-
Variadic<PDL_InstOrRangeOf<PDL_Type>>:$typeValues);
349+
Variadic<PDL_InstOrRangeOf<PDL_Type>>:$typeValues,
350+
OptionalAttr<UI32Attr>:$numRegions);
350351
let results = (outs PDL_Operation:$op);
351352
let assemblyFormat = [{
352353
($opName^)? (`(` $operandValues^ `:` type($operandValues) `)`)?
@@ -361,9 +362,10 @@ def PDL_OperationOp : PDL_Op<"operation", [AttrSizedOperandSegments]> {
361362
CArg<"ValueRange", "llvm::None">:$attrValues,
362363
CArg<"ValueRange", "llvm::None">:$resultTypes), [{
363364
auto nameAttr = name ? $_builder.getStringAttr(*name) : StringAttr();
365+
IntegerAttr numRegionsAttr;
364366
build($_builder, $_state, $_builder.getType<OperationType>(), nameAttr,
365367
operandValues, attrValues, $_builder.getStrArrayAttr(attrNames),
366-
resultTypes);
368+
resultTypes, numRegionsAttr);
367369
}]>,
368370
];
369371
let extraClassDeclaration = [{

mlir/include/mlir/Dialect/PDLInterp/IR/PDLInterpOps.td

Lines changed: 11 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -430,15 +430,24 @@ def PDLInterp_CreateOperationOp
430430
Variadic<PDL_Attribute>:$inputAttributes,
431431
StrArrayAttr:$inputAttributeNames,
432432
Variadic<PDL_InstOrRangeOf<PDL_Type>>:$inputResultTypes,
433-
UnitAttr:$inferredResultTypes);
433+
UnitAttr:$inferredResultTypes,
434+
OptionalAttr<UI32Attr>:$numRegions);
434435
let results = (outs PDL_Operation:$resultOp);
435436

436437
let builders = [
437438
OpBuilder<(ins "StringRef":$name, "ValueRange":$types,
438439
"bool":$inferredResultTypes, "ValueRange":$operands,
439440
"ValueRange":$attributes, "ArrayAttr":$attributeNames), [{
441+
IntegerAttr numRegionsAttr;
440442
build($_builder, $_state, $_builder.getType<pdl::OperationType>(), name,
441-
operands, attributes, attributeNames, types, inferredResultTypes);
443+
operands, attributes, attributeNames, types, inferredResultTypes, numRegionsAttr);
444+
}]>,
445+
OpBuilder<(ins "StringRef":$name, "ValueRange":$types,
446+
"bool":$inferredResultTypes, "ValueRange":$operands,
447+
"ValueRange":$attributes, "ArrayAttr":$attributeNames, "uint32_t":$numRegions), [{
448+
auto numRegionsAttr = $_builder.getUI32IntegerAttr(numRegions);
449+
build($_builder, $_state, $_builder.getType<pdl::OperationType>(), name,
450+
operands, attributes, attributeNames, types, inferredResultTypes, numRegionsAttr);
442451
}]>
443452
];
444453
let assemblyFormat = [{

mlir/include/mlir/Tools/PDLL/AST/Nodes.h

Lines changed: 12 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -501,12 +501,11 @@ class OperationExpr final
501501
private llvm::TrailingObjects<OperationExpr, Expr *,
502502
NamedAttributeDecl *> {
503503
public:
504-
static OperationExpr *create(Context &ctx, SMRange loc,
505-
const ods::Operation *odsOp,
506-
const OpNameDecl *nameDecl,
507-
ArrayRef<Expr *> operands,
508-
ArrayRef<Expr *> resultTypes,
509-
ArrayRef<NamedAttributeDecl *> attributes);
504+
static OperationExpr *
505+
create(Context &ctx, SMRange loc, const ods::Operation *odsOp,
506+
const OpNameDecl *nameDecl, ArrayRef<Expr *> operands,
507+
ArrayRef<Expr *> resultTypes,
508+
ArrayRef<NamedAttributeDecl *> attributes, unsigned numRegions);
510509

511510
/// Return the name of the operation, or None if there isn't one.
512511
Optional<StringRef> getName() const;
@@ -542,19 +541,22 @@ class OperationExpr final
542541
return const_cast<OperationExpr *>(this)->getAttributes();
543542
}
544543

544+
unsigned getNumRegions() const { return numRegions; }
545+
545546
private:
546547
OperationExpr(SMRange loc, Type type, const OpNameDecl *nameDecl,
547548
unsigned numOperands, unsigned numResultTypes,
548-
unsigned numAttributes, SMRange nameLoc)
549+
unsigned numAttributes, unsigned numRegions, SMRange nameLoc)
549550
: Base(loc, type), nameDecl(nameDecl), numOperands(numOperands),
550551
numResultTypes(numResultTypes), numAttributes(numAttributes),
551-
nameLoc(nameLoc) {}
552+
numRegions(numRegions), nameLoc(nameLoc) {}
552553

553554
/// The name decl of this expression.
554555
const OpNameDecl *nameDecl;
555556

556-
/// The number of operands, result types, and attributes of the operation.
557-
unsigned numOperands, numResultTypes, numAttributes;
557+
/// The number of operands, result types, attributes and regions of the
558+
/// operation.
559+
unsigned numOperands, numResultTypes, numAttributes, numRegions;
558560

559561
/// The location of the operation name in the expression if it has a name.
560562
SMRange nameLoc;

mlir/lib/Conversion/PDLToPDLInterp/PDLToPDLInterp.cpp

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -767,9 +767,10 @@ void PatternLowering::generateRewriter(
767767

768768
// Create the new operation.
769769
Location loc = operationOp.getLoc();
770+
auto numRegions = operationOp.getNumRegions().value_or(0);
770771
Value createdOp = builder.create<pdl_interp::CreateOperationOp>(
771772
loc, *operationOp.getOpName(), types, hasInferredResultTypes, operands,
772-
attributes, operationOp.getAttributeValueNames());
773+
attributes, operationOp.getAttributeValueNames(), numRegions);
773774
rewriteValues[operationOp.getOp()] = createdOp;
774775

775776
// Generate accesses for any results that have their types constrained.

mlir/lib/Dialect/PDL/IR/PDL.cpp

Lines changed: 29 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -141,34 +141,51 @@ LogicalResult OperandsOp::verify() { return verifyHasBindingUse(*this); }
141141
// pdl::OperationOp
142142
//===----------------------------------------------------------------------===//
143143

144+
/// Handles parsing of OperationOpAttributes, e.g. {"attr" = %attribute}.
145+
/// Also allows empty `{}`
144146
static ParseResult parseOperationOpAttributes(
145147
OpAsmParser &p,
146148
SmallVectorImpl<OpAsmParser::UnresolvedOperand> &attrOperands,
147149
ArrayAttr &attrNamesAttr) {
148150
Builder &builder = p.getBuilder();
149151
SmallVector<Attribute, 4> attrNames;
150152
if (succeeded(p.parseOptionalLBrace())) {
151-
auto parseOperands = [&]() {
152-
StringAttr nameAttr;
153-
OpAsmParser::UnresolvedOperand operand;
154-
if (p.parseAttribute(nameAttr) || p.parseEqual() ||
155-
p.parseOperand(operand))
153+
if (failed(p.parseOptionalRBrace())) {
154+
auto parseOperands = [&]() {
155+
StringAttr nameAttr;
156+
OpAsmParser::UnresolvedOperand operand;
157+
if (p.parseAttribute(nameAttr) || p.parseEqual() ||
158+
p.parseOperand(operand))
159+
return failure();
160+
attrNames.push_back(nameAttr);
161+
attrOperands.push_back(operand);
162+
return success();
163+
};
164+
if (p.parseCommaSeparatedList(parseOperands) || p.parseRBrace())
156165
return failure();
157-
attrNames.push_back(nameAttr);
158-
attrOperands.push_back(operand);
159-
return success();
160-
};
161-
if (p.parseCommaSeparatedList(parseOperands) || p.parseRBrace())
162-
return failure();
166+
}
163167
}
164168
attrNamesAttr = builder.getArrayAttr(attrNames);
165169
return success();
166170
}
167171

172+
/// Handles printing of OperationOpAttributes, e.g. {"attr" = %attribute}.
173+
/// Prints empty `{}` when it would not be possible to discern the attr-dict
174+
/// otherwise.
168175
static void printOperationOpAttributes(OpAsmPrinter &p, OperationOp op,
169176
OperandRange attrArgs,
170177
ArrayAttr attrNames) {
171-
if (attrNames.empty())
178+
/// Only omit printing empty `{}` if there are no other attributes that have
179+
/// to be printed later because otherwise we could not discern the attr dict.
180+
static const SmallVector<StringRef, 3> specialAttrs = {
181+
"operand_segment_sizes", "attributeValueNames", "opName"};
182+
bool onlySpecialAttrs =
183+
llvm::all_of(op->getAttrs(), [&](const NamedAttribute &attr) {
184+
return llvm::any_of(specialAttrs, [&](const StringRef &predefinedAttr) {
185+
return attr.getName() == predefinedAttr;
186+
});
187+
});
188+
if (attrNames.empty() && onlySpecialAttrs)
172189
return;
173190
p << " {";
174191
interleaveComma(llvm::seq<int>(0, attrNames.size()), p,

mlir/lib/Dialect/PDLInterp/IR/PDLInterp.cpp

Lines changed: 22 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -64,35 +64,45 @@ LogicalResult CreateOperationOp::verify() {
6464
return success();
6565
}
6666

67+
/// Handles parsing of OperationOpAttributes, e.g. {"attr" = %attribute}.
68+
/// Also allows empty `{}`
6769
static ParseResult parseCreateOperationOpAttributes(
6870
OpAsmParser &p,
6971
SmallVectorImpl<OpAsmParser::UnresolvedOperand> &attrOperands,
7072
ArrayAttr &attrNamesAttr) {
7173
Builder &builder = p.getBuilder();
7274
SmallVector<Attribute, 4> attrNames;
7375
if (succeeded(p.parseOptionalLBrace())) {
74-
auto parseOperands = [&]() {
75-
StringAttr nameAttr;
76-
OpAsmParser::UnresolvedOperand operand;
77-
if (p.parseAttribute(nameAttr) || p.parseEqual() ||
78-
p.parseOperand(operand))
76+
if (failed(p.parseOptionalRBrace())) {
77+
auto parseOperands = [&]() {
78+
StringAttr nameAttr;
79+
OpAsmParser::UnresolvedOperand operand;
80+
if (p.parseAttribute(nameAttr) || p.parseEqual() ||
81+
p.parseOperand(operand))
82+
return failure();
83+
attrNames.push_back(nameAttr);
84+
attrOperands.push_back(operand);
85+
return success();
86+
};
87+
if (p.parseCommaSeparatedList(parseOperands) || p.parseRBrace())
7988
return failure();
80-
attrNames.push_back(nameAttr);
81-
attrOperands.push_back(operand);
82-
return success();
83-
};
84-
if (p.parseCommaSeparatedList(parseOperands) || p.parseRBrace())
85-
return failure();
89+
}
8690
}
8791
attrNamesAttr = builder.getArrayAttr(attrNames);
8892
return success();
8993
}
9094

95+
/// Handles printing of OperationOpAttributes, e.g. {"attr" = %attribute}.
96+
/// Prints empty `{}` when it would not be possible to discern the attr-dict
97+
/// otherwise.
9198
static void printCreateOperationOpAttributes(OpAsmPrinter &p,
9299
CreateOperationOp op,
93100
OperandRange attrArgs,
94101
ArrayAttr attrNames) {
95-
if (attrNames.empty())
102+
/// Only omit printing empty `{}` if we have result types because otherwise we
103+
/// could not discern the attr dict.
104+
unsigned numResultTypes = op.getODSOperandIndexAndLength(2).second;
105+
if (attrNames.empty() && (numResultTypes > 0 || op.getInferredResultTypes()))
96106
return;
97107
p << " {";
98108
interleaveComma(llvm::seq<int>(0, attrNames.size()), p,

mlir/lib/Rewrite/ByteCode.cpp

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -885,6 +885,14 @@ void Generator::generate(pdl_interp::CreateOperationOp op,
885885
writer.append(kInferTypesMarker);
886886
else
887887
writer.appendPDLValueList(op.getInputResultTypes());
888+
889+
// Add number of regions
890+
if (IntegerAttr attr = op.getNumRegionsAttr()) {
891+
writer.append(ByteCodeField(attr.getUInt()));
892+
} else {
893+
unsigned numRegions = 0;
894+
writer.append(ByteCodeField(numRegions));
895+
}
888896
}
889897
void Generator::generate(pdl_interp::CreateRangeOp op, ByteCodeWriter &writer) {
890898
// Append the correct opcode for the range type.
@@ -1663,6 +1671,12 @@ void ByteCodeExecutor::executeCreateOperation(PatternRewriter &rewriter,
16631671
}
16641672
}
16651673

1674+
// handle regions:
1675+
unsigned numRegions = read();
1676+
for (unsigned i = 0; i < numRegions; i++) {
1677+
state.addRegion();
1678+
}
1679+
16661680
Operation *resultOp = rewriter.create(state);
16671681
memory[memIndex] = resultOp;
16681682

mlir/lib/Tools/PDLL/AST/NodePrinter.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -247,7 +247,7 @@ void NodePrinter::printImpl(const MemberAccessExpr *expr) {
247247
void NodePrinter::printImpl(const OperationExpr *expr) {
248248
os << "OperationExpr " << expr << " Type<";
249249
print(expr->getType());
250-
os << ">\n";
250+
os << "> numRegions:" << expr->getNumRegions() << "\n";
251251

252252
printChildren(expr->getNameDecl());
253253
printChildren("Operands", expr->getOperands());

mlir/lib/Tools/PDLL/AST/Nodes.cpp

Lines changed: 8 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -301,11 +301,13 @@ MemberAccessExpr *MemberAccessExpr::create(Context &ctx, SMRange loc,
301301
// OperationExpr
302302
//===----------------------------------------------------------------------===//
303303

304-
OperationExpr *
305-
OperationExpr::create(Context &ctx, SMRange loc, const ods::Operation *odsOp,
306-
const OpNameDecl *name, ArrayRef<Expr *> operands,
307-
ArrayRef<Expr *> resultTypes,
308-
ArrayRef<NamedAttributeDecl *> attributes) {
304+
OperationExpr *OperationExpr::create(Context &ctx, SMRange loc,
305+
const ods::Operation *odsOp,
306+
const OpNameDecl *name,
307+
ArrayRef<Expr *> operands,
308+
ArrayRef<Expr *> resultTypes,
309+
ArrayRef<NamedAttributeDecl *> attributes,
310+
unsigned numRegions) {
309311
unsigned allocSize =
310312
OperationExpr::totalSizeToAlloc<Expr *, NamedAttributeDecl *>(
311313
operands.size() + resultTypes.size(), attributes.size());
@@ -315,7 +317,7 @@ OperationExpr::create(Context &ctx, SMRange loc, const ods::Operation *odsOp,
315317
Type resultType = OperationType::get(ctx, name->getName(), odsOp);
316318
OperationExpr *opExpr = new (rawData)
317319
OperationExpr(loc, resultType, name, operands.size(), resultTypes.size(),
318-
attributes.size(), name->getLoc());
320+
attributes.size(), numRegions, name->getLoc());
319321
std::uninitialized_copy(operands.begin(), operands.end(),
320322
opExpr->getOperands().begin());
321323
std::uninitialized_copy(resultTypes.begin(), resultTypes.end(),

mlir/lib/Tools/PDLL/CodeGen/MLIRGen.cpp

Lines changed: 8 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -515,8 +515,14 @@ Value CodeGen::genExprImpl(const ast::OperationExpr *expr) {
515515
for (const ast::Expr *result : expr->getResultTypes())
516516
results.push_back(genSingleExpr(result));
517517

518-
return builder.create<pdl::OperationOp>(loc, opName, operands, attrNames,
519-
attrValues, results);
518+
auto operationOp = builder.create<pdl::OperationOp>(
519+
loc, opName, operands, attrNames, attrValues, results);
520+
521+
// numRegions
522+
if (expr->getNumRegions() > 0)
523+
operationOp.setNumRegions(expr->getNumRegions());
524+
525+
return operationOp;
520526
}
521527

522528
Value CodeGen::genExprImpl(const ast::RangeExpr *expr) {

mlir/lib/Tools/PDLL/Parser/Parser.cpp

Lines changed: 20 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -421,7 +421,8 @@ class Parser {
421421
OpResultTypeContext resultTypeContext,
422422
SmallVectorImpl<ast::Expr *> &operands,
423423
MutableArrayRef<ast::NamedAttributeDecl *> attributes,
424-
SmallVectorImpl<ast::Expr *> &results);
424+
SmallVectorImpl<ast::Expr *> &results,
425+
unsigned numRegions);
425426
LogicalResult
426427
validateOperationOperands(SMRange loc, Optional<StringRef> name,
427428
const ods::Operation *odsOp,
@@ -2129,8 +2130,23 @@ Parser::parseOperationExpr(OpResultTypeContext inputResultTypeContext) {
21292130
resultTypeContext = OpResultTypeContext::Interface;
21302131
}
21312132

2133+
// Parse list of regions
2134+
unsigned numRegions = 0;
2135+
if (consumeIf(Token::l_paren)) {
2136+
do {
2137+
if (failed(parseToken(Token::l_brace, "expected `{` to open region")))
2138+
return failure();
2139+
if (failed(parseToken(Token::r_brace, "expected `}` to close region")))
2140+
return failure();
2141+
numRegions++;
2142+
} while (consumeIf(Token::comma));
2143+
if (failed(parseToken(Token::r_paren, "expected `)` to close region "
2144+
"list")))
2145+
return failure();
2146+
}
2147+
21322148
return createOperationExpr(loc, *opNameDecl, resultTypeContext, operands,
2133-
attributes, resultTypes);
2149+
attributes, resultTypes, numRegions);
21342150
}
21352151

21362152
FailureOr<ast::Expr *> Parser::parseTupleExpr() {
@@ -2807,7 +2823,7 @@ FailureOr<ast::OperationExpr *> Parser::createOperationExpr(
28072823
OpResultTypeContext resultTypeContext,
28082824
SmallVectorImpl<ast::Expr *> &operands,
28092825
MutableArrayRef<ast::NamedAttributeDecl *> attributes,
2810-
SmallVectorImpl<ast::Expr *> &results) {
2826+
SmallVectorImpl<ast::Expr *> &results, unsigned numRegions) {
28112827
Optional<StringRef> opNameRef = name->getName();
28122828
const ods::Operation *odsOp = lookupODSOperation(opNameRef);
28132829

@@ -2844,7 +2860,7 @@ FailureOr<ast::OperationExpr *> Parser::createOperationExpr(
28442860
}
28452861

28462862
return ast::OperationExpr::create(ctx, loc, odsOp, name, operands, results,
2847-
attributes);
2863+
attributes, numRegions);
28482864
}
28492865

28502866
LogicalResult

mlir/test/Conversion/PDLToPDLInterp/pdl-to-pdl-interp-rewriter.mlir

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -260,3 +260,19 @@ module @range_op {
260260
}
261261
}
262262
}
263+
264+
// -----
265+
266+
// CHECK-LABEL: module @create_empty_region
267+
module @create_empty_region {
268+
// CHECK: module @rewriters
269+
// CHECK: func @pdl_generated_rewriter()
270+
// CHECK: %[[UNUSED:.*]] = pdl_interp.create_operation "bar.op" {} {numRegions = 1 : ui32}
271+
// CHECK: pdl_interp.finalize
272+
pdl.pattern : benefit(1) {
273+
%root = operation "foo.op"
274+
rewrite %root {
275+
%unused = operation "bar.op" {} {"numRegions" = 1 : ui32}
276+
}
277+
}
278+
}

0 commit comments

Comments
 (0)