Skip to content

Commit 52b3d2c

Browse files
author
Ferdinand Lemaire
committed
merge conflicts
2 parents d5d25bd + f891e20 commit 52b3d2c

File tree

22 files changed

+528
-65
lines changed

22 files changed

+528
-65
lines changed

mlir/include/mlir/Dialect/PDL/IR/PDLOps.td

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -346,7 +346,8 @@ def PDL_OperationOp : PDL_Op<"operation", [AttrSizedOperandSegments]> {
346346
Variadic<PDL_InstOrRangeOf<PDL_Value>>:$operandValues,
347347
Variadic<PDL_Attribute>:$attributeValues,
348348
StrArrayAttr:$attributeValueNames,
349-
Variadic<PDL_InstOrRangeOf<PDL_Type>>:$typeValues);
349+
Variadic<PDL_InstOrRangeOf<PDL_Type>>:$typeValues,
350+
OptionalAttr<UI32Attr>:$numRegions);
350351
let results = (outs PDL_Operation:$op);
351352
let assemblyFormat = [{
352353
($opName^)? (`(` $operandValues^ `:` type($operandValues) `)`)?
@@ -361,9 +362,10 @@ def PDL_OperationOp : PDL_Op<"operation", [AttrSizedOperandSegments]> {
361362
CArg<"ValueRange", "std::nullopt">:$attrValues,
362363
CArg<"ValueRange", "std::nullopt">:$resultTypes), [{
363364
auto nameAttr = name ? $_builder.getStringAttr(*name) : StringAttr();
365+
IntegerAttr numRegionsAttr;
364366
build($_builder, $_state, $_builder.getType<OperationType>(), nameAttr,
365367
operandValues, attrValues, $_builder.getStrArrayAttr(attrNames),
366-
resultTypes);
368+
resultTypes, numRegionsAttr);
367369
}]>,
368370
];
369371
let extraClassDeclaration = [{

mlir/include/mlir/Dialect/PDLInterp/IR/PDLInterpOps.td

Lines changed: 11 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -430,15 +430,24 @@ def PDLInterp_CreateOperationOp
430430
Variadic<PDL_Attribute>:$inputAttributes,
431431
StrArrayAttr:$inputAttributeNames,
432432
Variadic<PDL_InstOrRangeOf<PDL_Type>>:$inputResultTypes,
433-
UnitAttr:$inferredResultTypes);
433+
UnitAttr:$inferredResultTypes,
434+
OptionalAttr<UI32Attr>:$numRegions);
434435
let results = (outs PDL_Operation:$resultOp);
435436

436437
let builders = [
437438
OpBuilder<(ins "StringRef":$name, "ValueRange":$types,
438439
"bool":$inferredResultTypes, "ValueRange":$operands,
439440
"ValueRange":$attributes, "ArrayAttr":$attributeNames), [{
441+
IntegerAttr numRegionsAttr;
440442
build($_builder, $_state, $_builder.getType<pdl::OperationType>(), name,
441-
operands, attributes, attributeNames, types, inferredResultTypes);
443+
operands, attributes, attributeNames, types, inferredResultTypes, numRegionsAttr);
444+
}]>,
445+
OpBuilder<(ins "StringRef":$name, "ValueRange":$types,
446+
"bool":$inferredResultTypes, "ValueRange":$operands,
447+
"ValueRange":$attributes, "ArrayAttr":$attributeNames, "uint32_t":$numRegions), [{
448+
auto numRegionsAttr = $_builder.getUI32IntegerAttr(numRegions);
449+
build($_builder, $_state, $_builder.getType<pdl::OperationType>(), name,
450+
operands, attributes, attributeNames, types, inferredResultTypes, numRegionsAttr);
442451
}]>
443452
];
444453
let assemblyFormat = [{

mlir/include/mlir/Dialect/Tosa/IR/TosaOps.td

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1113,6 +1113,8 @@ def Tosa_ReciprocalOp : Tosa_Op<"reciprocal", [
11131113
let results = (outs
11141114
Tosa_Tensor:$output
11151115
);
1116+
1117+
let hasFolder = 1;
11161118
}
11171119

11181120
//===----------------------------------------------------------------------===//
@@ -1466,6 +1468,7 @@ def Tosa_ConcatOp : Tosa_Op<"concat", [
14661468
/// Method used by InferTypeOpInterface.
14671469
static bool isCompatibleReturnTypes(TypeRange l, TypeRange r);
14681470
}];
1471+
let hasFolder = 1;
14691472
}
14701473

