@@ -1373,16 +1373,17 @@ getOperandReassociation(AffineMap indexingMap,
1373
1373
}
1374
1374
1375
1375
// / Get the new value to use for a given `OpOperand` in the collapsed operation.
1376
- static Value getCollapsedOpOperand (Location loc, GenericOp genericOp ,
1376
+ static Value getCollapsedOpOperand (Location loc, LinalgOp op ,
1377
1377
OpOperand *opOperand,
1378
1378
const CollapsingInfo &collapsingInfo,
1379
1379
OpBuilder &builder) {
1380
- AffineMap indexingMap = genericOp .getMatchingIndexingMap (opOperand);
1380
+ AffineMap indexingMap = op .getMatchingIndexingMap (opOperand);
1381
1381
SmallVector<ReassociationIndices> operandReassociation =
1382
1382
getOperandReassociation (indexingMap, collapsingInfo);
1383
1383
1384
- // If the number of entries in the reassocation for the operand is same as the
1385
- // number of results of the indexing map, then nothing to do for this operand.
1384
+ // If the number of entries in the reassociation for the operand is same as
1385
+ // the number of results of the indexing map, then nothing to do for this
1386
+ // operand.
1386
1387
Value operand = opOperand->get ();
1387
1388
if (operandReassociation.size () == indexingMap.getNumResults ())
1388
1389
return operand;
@@ -1439,41 +1440,100 @@ void generateCollapsedIndexingRegion(Location loc, Block *block,
1439
1440
}
1440
1441
}
1441
1442
1443
+ template <typename LinalgType>
1444
+ Operation *createCollapsedOp (LinalgType op,
1445
+ const CollapsingInfo &collapsingInfo,
1446
+ RewriterBase &rewriter) {
1447
+ static_assert (llvm::is_one_of<LinalgType, GenericOp, CopyOp>::value,
1448
+ " unsupported linalg op type to create" );
1449
+ Location loc = op->getLoc ();
1450
+
1451
+ // Get the input operands.
1452
+ SmallVector<Value> inputOperands =
1453
+ llvm::map_to_vector (op.getDpsInputOperands (), [&](OpOperand *opOperand) {
1454
+ return getCollapsedOpOperand (loc, op, opOperand, collapsingInfo,
1455
+ rewriter);
1456
+ });
1457
+
1458
+ // Get the output operands and result types.
1459
+ SmallVector<Type> resultTypes;
1460
+ SmallVector<Value> outputOperands;
1461
+ resultTypes.reserve (op.getNumDpsInits ());
1462
+ outputOperands.reserve (op.getNumDpsInits ());
1463
+ for (OpOperand &output : op.getDpsInitsMutable ()) {
1464
+ Value newOutput =
1465
+ getCollapsedOpOperand (loc, op, &output, collapsingInfo, rewriter);
1466
+ outputOperands.push_back (newOutput);
1467
+ // If the op has "buffer semantics", then the init operands are ranked
1468
+ // memrefs and the op has no results.
1469
+ if (!op.hasBufferSemantics ())
1470
+ resultTypes.push_back (newOutput.getType ());
1471
+ }
1472
+
1473
+ if (isa<linalg::CopyOp>(op)) {
1474
+ return rewriter.create <linalg::CopyOp>(loc, inputOperands[0 ],
1475
+ outputOperands[0 ]);
1476
+ }
1477
+
1478
+ // Get the iterator types for the operand.
1479
+ SmallVector<utils::IteratorType> iteratorTypes =
1480
+ getCollapsedOpIteratorTypes (op.getIteratorTypesArray (), collapsingInfo);
1481
+
1482
+ // Get the indexing maps.
1483
+ auto indexingMaps =
1484
+ llvm::map_to_vector (op.getIndexingMapsArray (), [&](AffineMap map) {
1485
+ return getCollapsedOpIndexingMap (map, collapsingInfo);
1486
+ });
1487
+
1488
+ Operation *collapsedOp = rewriter.create <linalg::GenericOp>(
1489
+ loc, resultTypes, inputOperands, outputOperands, indexingMaps,
1490
+ iteratorTypes, [](OpBuilder &builder, Location loc, ValueRange args) {});
1491
+ Block *origOpBlock = &op->getRegion (0 ).front ();
1492
+ Block *collapsedOpBlock = &collapsedOp->getRegion (0 ).front ();
1493
+ rewriter.mergeBlocks (origOpBlock, collapsedOpBlock,
1494
+ collapsedOpBlock->getArguments ());
1495
+
1496
+ return collapsedOp;
1497
+ }
1498
+
1442
1499
// / Implementation of fusion with reshape operation by collapsing dimensions.
1443
- FailureOr<SmallVector<Value>> mlir::linalg::collapseGenericOpIterationDims (
1444
- GenericOp genericOp, ArrayRef<ReassociationIndices> foldedIterationDims,
1500
+ template <typename LinalgType>
1501
+ FailureOr<SmallVector<Value>> mlir::linalg::collapseOpIterationDims (
1502
+ LinalgType op, ArrayRef<ReassociationIndices> foldedIterationDims,
1445
1503
RewriterBase &rewriter) {
1504
+ static_assert (llvm::is_one_of<LinalgType, GenericOp, CopyOp>::value,
1505
+ " unsupported linalg op type to collapse" );
1506
+
1446
1507
// Bail on trivial no-op cases.
1447
- if (genericOp .getNumLoops () <= 1 || foldedIterationDims.empty () ||
1508
+ if (op .getNumLoops () <= 1 || foldedIterationDims.empty () ||
1448
1509
llvm::all_of (foldedIterationDims, [](ReassociationIndicesRef foldedDims) {
1449
1510
return foldedDims.size () <= 1 ;
1450
1511
}))
1451
1512
return failure ();
1452
1513
1453
- bool hasBufferSemantics = genericOp .hasBufferSemantics ();
1514
+ bool hasBufferSemantics = op .hasBufferSemantics ();
1454
1515
if (hasBufferSemantics &&
1455
- !llvm::all_of (genericOp ->getOperands (), [&](Value operand) -> bool {
1516
+ !llvm::all_of (op ->getOperands (), [&](Value operand) -> bool {
1456
1517
MemRefType memRefToCollapse = dyn_cast<MemRefType>(operand.getType ());
1457
1518
if (!memRefToCollapse)
1458
1519
return true ;
1459
1520
1460
1521
return memref::CollapseShapeOp::isGuaranteedCollapsible (
1461
1522
memRefToCollapse, foldedIterationDims);
1462
1523
}))
1463
- return rewriter.notifyMatchFailure (genericOp ,
1524
+ return rewriter.notifyMatchFailure (op ,
1464
1525
" memref is not guaranteed collapsible" );
1465
1526
1466
1527
CollapsingInfo collapsingInfo;
1467
- if (failed (collapsingInfo. initialize (genericOp. getNumLoops (),
1468
- foldedIterationDims))) {
1528
+ if (failed (
1529
+ collapsingInfo. initialize (op. getNumLoops (), foldedIterationDims))) {
1469
1530
return rewriter.notifyMatchFailure (
1470
- genericOp , " illegal to collapse specified dimensions" );
1531
+ op , " illegal to collapse specified dimensions" );
1471
1532
}
1472
1533
1473
1534
// Bail on non-canonical ranges.
1474
1535
SmallVector<Range> loopRanges =
1475
- cast<LinalgOp>(genericOp.getOperation ())
1476
- .createLoopRanges (rewriter, genericOp.getLoc ());
1536
+ cast<LinalgOp>(op.getOperation ()).createLoopRanges (rewriter, op.getLoc ());
1477
1537
auto opFoldIsConstantValue = [](OpFoldResult ofr, int64_t value) {
1478
1538
if (auto attr = llvm::dyn_cast_if_present<Attribute>(ofr))
1479
1539
return cast<IntegerAttr>(attr).getInt () == value;
@@ -1486,78 +1546,36 @@ FailureOr<SmallVector<Value>> mlir::linalg::collapseGenericOpIterationDims(
1486
1546
opFoldIsConstantValue (range.stride , 1 );
1487
1547
})) {
1488
1548
return rewriter.notifyMatchFailure (
1489
- genericOp,
1490
- " expected all loop ranges to have zero start and unit stride" );
1549
+ op, " expected all loop ranges to have zero start and unit stride" );
1491
1550
}
1492
1551
1493
- // Get the iterator types for the operand.
1494
- SmallVector<utils::IteratorType> iteratorTypes = getCollapsedOpIteratorTypes (
1495
- genericOp.getIteratorTypesArray (), collapsingInfo);
1496
-
1497
- // Get the indexing maps.
1498
- auto indexingMaps = llvm::to_vector (
1499
- llvm::map_range (genericOp.getIndexingMapsArray (), [&](AffineMap map) {
1500
- return getCollapsedOpIndexingMap (map, collapsingInfo);
1501
- }));
1502
-
1503
- Location loc = genericOp->getLoc ();
1504
-
1505
- // Get the input operands.
1506
- auto inputOperands = llvm::to_vector (llvm::map_range (
1507
- genericOp.getDpsInputOperands (), [&](OpOperand *opOperand) {
1508
- return getCollapsedOpOperand (loc, genericOp, opOperand, collapsingInfo,
1509
- rewriter);
1510
- }));
1511
-
1512
- // Get the output operands and result types.
1513
- SmallVector<Type> resultTypes;
1514
- SmallVector<Value> outputOperands;
1515
- resultTypes.reserve (genericOp.getNumDpsInits ());
1516
- outputOperands.reserve (genericOp.getNumDpsInits ());
1517
- for (OpOperand &output : genericOp.getDpsInitsMutable ()) {
1518
- Value newOutput = getCollapsedOpOperand (loc, genericOp, &output,
1519
- collapsingInfo, rewriter);
1520
- outputOperands.push_back (newOutput);
1521
- // If the op has "buffer semantics", then the init operands are ranked
1522
- // memrefs and the op has no results.
1523
- if (!hasBufferSemantics)
1524
- resultTypes.push_back (newOutput.getType ());
1525
- }
1526
-
1527
- // Create the generic op.
1528
- auto collapsedGenericOp = rewriter.create <linalg::GenericOp>(
1529
- loc, resultTypes, inputOperands, outputOperands, indexingMaps,
1530
- iteratorTypes, [](OpBuilder &builder, Location loc, ValueRange args) {});
1531
- Block *origOpBlock = &genericOp->getRegion (0 ).front ();
1532
- Block *collapsedOpBlock = &collapsedGenericOp->getRegion (0 ).front ();
1533
- rewriter.mergeBlocks (origOpBlock, collapsedOpBlock,
1534
- collapsedOpBlock->getArguments ());
1552
+ LinalgType collapsedOp = cast<LinalgType>(
1553
+ createCollapsedOp<LinalgType>(op, collapsingInfo, rewriter));
1535
1554
1536
- if (collapsedGenericOp.hasIndexSemantics ()) {
1555
+ Location loc = op->getLoc ();
1556
+ if (collapsedOp.hasIndexSemantics ()) {
1537
1557
// Collect the loop range of the generic op.
1538
1558
OpBuilder::InsertionGuard g (rewriter);
1539
- rewriter.setInsertionPoint (collapsedGenericOp );
1559
+ rewriter.setInsertionPoint (collapsedOp );
1540
1560
SmallVector<Value> loopBound =
1541
- llvm::to_vector ( llvm::map_range (loopRanges, [&](Range range) {
1561
+ llvm::map_to_vector (loopRanges, [&](Range range) {
1542
1562
return getValueOrCreateConstantIndexOp (rewriter, loc, range.size );
1543
- }));
1544
- generateCollapsedIndexingRegion (loc,
1545
- &collapsedGenericOp->getRegion (0 ).front (),
1563
+ });
1564
+ generateCollapsedIndexingRegion (loc, &collapsedOp->getRegion (0 ).front (),
1546
1565
collapsingInfo, loopBound, rewriter);
1547
1566
}
1548
1567
1549
1568
// Insert expanding reshape for the result to get back the original result
1550
1569
// type.
1551
1570
SmallVector<Value> results;
1552
- for (const auto &originalResult : llvm::enumerate (genericOp->getResults ())) {
1553
- Value collapsedOpResult =
1554
- collapsedGenericOp->getResult (originalResult.index ());
1571
+ for (const auto &originalResult : llvm::enumerate (op->getResults ())) {
1572
+ Value collapsedOpResult = collapsedOp->getResult (originalResult.index ());
1555
1573
auto originalResultType =
1556
1574
cast<ShapedType>(originalResult.value ().getType ());
1557
1575
auto collapsedOpResultType = cast<ShapedType>(collapsedOpResult.getType ());
1558
1576
if (collapsedOpResultType.getRank () != originalResultType.getRank ()) {
1559
1577
AffineMap indexingMap =
1560
- genericOp .getIndexingMapMatchingResult (originalResult.value ());
1578
+ op .getIndexingMapMatchingResult (originalResult.value ());
1561
1579
SmallVector<ReassociationIndices> reassociation =
1562
1580
getOperandReassociation (indexingMap, collapsingInfo);
1563
1581
if (isa<MemRefType>(collapsedOpResult.getType ())) {
@@ -1606,8 +1624,8 @@ class FoldWithProducerReshapeOpByCollapsing
1606
1624
}
1607
1625
1608
1626
std::optional<SmallVector<Value>> replacements =
1609
- collapseGenericOpIterationDims (genericOp, collapsableIterationDims,
1610
- rewriter);
1627
+ collapseOpIterationDims<linalg::GenericOp>(
1628
+ genericOp, collapsableIterationDims, rewriter);
1611
1629
if (!replacements) {
1612
1630
return rewriter.notifyMatchFailure (
1613
1631
genericOp, " failed to do the fusion by collapsing transformation" );
@@ -1624,36 +1642,36 @@ class FoldWithProducerReshapeOpByCollapsing
1624
1642
};
1625
1643
1626
1644
// / Pattern to collapse dimensions.
1627
- class CollapseLinalgDimensions : public OpRewritePattern <GenericOp> {
1645
+ template <typename LinalgType>
1646
+ class CollapseLinalgDimensions : public OpRewritePattern <LinalgType> {
1628
1647
public:
1629
1648
CollapseLinalgDimensions (MLIRContext *context,
1630
1649
GetCollapsableDimensionsFn collapseDimensions,
1631
1650
PatternBenefit benefit = 1 )
1632
- : OpRewritePattern<GenericOp >(context, benefit),
1651
+ : OpRewritePattern<LinalgType >(context, benefit),
1633
1652
controlCollapseDimension (std::move(collapseDimensions)) {}
1634
1653
1635
- LogicalResult matchAndRewrite (GenericOp genericOp ,
1654
+ LogicalResult matchAndRewrite (LinalgType op ,
1636
1655
PatternRewriter &rewriter) const override {
1637
1656
SmallVector<ReassociationIndices> collapsableIterationDims =
1638
- controlCollapseDimension (genericOp );
1657
+ controlCollapseDimension (op );
1639
1658
if (collapsableIterationDims.empty ())
1640
1659
return failure ();
1641
1660
1642
1661
// Check if the specified list of dimensions to collapse is a valid list.
1643
- if (!areDimSequencesPreserved (genericOp .getIndexingMapsArray (),
1662
+ if (!areDimSequencesPreserved (op .getIndexingMapsArray (),
1644
1663
collapsableIterationDims)) {
1645
1664
return rewriter.notifyMatchFailure (
1646
- genericOp , " specified dimensions cannot be collapsed" );
1665
+ op , " specified dimensions cannot be collapsed" );
1647
1666
}
1648
1667
1649
1668
std::optional<SmallVector<Value>> replacements =
1650
- collapseGenericOpIterationDims (genericOp , collapsableIterationDims,
1651
- rewriter);
1669
+ collapseOpIterationDims<LinalgType>(op , collapsableIterationDims,
1670
+ rewriter);
1652
1671
if (!replacements) {
1653
- return rewriter.notifyMatchFailure (genericOp,
1654
- " failed to collapse dimensions" );
1672
+ return rewriter.notifyMatchFailure (op, " failed to collapse dimensions" );
1655
1673
}
1656
- rewriter.replaceOp (genericOp , *replacements);
1674
+ rewriter.replaceOp (op , *replacements);
1657
1675
return success ();
1658
1676
}
1659
1677
@@ -1884,8 +1902,9 @@ void mlir::linalg::populateElementwiseOpsFusionPatterns(
1884
1902
void mlir::linalg::populateCollapseDimensions (
1885
1903
RewritePatternSet &patterns,
1886
1904
const GetCollapsableDimensionsFn &controlCollapseDimensions) {
1887
- patterns.add <CollapseLinalgDimensions>(patterns.getContext (),
1888
- controlCollapseDimensions);
1905
+ patterns.add <CollapseLinalgDimensions<linalg::GenericOp>,
1906
+ CollapseLinalgDimensions<linalg::CopyOp>>(
1907
+ patterns.getContext (), controlCollapseDimensions);
1889
1908
}
1890
1909
1891
1910
// ===---------------------------------------------------------------------===//
0 commit comments