@@ -1446,24 +1446,20 @@ void generateCollapsedIndexingRegion(Location loc, Block *block,
1446
1446
}
1447
1447
}
1448
1448
1449
- template < typename LinalgType>
1450
- Operation * createCollapsedOp (LinalgType op ,
1451
- const CollapsingInfo &collapsingInfo ,
1452
- RewriterBase &rewriter) {
1453
- static_assert (llvm::is_one_of<LinalgType, GenericOp, CopyOp>::value ,
1454
- " unsupported linalg op type to create " );
1449
+ void collapseOperandsAndResults (LinalgOp op,
1450
+ const CollapsingInfo &collapsingInfo ,
1451
+ RewriterBase &rewriter ,
1452
+ SmallVectorImpl<Value> &inputOperands,
1453
+ SmallVectorImpl<Value> &outputOperands ,
1454
+ SmallVectorImpl<Type> &resultTypes) {
1455
1455
Location loc = op->getLoc ();
1456
-
1457
- // Get the input operands.
1458
- SmallVector<Value> inputOperands =
1456
+ inputOperands =
1459
1457
llvm::map_to_vector (op.getDpsInputOperands (), [&](OpOperand *opOperand) {
1460
1458
return getCollapsedOpOperand (loc, op, opOperand, collapsingInfo,
1461
1459
rewriter);
1462
1460
});
1463
1461
1464
1462
// Get the output operands and result types.
1465
- SmallVector<Type> resultTypes;
1466
- SmallVector<Value> outputOperands;
1467
1463
resultTypes.reserve (op.getNumDpsInits ());
1468
1464
outputOperands.reserve (op.getNumDpsInits ());
1469
1465
for (OpOperand &output : op.getDpsInitsMutable ()) {
@@ -1475,41 +1471,69 @@ Operation *createCollapsedOp(LinalgType op,
1475
1471
if (!op.hasPureBufferSemantics ())
1476
1472
resultTypes.push_back (newOutput.getType ());
1477
1473
}
1474
+ }
1478
1475
1479
- if (isa<linalg::CopyOp>(op)) {
1480
- return rewriter.create <linalg::CopyOp>(loc, inputOperands[0 ],
1481
- outputOperands[0 ]);
1482
- }
1476
+ // / Clone a `LinalgOp` to a collapsed version of same name
1477
+ template <typename OpTy>
1478
+ OpTy cloneToCollapsedOp (RewriterBase &rewriter, OpTy origOp,
1479
+ const CollapsingInfo &collapsingInfo) {
1480
+ return nullptr ;
1481
+ }
1483
1482
1484
- // Get the iterator types for the operand.
1485
- SmallVector<utils::IteratorType> iteratorTypes =
1486
- getCollapsedOpIteratorTypes (op.getIteratorTypesArray (), collapsingInfo);
1483
+ // / Collapse any `LinalgOp` that does not require any specialization such as
1484
+ // / indexing_maps, iterator_types, etc.
1485
+ template <>
1486
+ LinalgOp cloneToCollapsedOp<LinalgOp>(RewriterBase &rewriter, LinalgOp origOp,
1487
+ const CollapsingInfo &collapsingInfo) {
1488
+ SmallVector<Value> inputOperands, outputOperands;
1489
+ SmallVector<Type> resultTypes;
1490
+ collapseOperandsAndResults (origOp, collapsingInfo, rewriter, inputOperands,
1491
+ outputOperands, resultTypes);
1492
+ return cast<LinalgOp>(clone (
1493
+ rewriter, origOp, resultTypes,
1494
+ llvm::to_vector (llvm::concat<Value>(inputOperands, outputOperands))));
1495
+ }
1487
1496
1488
- // Get the indexing maps.
1489
- auto indexingMaps =
1490
- llvm::map_to_vector (op.getIndexingMapsArray (), [&](AffineMap map) {
1497
+ // / Collapse a `GenericOp`
1498
+ template <>
1499
+ GenericOp cloneToCollapsedOp<GenericOp>(RewriterBase &rewriter,
1500
+ GenericOp origOp,
1501
+ const CollapsingInfo &collapsingInfo) {
1502
+ SmallVector<Value> inputOperands, outputOperands;
1503
+ SmallVector<Type> resultTypes;
1504
+ collapseOperandsAndResults (origOp, collapsingInfo, rewriter, inputOperands,
1505
+ outputOperands, resultTypes);
1506
+ SmallVector<AffineMap> indexingMaps (
1507
+ llvm::map_range (origOp.getIndexingMapsArray (), [&](AffineMap map) {
1491
1508
return getCollapsedOpIndexingMap (map, collapsingInfo);
1492
- });
1509
+ }));
1510
+
1511
+ SmallVector<utils::IteratorType> iteratorTypes (getCollapsedOpIteratorTypes (
1512
+ origOp.getIteratorTypesArray (), collapsingInfo));
1493
1513
1494
- Operation * collapsedOp = rewriter.create <linalg::GenericOp>(
1495
- loc , resultTypes, inputOperands, outputOperands, indexingMaps,
1514
+ GenericOp collapsedOp = rewriter.create <linalg::GenericOp>(
1515
+ origOp. getLoc () , resultTypes, inputOperands, outputOperands, indexingMaps,
1496
1516
iteratorTypes, [](OpBuilder &builder, Location loc, ValueRange args) {});
1497
- Block *origOpBlock = &op ->getRegion (0 ).front ();
1517
+ Block *origOpBlock = &origOp ->getRegion (0 ).front ();
1498
1518
Block *collapsedOpBlock = &collapsedOp->getRegion (0 ).front ();
1499
1519
rewriter.mergeBlocks (origOpBlock, collapsedOpBlock,
1500
1520
collapsedOpBlock->getArguments ());
1501
-
1502
1521
return collapsedOp;
1503
1522
}
1504
1523
1524
+ LinalgOp createCollapsedOp (LinalgOp op, const CollapsingInfo &collapsingInfo,
1525
+ RewriterBase &rewriter) {
1526
+ if (GenericOp genericOp = dyn_cast<GenericOp>(op.getOperation ())) {
1527
+ return cloneToCollapsedOp (rewriter, genericOp, collapsingInfo);
1528
+ } else {
1529
+ return cloneToCollapsedOp (rewriter, op, collapsingInfo);
1530
+ }
1531
+ }
1532
+
1505
1533
// / Implementation of fusion with reshape operation by collapsing dimensions.
1506
- template <typename LinalgType>
1507
- FailureOr<SmallVector<Value>> mlir::linalg::collapseOpIterationDims (
1508
- LinalgType op, ArrayRef<ReassociationIndices> foldedIterationDims,
1534
+ FailureOr<CollapseResult> mlir::linalg::collapseOpIterationDims (
1535
+ LinalgOp op, ArrayRef<ReassociationIndices> foldedIterationDims,
1509
1536
RewriterBase &rewriter) {
1510
- static_assert (llvm::is_one_of<LinalgType, GenericOp, CopyOp>::value,
1511
- " unsupported linalg op type to collapse" );
1512
-
1513
1537
// Bail on trivial no-op cases.
1514
1538
if (op.getNumLoops () <= 1 || foldedIterationDims.empty () ||
1515
1539
llvm::all_of (foldedIterationDims, [](ReassociationIndicesRef foldedDims) {
@@ -1538,8 +1562,7 @@ FailureOr<SmallVector<Value>> mlir::linalg::collapseOpIterationDims(
1538
1562
}
1539
1563
1540
1564
// Bail on non-canonical ranges.
1541
- SmallVector<Range> loopRanges =
1542
- cast<LinalgOp>(op.getOperation ()).createLoopRanges (rewriter, op.getLoc ());
1565
+ SmallVector<Range> loopRanges = op.createLoopRanges (rewriter, op.getLoc ());
1543
1566
auto opFoldIsConstantValue = [](OpFoldResult ofr, int64_t value) {
1544
1567
if (auto attr = llvm::dyn_cast_if_present<Attribute>(ofr))
1545
1568
return cast<IntegerAttr>(attr).getInt () == value;
@@ -1555,8 +1578,7 @@ FailureOr<SmallVector<Value>> mlir::linalg::collapseOpIterationDims(
1555
1578
op, " expected all loop ranges to have zero start and unit stride" );
1556
1579
}
1557
1580
1558
- LinalgType collapsedOp = cast<LinalgType>(
1559
- createCollapsedOp<LinalgType>(op, collapsingInfo, rewriter));
1581
+ LinalgOp collapsedOp = createCollapsedOp (op, collapsingInfo, rewriter);
1560
1582
1561
1583
Location loc = op->getLoc ();
1562
1584
if (collapsedOp.hasIndexSemantics ()) {
@@ -1597,7 +1619,7 @@ FailureOr<SmallVector<Value>> mlir::linalg::collapseOpIterationDims(
1597
1619
results.push_back (collapsedOpResult);
1598
1620
}
1599
1621
}
1600
- return results;
1622
+ return CollapseResult{ results, collapsedOp} ;
1601
1623
}
1602
1624
1603
1625
namespace {
@@ -1629,15 +1651,14 @@ class FoldWithProducerReshapeOpByCollapsing
1629
1651
continue ;
1630
1652
}
1631
1653
1632
- std::optional<SmallVector<Value>> replacements =
1633
- collapseOpIterationDims<linalg::GenericOp>(
1634
- genericOp, collapsableIterationDims, rewriter);
1635
- if (!replacements) {
1654
+ std::optional<CollapseResult> collapseResult = collapseOpIterationDims (
1655
+ genericOp, collapsableIterationDims, rewriter);
1656
+ if (!collapseResult) {
1636
1657
return rewriter.notifyMatchFailure (
1637
1658
genericOp, " failed to do the fusion by collapsing transformation" );
1638
1659
}
1639
1660
1640
- rewriter.replaceOp (genericOp, *replacements );
1661
+ rewriter.replaceOp (genericOp, collapseResult-> results );
1641
1662
return success ();
1642
1663
}
1643
1664
return failure ();
@@ -1671,13 +1692,12 @@ class CollapseLinalgDimensions : public OpRewritePattern<LinalgType> {
1671
1692
op, " specified dimensions cannot be collapsed" );
1672
1693
}
1673
1694
1674
- std::optional<SmallVector<Value>> replacements =
1675
- collapseOpIterationDims<LinalgType>(op, collapsableIterationDims,
1676
- rewriter);
1677
- if (!replacements) {
1695
+ std::optional<CollapseResult> collapseResult =
1696
+ collapseOpIterationDims (op, collapsableIterationDims, rewriter);
1697
+ if (!collapseResult) {
1678
1698
return rewriter.notifyMatchFailure (op, " failed to collapse dimensions" );
1679
1699
}
1680
- rewriter.replaceOp (op, *replacements );
1700
+ rewriter.replaceOp (op, collapseResult-> results );
1681
1701
return success ();
1682
1702
}
1683
1703
0 commit comments