14711474
//===----------------------------------------------------------------------===//

mlir/include/mlir/Tools/PDLL/AST/Nodes.h

Lines changed: 12 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -502,12 +502,11 @@ class OperationExpr final
502502
private llvm::TrailingObjects<OperationExpr, Expr *,
503503
NamedAttributeDecl *> {
504504
public:
505-
static OperationExpr *create(Context &ctx, SMRange loc,
506-
const ods::Operation *odsOp,
507-
const OpNameDecl *nameDecl,
508-
ArrayRef<Expr *> operands,
509-
ArrayRef<Expr *> resultTypes,
510-
ArrayRef<NamedAttributeDecl *> attributes);
505+
static OperationExpr *
506+
create(Context &ctx, SMRange loc, const ods::Operation *odsOp,
507+
const OpNameDecl *nameDecl, ArrayRef<Expr *> operands,
508+
ArrayRef<Expr *> resultTypes,
509+
ArrayRef<NamedAttributeDecl *> attributes, unsigned numRegions);
511510

512511
/// Return the name of the operation, or std::nullopt if there isn't one.
513512
std::optional<StringRef> getName() const;
@@ -543,19 +542,22 @@ class OperationExpr final
543542
return const_cast<OperationExpr *>(this)->getAttributes();
544543
}
545544

545+
unsigned getNumRegions() const { return numRegions; }
546+
546547
private:
547548
OperationExpr(SMRange loc, Type type, const OpNameDecl *nameDecl,
548549
unsigned numOperands, unsigned numResultTypes,
549-
unsigned numAttributes, SMRange nameLoc)
550+
unsigned numAttributes, unsigned numRegions, SMRange nameLoc)
550551
: Base(loc, type), nameDecl(nameDecl), numOperands(numOperands),
551552
numResultTypes(numResultTypes), numAttributes(numAttributes),
552-
nameLoc(nameLoc) {}
553+
numRegions(numRegions), nameLoc(nameLoc) {}
553554

554555
/// The name decl of this expression.
555556
const OpNameDecl *nameDecl;
556557

557-
/// The number of operands, result types, and attributes of the operation.
558-
unsigned numOperands, numResultTypes, numAttributes;
558+
/// The number of operands, result types, attributes and regions of the
559+
/// operation.
560+
unsigned numOperands, numResultTypes, numAttributes, numRegions;
559561

560562
/// The location of the operation name in the expression if it has a name.
561563
SMRange nameLoc;

mlir/lib/Conversion/PDLToPDLInterp/PDLToPDLInterp.cpp

Lines changed: 15 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -151,7 +151,7 @@ struct PatternLowering {
151151

152152
/// A mapping between constraint questions that refer to values created by
153153
/// constraints and the temporary placeholder values created for them.
154-
DenseMap<std::pair<ConstraintQuestion *, unsigned>, Value> substitutions;
154+
std::multimap<std::pair<ConstraintQuestion *, unsigned>, Value> substitutions;
155155
};
156156
} // namespace
157157

