@@ -661,6 +661,26 @@ void FillOp::getCanonicalizationPatterns(RewritePatternSet &results,
661
661
// GenericOp
662
662
// ===----------------------------------------------------------------------===//
663
663
664
+ static void buildGenericRegion (
665
+ OpBuilder &builder, OperationState &result, ValueRange inputs,
666
+ ValueRange outputs,
667
+ function_ref<void (OpBuilder &, Location, ValueRange)> bodyBuild) {
668
+ SmallVector<Type, 4 > blockArgTypes;
669
+ SmallVector<Location, 4 > blockArgLocs;
670
+ for (ValueRange container : {inputs, outputs}) {
671
+ for (Value v : container) {
672
+ blockArgTypes.push_back (getElementTypeOrSelf (v));
673
+ blockArgLocs.push_back (v.getLoc ());
674
+ }
675
+ }
676
+
677
+ OpBuilder::InsertionGuard guard (builder);
678
+ auto ®ion = *result.regions .front ();
679
+ Block *bodyBlock =
680
+ builder.createBlock (®ion, region.end (), blockArgTypes, blockArgLocs);
681
+ bodyBuild (builder, result.location , bodyBlock->getArguments ());
682
+ }
683
+
664
684
void GenericOp::getAsmBlockArgumentNames (Region ®ion,
665
685
OpAsmSetValueNameFn setNameFn) {
666
686
for (Value v : getRegionInputArgs ())
@@ -678,23 +698,8 @@ void GenericOp::build(
678
698
build (builder, result, resultTensorTypes, inputs, outputs, indexingMaps,
679
699
iteratorTypes, doc, libraryCall);
680
700
result.addAttributes (attributes);
681
- if (!bodyBuild)
682
- return ;
683
-
684
- SmallVector<Type, 4 > blockArgTypes;
685
- SmallVector<Location, 4 > blockArgLocs;
686
- for (ValueRange container : {inputs, outputs}) {
687
- for (Value v : container) {
688
- blockArgTypes.push_back (getElementTypeOrSelf (v));
689
- blockArgLocs.push_back (v.getLoc ());
690
- }
691
- }
692
-
693
- OpBuilder::InsertionGuard guard (builder);
694
- auto ®ion = *result.regions .front ();
695
- Block *bodyBlock =
696
- builder.createBlock (®ion, region.end (), blockArgTypes, blockArgLocs);
697
- bodyBuild (builder, result.location , bodyBlock->getArguments ());
701
+ if (bodyBuild)
702
+ buildGenericRegion (builder, result, inputs, outputs, bodyBuild);
698
703
}
699
704
700
705
void GenericOp::build (
@@ -1329,6 +1334,22 @@ void MapOp::getAsmResultNames(function_ref<void(Value, StringRef)> setNameFn) {
1329
1334
setNameFn (getResults ().front (), " mapped" );
1330
1335
}
1331
1336
1337
+ void MapOp::build (
1338
+ OpBuilder &builder, OperationState &result, ValueRange inputs, Value init,
1339
+ function_ref<void (OpBuilder &, Location, ValueRange)> bodyBuild,
1340
+ ArrayRef<NamedAttribute> attributes) {
1341
+ build (builder, result, TypeRange{}, inputs, init);
1342
+ result.addAttributes (attributes);
1343
+
1344
+ // Add output types for `RankedTensorType` output arguments.
1345
+ Type initType = init.getType ();
1346
+ if (initType.isa <RankedTensorType>())
1347
+ result.addTypes (initType);
1348
+
1349
+ if (bodyBuild)
1350
+ buildGenericRegion (builder, result, inputs, /* outputs=*/ {}, bodyBuild);
1351
+ }
1352
+
1332
1353
ParseResult MapOp::parse (OpAsmParser &parser, OperationState &result) {
1333
1354
if (parseDstStyleOp (parser, result))
1334
1355
return failure ();
@@ -1436,6 +1457,25 @@ void ReduceOp::getAsmResultNames(
1436
1457
setNameFn (getResults ().front (), " reduced" );
1437
1458
}
1438
1459
1460
+ void ReduceOp::build (
1461
+ OpBuilder &builder, OperationState &result, ValueRange inputs,
1462
+ ValueRange inits, ArrayRef<int64_t > dimensions,
1463
+ function_ref<void (OpBuilder &, Location, ValueRange)> bodyBuild,
1464
+ ArrayRef<NamedAttribute> attributes) {
1465
+ build (builder, result, TypeRange{}, inputs, inits, dimensions);
1466
+ result.addAttributes (attributes);
1467
+
1468
+ // Add output types for `RankedTensorType` output arguments.
1469
+ for (Value init : inits) {
1470
+ Type initType = init.getType ();
1471
+ if (initType.isa <RankedTensorType>())
1472
+ result.addTypes (initType);
1473
+ }
1474
+
1475
+ if (bodyBuild)
1476
+ buildGenericRegion (builder, result, inputs, inits, bodyBuild);
1477
+ }
1478
+
1439
1479
SmallVector<StringRef> ReduceOp::getIteratorTypesArray () {
1440
1480
int64_t inputRank = getInputs ()[0 ].getType ().cast <ShapedType>().getRank ();
1441
1481
SmallVector<StringRef> iteratorTypes (inputRank,
@@ -1618,45 +1658,32 @@ TransposeOp::getRegionBuilder() {
1618
1658
};
1619
1659
}
1620
1660
1621
- void TransposeOp::createRegion (::mlir::OpBuilder &opBuilder,
1622
- ::mlir::OperationState &odsState) {
1623
- Region *region = odsState.addRegion ();
1624
-
1625
- SmallVector<Type> argTypes;
1626
- SmallVector<Location> argLocs;
1627
- for (auto t : odsState.operands ) {
1628
- argTypes.push_back (getElementTypeOrSelf (t));
1629
- argLocs.push_back (opBuilder.getUnknownLoc ());
1630
- }
1631
-
1632
- // RAII.
1633
- OpBuilder::InsertionGuard guard (opBuilder);
1634
- Block *body =
1635
- opBuilder.createBlock (region, /* insertPt=*/ {}, argTypes, argLocs);
1636
-
1637
- ImplicitLocOpBuilder b (opBuilder.getUnknownLoc (), opBuilder);
1638
- getRegionBuilder ()(b, *body, odsState.attributes .getAttrs ());
1639
- }
1640
-
1641
- void TransposeOp::build (::mlir::OpBuilder &odsBuilder,
1642
- ::mlir::OperationState &odsState, Value input,
1643
- Value init, DenseI64ArrayAttr permutation,
1661
+ void TransposeOp::build (::mlir::OpBuilder &builder,
1662
+ ::mlir::OperationState &result, Value input, Value init,
1663
+ DenseI64ArrayAttr permutation,
1644
1664
ArrayRef<NamedAttribute> attributes) {
1645
- odsState.addOperands (input);
1646
- odsState.addOperands (init);
1647
- odsState.addAttribute (getPermutationAttrName (odsState.name ), permutation);
1648
- odsState.addAttributes (attributes);
1649
- odsState.addTypes (init.getType ());
1665
+ result.addOperands (input);
1666
+ result.addOperands (init);
1667
+ result.addAttribute (getPermutationAttrName (result.name ), permutation);
1668
+ result.addAttributes (attributes);
1669
+
1670
+ // Add output types for `RankedTensorType` output arguments.
1671
+ Type initType = init.getType ();
1672
+ if (initType.isa <RankedTensorType>())
1673
+ result.addTypes (initType);
1650
1674
1651
- createRegion (odsBuilder, odsState);
1675
+ buildGenericRegion (builder, result, input, init,
1676
+ [&](OpBuilder &b, Location loc, ValueRange args) {
1677
+ b.create <linalg::YieldOp>(loc, args[0 ]);
1678
+ });
1652
1679
}
1653
1680
1654
- void TransposeOp::build (::mlir::OpBuilder &odsBuilder ,
1655
- ::mlir::OperationState &odsState , Value input,
1656
- Value init, ArrayRef<int64_t > permutation,
1681
+ void TransposeOp::build (::mlir::OpBuilder &builder ,
1682
+ ::mlir::OperationState &result , Value input, Value init ,
1683
+ ArrayRef<int64_t > permutation,
1657
1684
ArrayRef<NamedAttribute> attributes) {
1658
- build (odsBuilder, odsState , input, init,
1659
- odsBuilder. getDenseI64ArrayAttr (permutation), attributes);
1685
+ build (builder, result , input, init, builder. getDenseI64ArrayAttr (permutation) ,
1686
+ attributes);
1660
1687
}
1661
1688
1662
1689
ParseResult TransposeOp::parse (OpAsmParser &parser, OperationState &result) {
@@ -1666,8 +1693,13 @@ ParseResult TransposeOp::parse(OpAsmParser &parser, OperationState &result) {
1666
1693
})))
1667
1694
return failure ();
1668
1695
1669
- OpBuilder opBuilder (parser.getContext ());
1670
- createRegion (opBuilder, result);
1696
+ (void )result.addRegion ();
1697
+ OpBuilder builder (parser.getContext ());
1698
+ buildGenericRegion (builder, result, /* inputs=*/ result.operands ,
1699
+ /* outputs=*/ {},
1700
+ [&](OpBuilder &b, Location loc, ValueRange args) {
1701
+ b.create <linalg::YieldOp>(loc, args[0 ]);
1702
+ });
1671
1703
return success ();
1672
1704
}
1673
1705
0 commit comments