Skip to content

Commit 3f18f6a

Browse files
authored
[mlir][linalg] Enable fusion by expansion of reduction and named ops (llvm#83473)
This adds support for expansion of named linalg ops and linalg ops with reduction iterators. This improves the ability to make fusion decisions WRT reduction operations. To recover the previous behavior, users of the patterns can add a control function to restrict propagation of reshape by expansion through linalg ops with reduction iterators. For named linalg ops, this always converts the named op into a generic.
1 parent 37293e6 commit 3f18f6a

File tree

2 files changed

+219
-50
lines changed

2 files changed

+219
-50
lines changed

mlir/lib/Dialect/Linalg/Transforms/ElementwiseOpFusion.cpp

Lines changed: 58 additions & 50 deletions
Original file line numberDiff line numberDiff line change
@@ -467,10 +467,10 @@ class FuseElementwiseOps : public OpRewritePattern<GenericOp> {
467467
// expanding the dimensionality of the elementwise operations.
468468
//===---------------------------------------------------------------------===//
469469

470-
/// Conditions for folding a generic operation with a reshape op by expanding
471-
/// the iteration space dimensionality for tensor operations. These are
472-
/// preconditions assumed by `foldReshapeByDimExpansion` which implements the
473-
/// following fusion pattern.
470+
/// Conditions for folding a structured linalg operation with a reshape op by
471+
/// expanding the iteration space dimensionality for tensor operations. These
472+
/// are preconditions assumed by `foldReshapeByDimExpansion` which implements
473+
/// the following fusion pattern.
474474
///
475475
/// Consider
476476
///
@@ -481,9 +481,9 @@ class FuseElementwiseOps : public OpRewritePattern<GenericOp> {
481481
/// %d = tensor.expand_shape %c [[0, 1], [2], [3, 4, 5]]
482482
/// : tensor<?x?x?xf32> into tensor<?x?x?x?x?x?xf32>
483483
///
484-
/// The reshape can be folded into the `genericOp` if its loop dimensionality
484+
/// The reshape can be folded into the `linalgOp` if its loop dimensionality
485485
/// is increased to match the result (operand) of the tensor.expand_shape.
486-
/// The indexing_map of the fused tensor in the `genericOp` and the
486+
/// The indexing_map of the fused tensor in the `linalgOp` and the
487487
/// reassociation map helps compute the indexing maps of the modified op.
488488
/// For the above example, based on the reassociation map it
489489
/// can be concluded that
@@ -502,7 +502,7 @@ class FuseElementwiseOps : public OpRewritePattern<GenericOp> {
502502
/// d1 -> e2, e3, e4
503503
/// d2 -> e5
504504
///
505-
/// substituting this, the generic op can be rewritten as
505+
/// substituting this, the structured op can be rewritten as
506506
///
507507
/// %d = linalg.generic ins(%0, %1 : )
508508
/// indexing_maps =
@@ -520,23 +520,28 @@ class FuseElementwiseOps : public OpRewritePattern<GenericOp> {
520520
///
521521
/// The added reshapes are again expanding patterns, so they will get fused
522522
/// with its producers if possible.
523-
static bool isFusableWithReshapeByDimExpansion(GenericOp genericOp,
523+
static bool isFusableWithReshapeByDimExpansion(LinalgOp linalgOp,
524524
OpOperand *fusableOpOperand) {
525525
// Is fusable only if:
526526
// - All the indexing maps for operands and results are projected
527527
// permutations.
528528
// - The fused tensor is not a scalar.
529-
// - All the loops are parallel loops.
530-
return genericOp.hasPureTensorSemantics() &&
531-
llvm::all_of(genericOp.getIndexingMaps().getValue(),
529+
// - All the loops for the reshaped operand are parallel loops.
530+
SmallVector<utils::IteratorType> iteratorTypes =
531+
linalgOp.getIteratorTypesArray();
532+
AffineMap operandMap = linalgOp.getMatchingIndexingMap(fusableOpOperand);
533+
return linalgOp.hasPureTensorSemantics() &&
534+
llvm::all_of(linalgOp.getIndexingMaps().getValue(),
532535
[](Attribute attr) {
533536
return cast<AffineMapAttr>(attr)
534537
.getValue()
535538
.isProjectedPermutation();
536539
}) &&
537-
genericOp.getMatchingIndexingMap(fusableOpOperand).getNumResults() >
538-
0 &&
539-
llvm::all_of(genericOp.getIteratorTypesArray(), isParallelIterator);
540+
operandMap.getNumResults() > 0 &&
541+
llvm::all_of(operandMap.getResults(), [&](AffineExpr expr) {
542+
return isParallelIterator(
543+
iteratorTypes[cast<AffineDimExpr>(expr).getPosition()]);
544+
});
540545
}
541546

542547
namespace {
@@ -628,10 +633,10 @@ LogicalResult ExpansionInfo::compute(LinalgOp linalgOp,
628633
/// Note that this could be extended to handle dynamic case, but the
629634
/// implementation below uses `affine.apply` which seems to have issues when the
630635
/// shapes are not static.
631-
static LogicalResult isGenericOpExpandable(GenericOp genericOp,
632-
const ExpansionInfo &expansionInfo,
633-
PatternRewriter &rewriter) {
634-
if (!genericOp.hasIndexSemantics())
636+
static LogicalResult isLinalgOpExpandable(LinalgOp linalgOp,
637+
const ExpansionInfo &expansionInfo,
638+
PatternRewriter &rewriter) {
639+
if (!linalgOp.hasIndexSemantics())
635640
return success();
636641
for (unsigned i : llvm::seq<unsigned>(0, expansionInfo.getOrigOpNumDims())) {
637642
ArrayRef<int64_t> expandedShape = expansionInfo.getExpandedShapeOfDim(i);
@@ -640,7 +645,7 @@ static LogicalResult isGenericOpExpandable(GenericOp genericOp,
640645
for (int64_t shape : expandedShape.drop_front()) {
641646
if (ShapedType::isDynamic(shape)) {
642647
return rewriter.notifyMatchFailure(
643-
genericOp, "cannot expand due to index semantics and dynamic dims");
648+
linalgOp, "cannot expand due to index semantics and dynamic dims");
644649
}
645650
}
646651
}
@@ -749,10 +754,10 @@ static void updateExpandedGenericOpRegion(PatternRewriter &rewriter,
749754
/// and a generic op as explained in `isFusableWithReshapeByExpansion`. Assumes
750755
/// that those conditions have been satisfied.
751756
static std::optional<SmallVector<Value>>
752-
fuseWithReshapeByExpansion(GenericOp genericOp, Operation *reshapeOp,
757+
fuseWithReshapeByExpansion(LinalgOp linalgOp, Operation *reshapeOp,
753758
OpOperand *fusableOpOperand,
754759
PatternRewriter &rewriter) {
755-
assert(isFusableWithReshapeByDimExpansion(genericOp, fusableOpOperand) &&
760+
assert(isFusableWithReshapeByDimExpansion(linalgOp, fusableOpOperand) &&
756761
"preconditions for fuse operation failed");
757762
// Check if reshape is expanding or collapsing.
758763
auto expandingReshapeOp = dyn_cast<tensor::ExpandShapeOp>(*reshapeOp);
@@ -767,35 +772,35 @@ fuseWithReshapeByExpansion(GenericOp genericOp, Operation *reshapeOp,
767772

768773
ExpansionInfo expansionInfo;
769774
if (failed(expansionInfo.compute(
770-
genericOp, fusableOpOperand,
775+
linalgOp, fusableOpOperand,
771776
isExpanding ? expandingReshapeOp.getReassociationMaps()
772777
: collapsingReshapeOp.getReassociationMaps(),
773778
expandedType.getShape(), collapsedType.getShape(), rewriter)))
774779
return std::nullopt;
775780

776-
if (failed(isGenericOpExpandable(genericOp, expansionInfo, rewriter)))
781+
if (failed(isLinalgOpExpandable(linalgOp, expansionInfo, rewriter)))
777782
return std::nullopt;
778783

779784
SmallVector<AffineMap, 4> expandedOpIndexingMaps = llvm::to_vector<4>(
780-
llvm::map_range(genericOp.getIndexingMapsArray(), [&](AffineMap m) {
785+
llvm::map_range(linalgOp.getIndexingMapsArray(), [&](AffineMap m) {
781786
return getIndexingMapInExpandedOp(rewriter, m, expansionInfo);
782787
}));
783788

784789
// Set insertion point to the generic op.
785790
OpBuilder::InsertionGuard g(rewriter);
786-
rewriter.setInsertionPoint(genericOp);
791+
rewriter.setInsertionPoint(linalgOp);
787792

788793
SmallVector<Value> expandedOpOperands;
789-
expandedOpOperands.reserve(genericOp.getNumDpsInputs());
790-
for (OpOperand *opOperand : genericOp.getDpsInputOperands()) {
794+
expandedOpOperands.reserve(linalgOp.getNumDpsInputs());
795+
for (OpOperand *opOperand : linalgOp.getDpsInputOperands()) {
791796
if (opOperand == fusableOpOperand) {
792797
expandedOpOperands.push_back(isExpanding ? expandingReshapeOp.getSrc()
793798
: collapsingReshapeOp.getSrc());
794799
continue;
795800
}
796801
if (auto opOperandType =
797802
dyn_cast<RankedTensorType>(opOperand->get().getType())) {
798-
AffineMap indexingMap = genericOp.getMatchingIndexingMap(opOperand);
803+
AffineMap indexingMap = linalgOp.getMatchingIndexingMap(opOperand);
799804
RankedTensorType expandedOperandType =
800805
getExpandedType(opOperandType, indexingMap, expansionInfo);
801806
if (expandedOperandType != opOperand->get().getType()) {
@@ -804,25 +809,25 @@ fuseWithReshapeByExpansion(GenericOp genericOp, Operation *reshapeOp,
804809
getReassociationForExpansion(indexingMap, expansionInfo);
805810
if (failed(reshapeLikeShapesAreCompatible(
806811
[&](const Twine &msg) {
807-
return rewriter.notifyMatchFailure(genericOp, msg);
812+
return rewriter.notifyMatchFailure(linalgOp, msg);
808813
},
809814
opOperandType.getShape(), expandedOperandType.getShape(),
810815
reassociation,
811816
/*isExpandingReshape=*/true)))
812817
return std::nullopt;
813818
expandedOpOperands.push_back(rewriter.create<tensor::ExpandShapeOp>(
814-
genericOp.getLoc(), expandedOperandType, opOperand->get(),
819+
linalgOp.getLoc(), expandedOperandType, opOperand->get(),
815820
reassociation));
816821
continue;
817822
}
818823
}
819824
expandedOpOperands.push_back(opOperand->get());
820825
}
821826

822-
Location loc = genericOp.getLoc();
827+
Location loc = linalgOp.getLoc();
823828
SmallVector<Value> outputs;
824-
for (OpOperand &opOperand : genericOp.getDpsInitsMutable()) {
825-
AffineMap indexingMap = genericOp.getMatchingIndexingMap(&opOperand);
829+
for (OpOperand &opOperand : linalgOp.getDpsInitsMutable()) {
830+
AffineMap indexingMap = linalgOp.getMatchingIndexingMap(&opOperand);
826831
auto opOperandType = cast<RankedTensorType>(opOperand.get().getType());
827832
RankedTensorType expandedOutputType =
828833
getExpandedType(opOperandType, indexingMap, expansionInfo);
@@ -831,14 +836,14 @@ fuseWithReshapeByExpansion(GenericOp genericOp, Operation *reshapeOp,
831836
getReassociationForExpansion(indexingMap, expansionInfo);
832837
if (failed(reshapeLikeShapesAreCompatible(
833838
[&](const Twine &msg) {
834-
return rewriter.notifyMatchFailure(genericOp, msg);
839+
return rewriter.notifyMatchFailure(linalgOp, msg);
835840
},
836841
opOperandType.getShape(), expandedOutputType.getShape(),
837842
reassociation,
838843
/*isExpandingReshape=*/true)))
839844
return std::nullopt;
840845
outputs.push_back(rewriter.create<tensor::ExpandShapeOp>(
841-
genericOp.getLoc(), expandedOutputType, opOperand.get(),
846+
linalgOp.getLoc(), expandedOutputType, opOperand.get(),
842847
reassociation));
843848
} else {
844849
outputs.push_back(opOperand.get());
@@ -848,14 +853,17 @@ fuseWithReshapeByExpansion(GenericOp genericOp, Operation *reshapeOp,
848853
// The iterator types of the expanded op are all parallel.
849854
SmallVector<utils::IteratorType> iteratorTypes(
850855
expansionInfo.getExpandedOpNumDims(), utils::IteratorType::parallel);
856+
for (auto [i, type] : llvm::enumerate(linalgOp.getIteratorTypesArray()))
857+
for (auto j : expansionInfo.getExpandedDims(i))
858+
iteratorTypes[j] = type;
851859

852860
TypeRange resultTypes = ValueRange(outputs).getTypes();
853861
auto fusedOp =
854-
rewriter.create<GenericOp>(genericOp.getLoc(), resultTypes,
862+
rewriter.create<GenericOp>(linalgOp.getLoc(), resultTypes,
855863
/*inputs=*/expandedOpOperands, outputs,
856864
expandedOpIndexingMaps, iteratorTypes);
857865
Region &fusedRegion = fusedOp->getRegion(0);
858-
Region &originalRegion = genericOp->getRegion(0);
866+
Region &originalRegion = linalgOp->getRegion(0);
859867
rewriter.cloneRegionBefore(originalRegion, fusedRegion, fusedRegion.begin());
860868

861869
// Update the index accesses after the expansion.
@@ -864,16 +872,16 @@ fuseWithReshapeByExpansion(GenericOp genericOp, Operation *reshapeOp,
864872
// Reshape the result values to their original shape if this is a collapsing
865873
// reshape folded into its consumer.
866874
SmallVector<Value> resultVals;
867-
for (OpResult opResult : genericOp->getOpResults()) {
875+
for (OpResult opResult : linalgOp->getOpResults()) {
868876
int64_t resultNumber = opResult.getResultNumber();
869877
if (resultTypes[resultNumber] != opResult.getType()) {
870878
SmallVector<ReassociationIndices> reassociation =
871879
getReassociationForExpansion(
872-
genericOp.getMatchingIndexingMap(
873-
genericOp.getDpsInitOperand(resultNumber)),
880+
linalgOp.getMatchingIndexingMap(
881+
linalgOp.getDpsInitOperand(resultNumber)),
874882
expansionInfo);
875883
resultVals.push_back(rewriter.create<tensor::CollapseShapeOp>(
876-
genericOp.getLoc(), opResult.getType(),
884+
linalgOp.getLoc(), opResult.getType(),
877885
fusedOp->getResult(resultNumber), reassociation));
878886
} else {
879887
resultVals.push_back(fusedOp->getResult(resultNumber));
@@ -885,37 +893,37 @@ fuseWithReshapeByExpansion(GenericOp genericOp, Operation *reshapeOp,
885893

886894
namespace {
887895

888-
/// Pattern to fuse a tensor.collapse_shape op with its consumer generic op,
896+
/// Pattern to fuse a tensor.collapse_shape op with its consumer structured op,
889897
/// when the reshape op is collapsing dimensions. The dimensionality of the loop
890898
/// in the consumer is expanded.
891899
class FoldWithProducerReshapeOpByExpansion
892-
: public OpRewritePattern<GenericOp> {
900+
: public OpInterfaceRewritePattern<LinalgOp> {
893901
public:
894902
FoldWithProducerReshapeOpByExpansion(MLIRContext *context,
895903
ControlFusionFn foldReshapes,
896904
PatternBenefit benefit = 1)
897-
: OpRewritePattern<GenericOp>(context, benefit),
905+
: OpInterfaceRewritePattern<LinalgOp>(context, benefit),
898906
controlFoldingReshapes(std::move(foldReshapes)) {}
899907

900-
LogicalResult matchAndRewrite(GenericOp genericOp,
908+
LogicalResult matchAndRewrite(LinalgOp linalgOp,
901909
PatternRewriter &rewriter) const override {
902-
for (OpOperand *opOperand : genericOp.getDpsInputOperands()) {
910+
for (OpOperand *opOperand : linalgOp.getDpsInputOperands()) {
903911
tensor::CollapseShapeOp reshapeOp =
904912
opOperand->get().getDefiningOp<tensor::CollapseShapeOp>();
905913
if (!reshapeOp)
906914
continue;
907915
// Fold only if
908916
// - The tensor reshape op is folding.
909917
// - All constraints of fusing with reshape by expansion are met.
910-
if (!isFusableWithReshapeByDimExpansion(genericOp, opOperand) ||
918+
if (!isFusableWithReshapeByDimExpansion(linalgOp, opOperand) ||
911919
(!controlFoldingReshapes(opOperand)))
912920
continue;
913921

914922
std::optional<SmallVector<Value>> replacementValues =
915-
fuseWithReshapeByExpansion(genericOp, reshapeOp, opOperand, rewriter);
923+
fuseWithReshapeByExpansion(linalgOp, reshapeOp, opOperand, rewriter);
916924
if (!replacementValues)
917925
return failure();
918-
rewriter.replaceOp(genericOp, *replacementValues);
926+
rewriter.replaceOp(linalgOp, *replacementValues);
919927
return success();
920928
}
921929
return failure();
@@ -945,7 +953,7 @@ struct FoldReshapeWithGenericOpByExpansion
945953
"source not produced by an operation");
946954
}
947955

948-
auto producer = dyn_cast<GenericOp>(producerResult.getOwner());
956+
auto producer = dyn_cast<LinalgOp>(producerResult.getOwner());
949957
if (!producer) {
950958
return rewriter.notifyMatchFailure(reshapeOp,
951959
"producer not a generic op");

0 commit comments

Comments
 (0)