@@ -377,8 +377,9 @@ Value PatternLowering::getValueAt(Block *&currentBlock, Position *pos) {
377377
auto *constrResPos = cast<ConstraintPosition>(pos);
378378
Value placeholderValue = builder.create<pdl_interp::CreateAttributeOp>(
379379
loc, StringAttr::get(builder.getContext(), "placeholder"));
380-
substitutions[{constrResPos->getQuestion(), constrResPos->getIndex()}] =
381-
placeholderValue;
380+
substitutions.insert(
381+
{{constrResPos->getQuestion(), constrResPos->getIndex()},
382+
placeholderValue});
382383
value = placeholderValue;
383384
break;
384385
}
@@ -474,11 +475,15 @@ void PatternLowering::generate(BoolNode *boolNode, Block *&currentBlock,
474475
std::pair<ConstraintQuestion *, unsigned> substitutionKey = {
475476
cstQuestion, result.index()};
476477
// Check if there are substitutions to perform. If the result is never
477-
// used no substitutions will have been generated.
478-
if (substitutions.count(substitutionKey)) {
479-
substitutions[substitutionKey].replaceAllUsesWith(result.value());
480-
substitutions[substitutionKey].getDefiningOp()->erase();
481-
}
478+
// used or multiple calls to the same constraint have been merged,
479+
// no substitutions will have been generated for this specific op.
480+
auto range = substitutions.equal_range(substitutionKey);
481+
std::for_each(range.first, range.second, [&](const auto &elem) {
482+
Value placeholder = elem.second;
483+
placeholder.replaceAllUsesWith(result.value());
484+
placeholder.getDefiningOp()->erase();
485+
});
486+
substitutions.erase(substitutionKey);
482487
}
483488
break;
484489
}
@@ -767,9 +772,10 @@ void PatternLowering::generateRewriter(
767772

768773
// Create the new operation.
769774
Location loc = operationOp.getLoc();
775+
auto numRegions = operationOp.getNumRegions().value_or(0);
770776
Value createdOp = builder.create<pdl_interp::CreateOperationOp>(
771777
loc, *operationOp.getOpName(), types, hasInferredResultTypes, operands,
772-
attributes, operationOp.getAttributeValueNames());
778+
attributes, operationOp.getAttributeValueNames(), numRegions);
773779
rewriteValues[operationOp.getOp()] = createdOp;
774780

775781
// Generate accesses for any results that have their types constrained.

mlir/lib/Dialect/PDL/IR/PDL.cpp

Lines changed: 29 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -142,34 +142,51 @@ LogicalResult OperandsOp::verify() { return verifyHasBindingUse(*this); }
142142
// pdl::OperationOp
143143
//===----------------------------------------------------------------------===//
144144

145+
/// Handles parsing of OperationOpAttributes, e.g. {"attr" = %attribute}.
146+
/// Also allows empty `{}`
145147
static ParseResult parseOperationOpAttributes(
146148
OpAsmParser &p,
147149
SmallVectorImpl<OpAsmParser::UnresolvedOperand> &attrOperands,
148150
ArrayAttr &attrNamesAttr) {
149151
Builder &builder = p.getBuilder();
150152
SmallVector<Attribute, 4> attrNames;
151153
if (succeeded(p.parseOptionalLBrace())) {
152-
auto parseOperands = [&]() {
153-
StringAttr nameAttr;
154-
OpAsmParser::UnresolvedOperand operand;
155-
if (p.parseAttribute(nameAttr) || p.parseEqual() ||
156-
p.parseOperand(operand))
154+
if (failed(p.parseOptionalRBrace())) {
155+
auto parseOperands = [&]() {
156+
StringAttr nameAttr;
157+
OpAsmParser::UnresolvedOperand operand;
158+
if (p.parseAttribute(nameAttr) || p.parseEqual() ||
159+
p.parseOperand(operand))
160+
return failure();
161+
attrNames.push_back(nameAttr);
162+
attrOperands.push_back(operand);
163+
return success();
164+
};
165+
if (p.parseCommaSeparatedList(parseOperands) || p.parseRBrace())
157166
return failure();
158-
attrNames.push_back(nameAttr);
159-
attrOperands.push_back(operand);
160-
return success();
161-
};
162-
if (p.parseCommaSeparatedList(parseOperands) || p.parseRBrace())
163-
return failure();
167+
}
164168
}
165169
attrNamesAttr = builder.getArrayAttr(attrNames);
166170
return success();
167171
}
168172

