Skip to content

Commit 5c3ed39

Browse files
authored
[mlir][linalg] Enable CollapseLinalgDimensions to collapse linalg::CopyOp (#68526)
1 parent 86bc486 commit 5c3ed39

File tree

4 files changed

+151
-93
lines changed

4 files changed

+151
-93
lines changed

mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h

Lines changed: 10 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -1047,16 +1047,18 @@ bool isDimSequencePreserved(AffineMap map, ReassociationIndicesRef dimSequence);
10471047
bool areDimSequencesPreserved(ArrayRef<AffineMap> maps,
10481048
ArrayRef<ReassociationIndices> dimSequences);
10491049

1050-
/// Collapses dimensions of linalg.generic operation. A precondition to
1051-
/// calling this method is that for each list in `foldedIterationDim`, the
1050+
/// Collapses dimensions of linalg.generic/linalg.copy operation. A precondition
1051+
/// to calling this method is that for each list in `foldedIterationDim`, the
10521052
/// sequence of dimensions is contiguous in domains of all `indexing_maps` of
1053-
/// the `genericOp`. This can be checked using `areDimSequencePreserved` method.
1053+
/// the `linalgOp`. This can be checked using `areDimSequencePreserved` method.
10541054
/// When valid, the method also collapses the operands of the op. Returns
1055-
/// replacement values of the results of the original `genericOp` by inserting
1055+
/// replacement values of the results of the original `linalgOp` by inserting
10561056
/// reshapes to get back values of compatible types.
1057-
FailureOr<SmallVector<Value>> collapseGenericOpIterationDims(
1058-
GenericOp genericOp, ArrayRef<ReassociationIndices> foldedIterationDims,
1059-
RewriterBase &rewriter);
1057+
template <typename LinalgType>
1058+
FailureOr<SmallVector<Value>>
1059+
collapseOpIterationDims(LinalgType op,
1060+
ArrayRef<ReassociationIndices> foldedIterationDims,
1061+
RewriterBase &rewriter);
10601062

10611063
struct LowerPackResult {
10621064
tensor::PadOp padOp;
@@ -1515,7 +1517,7 @@ void populateEraseUnnecessaryInputsPatterns(RewritePatternSet &patterns);
15151517
/// to return an array of `ReassociationIndices` representing dimensions that
15161518
/// should be merged.
15171519
using GetCollapsableDimensionsFn =
1518-
std::function<SmallVector<ReassociationIndices>(linalg::GenericOp)>;
1520+
std::function<SmallVector<ReassociationIndices>(linalg::LinalgOp)>;
15191521

15201522
/// Pattern to collapse dimensions in a linalg.generic op. This will collapse
15211523
/// tensor operands when needed and expand back the result tensors.

mlir/lib/Dialect/Linalg/Transforms/ElementwiseOpFusion.cpp

Lines changed: 103 additions & 84 deletions
Original file line numberDiff line numberDiff line change
@@ -1373,16 +1373,17 @@ getOperandReassociation(AffineMap indexingMap,
13731373
}
13741374

13751375
/// Get the new value to use for a given `OpOperand` in the collapsed operation.
1376-
static Value getCollapsedOpOperand(Location loc, GenericOp genericOp,
1376+
static Value getCollapsedOpOperand(Location loc, LinalgOp op,
13771377
OpOperand *opOperand,
13781378
const CollapsingInfo &collapsingInfo,
13791379
OpBuilder &builder) {
1380-
AffineMap indexingMap = genericOp.getMatchingIndexingMap(opOperand);
1380+
AffineMap indexingMap = op.getMatchingIndexingMap(opOperand);
13811381
SmallVector<ReassociationIndices> operandReassociation =
13821382
getOperandReassociation(indexingMap, collapsingInfo);
13831383

1384-
// If the number of entries in the reassocation for the operand is same as the
1385-
// number of results of the indexing map, then nothing to do for this operand.
1384+
// If the number of entries in the reassociation for the operand is same as
1385+
// the number of results of the indexing map, then nothing to do for this
1386+
// operand.
13861387
Value operand = opOperand->get();
13871388
if (operandReassociation.size() == indexingMap.getNumResults())
13881389
return operand;
@@ -1439,41 +1440,100 @@ void generateCollapsedIndexingRegion(Location loc, Block *block,
14391440
}
14401441
}
14411442

1443+
template <typename LinalgType>
1444+
Operation *createCollapsedOp(LinalgType op,
1445+
const CollapsingInfo &collapsingInfo,
1446+
RewriterBase &rewriter) {
1447+
static_assert(llvm::is_one_of<LinalgType, GenericOp, CopyOp>::value,
1448+
"unsupported linalg op type to create");
1449+
Location loc = op->getLoc();
1450+
1451+
// Get the input operands.
1452+
SmallVector<Value> inputOperands =
1453+
llvm::map_to_vector(op.getDpsInputOperands(), [&](OpOperand *opOperand) {
1454+
return getCollapsedOpOperand(loc, op, opOperand, collapsingInfo,
1455+
rewriter);
1456+
});
1457+
1458+
// Get the output operands and result types.
1459+
SmallVector<Type> resultTypes;
1460+
SmallVector<Value> outputOperands;
1461+
resultTypes.reserve(op.getNumDpsInits());
1462+
outputOperands.reserve(op.getNumDpsInits());
1463+
for (OpOperand &output : op.getDpsInitsMutable()) {
1464+
Value newOutput =
1465+
getCollapsedOpOperand(loc, op, &output, collapsingInfo, rewriter);
1466+
outputOperands.push_back(newOutput);
1467+
// If the op has "buffer semantics", then the init operands are ranked
1468+
// memrefs and the op has no results.
1469+
if (!op.hasBufferSemantics())
1470+
resultTypes.push_back(newOutput.getType());
1471+
}
1472+
1473+
if (isa<linalg::CopyOp>(op)) {
1474+
return rewriter.create<linalg::CopyOp>(loc, inputOperands[0],
1475+
outputOperands[0]);
1476+
}
1477+
1478+
// Get the iterator types for the operand.
1479+
SmallVector<utils::IteratorType> iteratorTypes =
1480+
getCollapsedOpIteratorTypes(op.getIteratorTypesArray(), collapsingInfo);
1481+
1482+
// Get the indexing maps.
1483+
auto indexingMaps =
1484+
llvm::map_to_vector(op.getIndexingMapsArray(), [&](AffineMap map) {
1485+
return getCollapsedOpIndexingMap(map, collapsingInfo);
1486+
});
1487+
1488+
Operation *collapsedOp = rewriter.create<linalg::GenericOp>(
1489+
loc, resultTypes, inputOperands, outputOperands, indexingMaps,
1490+
iteratorTypes, [](OpBuilder &builder, Location loc, ValueRange args) {});
1491+
Block *origOpBlock = &op->getRegion(0).front();
1492+
Block *collapsedOpBlock = &collapsedOp->getRegion(0).front();
1493+
rewriter.mergeBlocks(origOpBlock, collapsedOpBlock,
1494+
collapsedOpBlock->getArguments());
1495+
1496+
return collapsedOp;
1497+
}
1498+
14421499
/// Implementation of fusion with reshape operation by collapsing dimensions.
1443-
FailureOr<SmallVector<Value>> mlir::linalg::collapseGenericOpIterationDims(
1444-
GenericOp genericOp, ArrayRef<ReassociationIndices> foldedIterationDims,
1500+
template <typename LinalgType>
1501+
FailureOr<SmallVector<Value>> mlir::linalg::collapseOpIterationDims(
1502+
LinalgType op, ArrayRef<ReassociationIndices> foldedIterationDims,
14451503
RewriterBase &rewriter) {
1504+
static_assert(llvm::is_one_of<LinalgType, GenericOp, CopyOp>::value,
1505+
"unsupported linalg op type to collapse");
1506+
14461507
// Bail on trivial no-op cases.
1447-
if (genericOp.getNumLoops() <= 1 || foldedIterationDims.empty() ||
1508+
if (op.getNumLoops() <= 1 || foldedIterationDims.empty() ||
14481509
llvm::all_of(foldedIterationDims, [](ReassociationIndicesRef foldedDims) {
14491510
return foldedDims.size() <= 1;
14501511
}))
14511512
return failure();
14521513

1453-
bool hasBufferSemantics = genericOp.hasBufferSemantics();
1514+
bool hasBufferSemantics = op.hasBufferSemantics();
14541515
if (hasBufferSemantics &&
1455-
!llvm::all_of(genericOp->getOperands(), [&](Value operand) -> bool {
1516+
!llvm::all_of(op->getOperands(), [&](Value operand) -> bool {
14561517
MemRefType memRefToCollapse = dyn_cast<MemRefType>(operand.getType());
14571518
if (!memRefToCollapse)
14581519
return true;
14591520

14601521
return memref::CollapseShapeOp::isGuaranteedCollapsible(
14611522
memRefToCollapse, foldedIterationDims);
14621523
}))
1463-
return rewriter.notifyMatchFailure(genericOp,
1524+
return rewriter.notifyMatchFailure(op,
14641525
"memref is not guaranteed collapsible");
14651526

14661527
CollapsingInfo collapsingInfo;
1467-
if (failed(collapsingInfo.initialize(genericOp.getNumLoops(),
1468-
foldedIterationDims))) {
1528+
if (failed(
1529+
collapsingInfo.initialize(op.getNumLoops(), foldedIterationDims))) {
14691530
return rewriter.notifyMatchFailure(
1470-
genericOp, "illegal to collapse specified dimensions");
1531+
op, "illegal to collapse specified dimensions");
14711532
}
14721533

14731534
// Bail on non-canonical ranges.
14741535
SmallVector<Range> loopRanges =
1475-
cast<LinalgOp>(genericOp.getOperation())
1476-
.createLoopRanges(rewriter, genericOp.getLoc());
1536+
cast<LinalgOp>(op.getOperation()).createLoopRanges(rewriter, op.getLoc());
14771537
auto opFoldIsConstantValue = [](OpFoldResult ofr, int64_t value) {
14781538
if (auto attr = llvm::dyn_cast_if_present<Attribute>(ofr))
14791539
return cast<IntegerAttr>(attr).getInt() == value;
@@ -1486,78 +1546,36 @@ FailureOr<SmallVector<Value>> mlir::linalg::collapseGenericOpIterationDims(
14861546
opFoldIsConstantValue(range.stride, 1);
14871547
})) {
14881548
return rewriter.notifyMatchFailure(
1489-
genericOp,
1490-
"expected all loop ranges to have zero start and unit stride");
1549+
op, "expected all loop ranges to have zero start and unit stride");
14911550
}
14921551

1493-
// Get the iterator types for the operand.
1494-
SmallVector<utils::IteratorType> iteratorTypes = getCollapsedOpIteratorTypes(
1495-
genericOp.getIteratorTypesArray(), collapsingInfo);
1496-
1497-
// Get the indexing maps.
1498-
auto indexingMaps = llvm::to_vector(
1499-
llvm::map_range(genericOp.getIndexingMapsArray(), [&](AffineMap map) {
1500-
return getCollapsedOpIndexingMap(map, collapsingInfo);
1501-
}));
1502-
1503-
Location loc = genericOp->getLoc();
1504-
1505-
// Get the input operands.
1506-
auto inputOperands = llvm::to_vector(llvm::map_range(
1507-
genericOp.getDpsInputOperands(), [&](OpOperand *opOperand) {
1508-
return getCollapsedOpOperand(loc, genericOp, opOperand, collapsingInfo,
1509-
rewriter);
1510-
}));
1511-
1512-
// Get the output operands and result types.
1513-
SmallVector<Type> resultTypes;
1514-
SmallVector<Value> outputOperands;
1515-
resultTypes.reserve(genericOp.getNumDpsInits());
1516-
outputOperands.reserve(genericOp.getNumDpsInits());
1517-
for (OpOperand &output : genericOp.getDpsInitsMutable()) {
1518-
Value newOutput = getCollapsedOpOperand(loc, genericOp, &output,
1519-
collapsingInfo, rewriter);
1520-
outputOperands.push_back(newOutput);
1521-
// If the op has "buffer semantics", then the init operands are ranked
1522-
// memrefs and the op has no results.
1523-
if (!hasBufferSemantics)
1524-
resultTypes.push_back(newOutput.getType());
1525-
}
1526-
1527-
// Create the generic op.
1528-
auto collapsedGenericOp = rewriter.create<linalg::GenericOp>(
1529-
loc, resultTypes, inputOperands, outputOperands, indexingMaps,
1530-
iteratorTypes, [](OpBuilder &builder, Location loc, ValueRange args) {});
1531-
Block *origOpBlock = &genericOp->getRegion(0).front();
1532-
Block *collapsedOpBlock = &collapsedGenericOp->getRegion(0).front();
1533-
rewriter.mergeBlocks(origOpBlock, collapsedOpBlock,
1534-
collapsedOpBlock->getArguments());
1552+
LinalgType collapsedOp = cast<LinalgType>(
1553+
createCollapsedOp<LinalgType>(op, collapsingInfo, rewriter));
15351554

1536-
if (collapsedGenericOp.hasIndexSemantics()) {
1555+
Location loc = op->getLoc();
1556+
if (collapsedOp.hasIndexSemantics()) {
15371557
// Collect the loop range of the generic op.
15381558
OpBuilder::InsertionGuard g(rewriter);
1539-
rewriter.setInsertionPoint(collapsedGenericOp);
1559+
rewriter.setInsertionPoint(collapsedOp);
15401560
SmallVector<Value> loopBound =
1541-
llvm::to_vector(llvm::map_range(loopRanges, [&](Range range) {
1561+
llvm::map_to_vector(loopRanges, [&](Range range) {
15421562
return getValueOrCreateConstantIndexOp(rewriter, loc, range.size);
1543-
}));
1544-
generateCollapsedIndexingRegion(loc,
1545-
&collapsedGenericOp->getRegion(0).front(),
1563+
});
1564+
generateCollapsedIndexingRegion(loc, &collapsedOp->getRegion(0).front(),
15461565
collapsingInfo, loopBound, rewriter);
15471566
}
15481567

15491568
// Insert expanding reshape for the result to get back the original result
15501569
// type.
15511570
SmallVector<Value> results;
1552-
for (const auto &originalResult : llvm::enumerate(genericOp->getResults())) {
1553-
Value collapsedOpResult =
1554-
collapsedGenericOp->getResult(originalResult.index());
1571+
for (const auto &originalResult : llvm::enumerate(op->getResults())) {
1572+
Value collapsedOpResult = collapsedOp->getResult(originalResult.index());
15551573
auto originalResultType =
15561574
cast<ShapedType>(originalResult.value().getType());
15571575
auto collapsedOpResultType = cast<ShapedType>(collapsedOpResult.getType());
15581576
if (collapsedOpResultType.getRank() != originalResultType.getRank()) {
15591577
AffineMap indexingMap =
1560-
genericOp.getIndexingMapMatchingResult(originalResult.value());
1578+
op.getIndexingMapMatchingResult(originalResult.value());
15611579
SmallVector<ReassociationIndices> reassociation =
15621580
getOperandReassociation(indexingMap, collapsingInfo);
15631581
if (isa<MemRefType>(collapsedOpResult.getType())) {
@@ -1606,8 +1624,8 @@ class FoldWithProducerReshapeOpByCollapsing
16061624
}
16071625

16081626
std::optional<SmallVector<Value>> replacements =
1609-
collapseGenericOpIterationDims(genericOp, collapsableIterationDims,
1610-
rewriter);
1627+
collapseOpIterationDims<linalg::GenericOp>(
1628+
genericOp, collapsableIterationDims, rewriter);
16111629
if (!replacements) {
16121630
return rewriter.notifyMatchFailure(
16131631
genericOp, "failed to do the fusion by collapsing transformation");
@@ -1624,36 +1642,36 @@ class FoldWithProducerReshapeOpByCollapsing
16241642
};
16251643

16261644
/// Pattern to collapse dimensions.
1627-
class CollapseLinalgDimensions : public OpRewritePattern<GenericOp> {
1645+
template <typename LinalgType>
1646+
class CollapseLinalgDimensions : public OpRewritePattern<LinalgType> {
16281647
public:
16291648
CollapseLinalgDimensions(MLIRContext *context,
16301649
GetCollapsableDimensionsFn collapseDimensions,
16311650
PatternBenefit benefit = 1)
1632-
: OpRewritePattern<GenericOp>(context, benefit),
1651+
: OpRewritePattern<LinalgType>(context, benefit),
16331652
controlCollapseDimension(std::move(collapseDimensions)) {}
16341653

1635-
LogicalResult matchAndRewrite(GenericOp genericOp,
1654+
LogicalResult matchAndRewrite(LinalgType op,
16361655
PatternRewriter &rewriter) const override {
16371656
SmallVector<ReassociationIndices> collapsableIterationDims =
1638-
controlCollapseDimension(genericOp);
1657+
controlCollapseDimension(op);
16391658
if (collapsableIterationDims.empty())
16401659
return failure();
16411660

16421661
// Check if the specified list of dimensions to collapse is a valid list.
1643-
if (!areDimSequencesPreserved(genericOp.getIndexingMapsArray(),
1662+
if (!areDimSequencesPreserved(op.getIndexingMapsArray(),
16441663
collapsableIterationDims)) {
16451664
return rewriter.notifyMatchFailure(
1646-
genericOp, "specified dimensions cannot be collapsed");
1665+
op, "specified dimensions cannot be collapsed");
16471666
}
16481667

16491668
std::optional<SmallVector<Value>> replacements =
1650-
collapseGenericOpIterationDims(genericOp, collapsableIterationDims,
1651-
rewriter);
1669+
collapseOpIterationDims<LinalgType>(op, collapsableIterationDims,
1670+
rewriter);
16521671
if (!replacements) {
1653-
return rewriter.notifyMatchFailure(genericOp,
1654-
"failed to collapse dimensions");
1672+
return rewriter.notifyMatchFailure(op, "failed to collapse dimensions");
16551673
}
1656-
rewriter.replaceOp(genericOp, *replacements);
1674+
rewriter.replaceOp(op, *replacements);
16571675
return success();
16581676
}
16591677

@@ -1884,8 +1902,9 @@ void mlir::linalg::populateElementwiseOpsFusionPatterns(
18841902
void mlir::linalg::populateCollapseDimensions(
18851903
RewritePatternSet &patterns,
18861904
const GetCollapsableDimensionsFn &controlCollapseDimensions) {
1887-
patterns.add<CollapseLinalgDimensions>(patterns.getContext(),
1888-
controlCollapseDimensions);
1905+
patterns.add<CollapseLinalgDimensions<linalg::GenericOp>,
1906+
CollapseLinalgDimensions<linalg::CopyOp>>(
1907+
patterns.getContext(), controlCollapseDimensions);
18891908
}
18901909

18911910
//===---------------------------------------------------------------------===//

mlir/test/Dialect/Linalg/collapse-dim.mlir

Lines changed: 37 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -116,3 +116,40 @@ func.func @uncollapsable_strided_memref(%arg0: memref<2x6x24x48xi32>, %arg1: mem
116116
}
117117
return %alloc : memref<2x6x24x48xi32>
118118
}
119+
120+
// -----
121+
122+
// CHECK-LABEL: func.func @linalg_copy(
123+
// CHECK-SAME: %[[VAL_0:.*]]: tensor<1x2x3x4x5xf32, 1 : i64>,
124+
// CHECK-SAME: %[[VAL_1:.*]]: tensor<1x2x3x4x5xf32, 3 : i64>) -> tensor<1x2x3x4x5xf32, 3 : i64> {
125+
// CHECK: %[[VAL_2:.*]] = tensor.collapse_shape %[[VAL_0]] {{\[\[}}0], [1], [2, 3], [4]] : tensor<1x2x3x4x5xf32, 1 : i64> into tensor<1x2x12x5xf32>
126+
// CHECK: %[[VAL_3:.*]] = tensor.collapse_shape %[[VAL_1]] {{\[\[}}0], [1], [2, 3], [4]] : tensor<1x2x3x4x5xf32, 3 : i64> into tensor<1x2x12x5xf32>
127+
// CHECK: %[[VAL_4:.*]] = tensor.collapse_shape %[[VAL_2]] {{\[\[}}0], [1], [2, 3]] : tensor<1x2x12x5xf32> into tensor<1x2x60xf32>
128+
// CHECK: %[[VAL_5:.*]] = tensor.collapse_shape %[[VAL_3]] {{\[\[}}0], [1], [2, 3]] : tensor<1x2x12x5xf32> into tensor<1x2x60xf32>
129+
// CHECK: %[[VAL_6:.*]] = linalg.copy ins(%[[VAL_4]] : tensor<1x2x60xf32>) outs(%[[VAL_5]] : tensor<1x2x60xf32>) -> tensor<1x2x60xf32>
130+
// CHECK: %[[VAL_7:.*]] = tensor.expand_shape %[[VAL_6]] {{\[\[}}0], [1], [2, 3]] : tensor<1x2x60xf32> into tensor<1x2x12x5xf32>
131+
// CHECK: %[[VAL_8:.*]] = tensor.expand_shape %[[VAL_7]] {{\[\[}}0], [1], [2, 3], [4]] : tensor<1x2x12x5xf32> into tensor<1x2x3x4x5xf32, 3 : i64>
132+
// CHECK: return %[[VAL_8]] : tensor<1x2x3x4x5xf32, 3 : i64>
133+
// CHECK: }
134+
135+
func.func @linalg_copy(
136+
%arg0: tensor<1x2x3x4x5xf32, 1>, %arg1: tensor<1x2x3x4x5xf32, 3>) -> tensor<1x2x3x4x5xf32, 3> {
137+
%0 = linalg.copy ins(%arg0: tensor<1x2x3x4x5xf32, 1>) outs(%arg1: tensor<1x2x3x4x5xf32, 3>) -> tensor<1x2x3x4x5xf32, 3>
138+
return %0 : tensor<1x2x3x4x5xf32, 3>
139+
}
140+
141+
// -----
142+
143+
// CHECK-LABEL: func.func private @memref_linalg_copy(
144+
// CHECK-SAME: %[[VAL_0:.*]]: memref<1x24x32x8xf32, 1>,
145+
// CHECK-SAME: %[[VAL_1:.*]]: memref<1x24x32x8xf32, 1>) {
146+
// CHECK: %[[VAL_2:.*]] = memref.collapse_shape %[[VAL_0]] {{\[\[}}0], [1], [2, 3]] : memref<1x24x32x8xf32, 1> into memref<1x24x256xf32, 1>
147+
// CHECK: %[[VAL_3:.*]] = memref.collapse_shape %[[VAL_1]] {{\[\[}}0], [1], [2, 3]] : memref<1x24x32x8xf32, 1> into memref<1x24x256xf32, 1>
148+
// CHECK: linalg.copy ins(%[[VAL_2]] : memref<1x24x256xf32, 1>) outs(%[[VAL_3]] : memref<1x24x256xf32, 1>)
149+
// CHECK: return
150+
// CHECK: }
151+
152+
func.func private @memref_linalg_copy(%arg0: memref<1x24x32x8xf32, 1>, %arg1: memref<1x24x32x8xf32, 1>) {
153+
linalg.copy ins(%arg0: memref<1x24x32x8xf32, 1>) outs(%arg1: memref<1x24x32x8xf32, 1>)
154+
return
155+
}

mlir/test/lib/Dialect/Linalg/TestLinalgElementwiseFusion.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -258,7 +258,7 @@ struct TestLinalgElementwiseFusion
258258
SmallVector<int64_t, 2> dims(collapseDimensions.begin(),
259259
collapseDimensions.end());
260260
linalg::GetCollapsableDimensionsFn collapseFn =
261-
[&dims](linalg::GenericOp op) {
261+
[&dims](linalg::LinalgOp op) {
262262
SmallVector<ReassociationIndices> reassociations;
263263
reassociations.emplace_back(dims);
264264
return reassociations;

0 commit comments

Comments
 (0)