173+
/// Handles printing of OperationOpAttributes, e.g. {"attr" = %attribute}.
174+
/// Prints empty `{}` when it would not be possible to discern the attr-dict
175+
/// otherwise.
169176
static void printOperationOpAttributes(OpAsmPrinter &p, OperationOp op,
170177
OperandRange attrArgs,
171178
ArrayAttr attrNames) {
172-
if (attrNames.empty())
179+
/// Only omit printing empty `{}` if there are no other attributes that have
180+
/// to be printed later because otherwise we could not discern the attr dict.
181+
static const SmallVector<StringRef, 3> specialAttrs = {
182+
"operand_segment_sizes", "attributeValueNames", "opName"};
183+
bool onlySpecialAttrs =
184+
llvm::all_of(op->getAttrs(), [&](const NamedAttribute &attr) {
185+
return llvm::any_of(specialAttrs, [&](const StringRef &predefinedAttr) {
186+
return attr.getName() == predefinedAttr;
187+
});
188+
});
189+
if (attrNames.empty() && onlySpecialAttrs)
173190
return;
174191
p << " {";
175192
interleaveComma(llvm::seq<int>(0, attrNames.size()), p,

mlir/lib/Dialect/PDLInterp/IR/PDLInterp.cpp

Lines changed: 22 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -64,35 +64,45 @@ LogicalResult CreateOperationOp::verify() {
6464
return success();
6565
}
6666

67+
/// Handles parsing of OperationOpAttributes, e.g. {"attr" = %attribute}.
68+
/// Also allows empty `{}`
6769
static ParseResult parseCreateOperationOpAttributes(
6870
OpAsmParser &p,
6971
SmallVectorImpl<OpAsmParser::UnresolvedOperand> &attrOperands,
7072
ArrayAttr &attrNamesAttr) {
7173
Builder &builder = p.getBuilder();
7274
SmallVector<Attribute, 4> attrNames;
7375
if (succeeded(p.parseOptionalLBrace())) {
74-
auto parseOperands = [&]() {
75-
StringAttr nameAttr;
76-
OpAsmParser::UnresolvedOperand operand;
77-
if (p.parseAttribute(nameAttr) || p.parseEqual() ||
78-
p.parseOperand(operand))
76+
if (failed(p.parseOptionalRBrace())) {
77+
auto parseOperands = [&]() {
78+
StringAttr nameAttr;
79+
OpAsmParser::UnresolvedOperand operand;
80+
if (p.parseAttribute(nameAttr) || p.parseEqual() ||
81+
p.parseOperand(operand))
82+
return failure();
83+
attrNames.push_back(nameAttr);
84+
attrOperands.push_back(operand);
85+
return success();
86+
};
87+
if (p.parseCommaSeparatedList(parseOperands) || p.parseRBrace())
7988
return failure();
80-
attrNames.push_back(nameAttr);
81-
attrOperands.push_back(operand);
82-
return success();
83-
};
84-
if (p.parseCommaSeparatedList(parseOperands) || p.parseRBrace())
85-
return failure();
89+
}
8690
}
8791
attrNamesAttr = builder.getArrayAttr(attrNames);
8892
return success();
8993
}
9094

95+
/// Handles printing of OperationOpAttributes, e.g. {"attr" = %attribute}.
96+
/// Prints empty `{}` when it would not be possible to discern the attr-dict
97+
/// otherwise.
9198
static void printCreateOperationOpAttributes(OpAsmPrinter &p,
9299
CreateOperationOp op,
93100
OperandRange attrArgs,
94101
ArrayAttr attrNames) {
95-
if (attrNames.empty())
102+
/// Only omit printing empty `{}` if we have result types because otherwise we
103+
/// could not discern the attr dict.
104+
unsigned numResultTypes = op.getODSOperandIndexAndLength(2).second;
105+
if (attrNames.empty() && (numResultTypes > 0 || op.getInferredResultTypes()))
96106
return;
97107
p << " {";
98108
interleaveComma(llvm::seq<int>(0, attrNames.size()), p,

mlir/lib/Dialect/Tosa/IR/TosaCanonicalizations.cpp

Lines changed: 60 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -562,6 +562,9 @@ OpFoldResult AddOp::fold(FoldAdaptor adaptor) {
562562
if (!lhsAttr || !rhsAttr)
563563
return {};
564564

565+
if (lhsTy != rhsTy)
566+
return {};
567+
565568
return binaryFolder<std::plus<APInt>, std::plus<APFloat>>(lhsAttr, rhsAttr,
566569
resultTy);
567570
}
@@ -663,6 +666,26 @@ OpFoldResult MulOp::fold(FoldAdaptor adaptor) {
663666
return mulBinaryFolder(lhsAttr, rhsAttr, resultTy, getShift());
664667
}
665668

669+
OpFoldResult ReciprocalOp::fold(ArrayRef<Attribute> operands) {
670+
auto constantAttr = dyn_cast_or_null<DenseElementsAttr>(operands[0]);
671+
auto lhsTy = dyn_cast<RankedTensorType>(getInput1().getType());
672+
673+
if (!lhsTy || !constantAttr) {
674+
return {};
675+
}
676+
677+
if (!constantAttr.isSplat()) {
678+
return {};
679+
}
680+
681+
auto floatVal = constantAttr.getSplatValue<llvm::APFloat>();
682+
683+
auto recipAttr = FloatAttr::get(lhsTy.getElementType(), 1.0);
684+
APFloat recip = recipAttr.getValue();
685+
recip.divide(floatVal, APFloat::rmNearestTiesToEven);
686+
return DenseElementsAttr::get(lhsTy, recip);
687+
}
688+
666689
OpFoldResult SubOp::fold(FoldAdaptor adaptor) {
667690
auto lhsTy = getInput1().getType().dyn_cast<RankedTensorType>();
668691
auto rhsTy = getInput2().getType().dyn_cast<RankedTensorType>();
@@ -680,6 +703,9 @@ OpFoldResult SubOp::fold(FoldAdaptor adaptor) {
680703
if (!lhsAttr || !rhsAttr)
681704
return {};
682705

706+
if (lhsTy != rhsTy)
707+
return {};
708+
683709
return binaryFolder<std::minus<APInt>, std::minus<APFloat>>(lhsAttr, rhsAttr,
684710
resultTy);
685711
}
@@ -998,3 +1024,37 @@ OpFoldResult TransposeOp::fold(FoldAdaptor adaptor) {
9981024

9991025
return getInput1();
10001026
}
1027+
1028+
OpFoldResult ConcatOp::fold(ArrayRef<Attribute> operands) {
1029+
// Fold consecutive concats on the same axis into a single op.
1030+
// Keep track of the operands so we are able to construct a new concat
1031+
// later. Conservatively assume that we double the number of operands when
1032+
// folding
1033+
SmallVector<Value, 8> concatOperands;
1034+
concatOperands.reserve(2 * getNumOperands());
1035+
1036+
// Find all operands that are foldable concats
1037+
bool canFold = false;
1038+
for (Value operand : getOperands()) {
1039+
concatOperands.emplace_back(operand);
1040+
1041+
auto producer = dyn_cast_or_null<ConcatOp>(operand.getDefiningOp());
1042+
if (!producer)
1043+
continue;
1044+
1045+
// Foldable if axes are the same
1046+
if (getAxis() != producer.getAxis())
1047+
continue;
1048+
1049+
// Replace the original operand with all incoming operands
1050+
canFold = true;
1051+
concatOperands.pop_back();
1052+
llvm::append_range(concatOperands, producer->getOperands());
1053+
}
1054+
1055+
if (!canFold)
1056+
return {};
1057+
1058+
getOperation()->setOperands(concatOperands);
1059+
return getResult();
1060+
}

0 commit comments

Comments
 (0)