@@ -467,10 +467,10 @@ class FuseElementwiseOps : public OpRewritePattern<GenericOp> {
467
467
// expanding the dimensionality of the elementwise operations.
468
468
// ===---------------------------------------------------------------------===//
469
469
470
- // / Conditions for folding a generic operation with a reshape op by expanding
471
- // / the iteration space dimensionality for tensor operations. These are
472
- // / preconditions assumed by `foldReshapeByDimExpansion` which implements the
473
- // / following fusion pattern.
470
+ // / Conditions for folding a structured linalg operation with a reshape op by
471
+ // / expanding the iteration space dimensionality for tensor operations. These
472
+ // / are preconditions assumed by `foldReshapeByDimExpansion` which implements
473
+ // / the following fusion pattern.
474
474
// /
475
475
// / Consider
476
476
// /
@@ -481,9 +481,9 @@ class FuseElementwiseOps : public OpRewritePattern<GenericOp> {
481
481
// / %d = tensor.expand_shape %c [[0, 1], [2], [3, 4, 5]]
482
482
// / : tensor<?x?x?xf32> into tensor<?x?x?x?x?x?xf32>
483
483
// /
484
- // / The reshape can be folded into the `genericOp ` if its loop dimensionality
484
+ // / The reshape can be folded into the `linalgOp ` if its loop dimensionality
485
485
// / is increased to match the result (operand) of the tensor.expand_shape.
486
- // / The indexing_map of the fused tensor in the `genericOp ` and the
486
+ // / The indexing_map of the fused tensor in the `linalgOp ` and the
487
487
// / reassociation map helps compute the indexing maps of the modified op.
488
488
// / For the above example, based on the reassociation map it
489
489
// / can be concluded that
@@ -502,7 +502,7 @@ class FuseElementwiseOps : public OpRewritePattern<GenericOp> {
502
502
// / d1 -> e2, e3, e4
503
503
// / d2 -> e5
504
504
// /
505
- // / substituting this, the generic op can be rewritten as
505
+ // / substituting this, the structured op can be rewritten as
506
506
// /
507
507
// / %d = linalg.generic ins(%0, %1 : )
508
508
// / indexing_maps =
@@ -520,23 +520,28 @@ class FuseElementwiseOps : public OpRewritePattern<GenericOp> {
520
520
// /
521
521
// / The added reshapes are again expanding patterns, so they will get fused
522
522
// / with its producers if possible.
523
- static bool isFusableWithReshapeByDimExpansion (GenericOp genericOp ,
523
+ static bool isFusableWithReshapeByDimExpansion (LinalgOp linalgOp ,
524
524
OpOperand *fusableOpOperand) {
525
525
// Is fusable only if:
526
526
// - All the indexing maps for operands and results are projected
527
527
// permutations.
528
528
// - The fused tensor is not a scalar.
529
- // - All the loops are parallel loops.
530
- return genericOp.hasPureTensorSemantics () &&
531
- llvm::all_of (genericOp.getIndexingMaps ().getValue (),
529
+ // - All the loops for the reshaped operand are parallel loops.
530
+ SmallVector<utils::IteratorType> iteratorTypes =
531
+ linalgOp.getIteratorTypesArray ();
532
+ AffineMap operandMap = linalgOp.getMatchingIndexingMap (fusableOpOperand);
533
+ return linalgOp.hasPureTensorSemantics () &&
534
+ llvm::all_of (linalgOp.getIndexingMaps ().getValue (),
532
535
[](Attribute attr) {
533
536
return cast<AffineMapAttr>(attr)
534
537
.getValue ()
535
538
.isProjectedPermutation ();
536
539
}) &&
537
- genericOp.getMatchingIndexingMap (fusableOpOperand).getNumResults () >
538
- 0 &&
539
- llvm::all_of (genericOp.getIteratorTypesArray (), isParallelIterator);
540
+ operandMap.getNumResults () > 0 &&
541
+ llvm::all_of (operandMap.getResults (), [&](AffineExpr expr) {
542
+ return isParallelIterator (
543
+ iteratorTypes[cast<AffineDimExpr>(expr).getPosition ()]);
544
+ });
540
545
}
541
546
542
547
namespace {
@@ -628,10 +633,10 @@ LogicalResult ExpansionInfo::compute(LinalgOp linalgOp,
628
633
// / Note that this could be extended to handle dynamic case, but the
629
634
// / implementation below uses `affine.apply` which seems to have issues when the
630
635
// / shapes are not static.
631
- static LogicalResult isGenericOpExpandable (GenericOp genericOp ,
632
- const ExpansionInfo &expansionInfo,
633
- PatternRewriter &rewriter) {
634
- if (!genericOp .hasIndexSemantics ())
636
+ static LogicalResult isLinalgOpExpandable (LinalgOp linalgOp ,
637
+ const ExpansionInfo &expansionInfo,
638
+ PatternRewriter &rewriter) {
639
+ if (!linalgOp .hasIndexSemantics ())
635
640
return success ();
636
641
for (unsigned i : llvm::seq<unsigned >(0 , expansionInfo.getOrigOpNumDims ())) {
637
642
ArrayRef<int64_t > expandedShape = expansionInfo.getExpandedShapeOfDim (i);
@@ -640,7 +645,7 @@ static LogicalResult isGenericOpExpandable(GenericOp genericOp,
640
645
for (int64_t shape : expandedShape.drop_front ()) {
641
646
if (ShapedType::isDynamic (shape)) {
642
647
return rewriter.notifyMatchFailure (
643
- genericOp , " cannot expand due to index semantics and dynamic dims" );
648
+ linalgOp , " cannot expand due to index semantics and dynamic dims" );
644
649
}
645
650
}
646
651
}
@@ -749,10 +754,10 @@ static void updateExpandedGenericOpRegion(PatternRewriter &rewriter,
749
754
// / and a generic op as explained in `isFusableWithReshapeByExpansion`. Assumes
750
755
// / that those conditions have been satisfied.
751
756
static std::optional<SmallVector<Value>>
752
- fuseWithReshapeByExpansion (GenericOp genericOp , Operation *reshapeOp,
757
+ fuseWithReshapeByExpansion (LinalgOp linalgOp , Operation *reshapeOp,
753
758
OpOperand *fusableOpOperand,
754
759
PatternRewriter &rewriter) {
755
- assert (isFusableWithReshapeByDimExpansion (genericOp , fusableOpOperand) &&
760
+ assert (isFusableWithReshapeByDimExpansion (linalgOp , fusableOpOperand) &&
756
761
" preconditions for fuse operation failed" );
757
762
// Check if reshape is expanding or collapsing.
758
763
auto expandingReshapeOp = dyn_cast<tensor::ExpandShapeOp>(*reshapeOp);
@@ -767,35 +772,35 @@ fuseWithReshapeByExpansion(GenericOp genericOp, Operation *reshapeOp,
767
772
768
773
ExpansionInfo expansionInfo;
769
774
if (failed (expansionInfo.compute (
770
- genericOp , fusableOpOperand,
775
+ linalgOp , fusableOpOperand,
771
776
isExpanding ? expandingReshapeOp.getReassociationMaps ()
772
777
: collapsingReshapeOp.getReassociationMaps (),
773
778
expandedType.getShape (), collapsedType.getShape (), rewriter)))
774
779
return std::nullopt;
775
780
776
- if (failed (isGenericOpExpandable (genericOp , expansionInfo, rewriter)))
781
+ if (failed (isLinalgOpExpandable (linalgOp , expansionInfo, rewriter)))
777
782
return std::nullopt;
778
783
779
784
SmallVector<AffineMap, 4 > expandedOpIndexingMaps = llvm::to_vector<4 >(
780
- llvm::map_range (genericOp .getIndexingMapsArray (), [&](AffineMap m) {
785
+ llvm::map_range (linalgOp .getIndexingMapsArray (), [&](AffineMap m) {
781
786
return getIndexingMapInExpandedOp (rewriter, m, expansionInfo);
782
787
}));
783
788
784
789
// Set insertion point to the generic op.
785
790
OpBuilder::InsertionGuard g (rewriter);
786
- rewriter.setInsertionPoint (genericOp );
791
+ rewriter.setInsertionPoint (linalgOp );
787
792
788
793
SmallVector<Value> expandedOpOperands;
789
- expandedOpOperands.reserve (genericOp .getNumDpsInputs ());
790
- for (OpOperand *opOperand : genericOp .getDpsInputOperands ()) {
794
+ expandedOpOperands.reserve (linalgOp .getNumDpsInputs ());
795
+ for (OpOperand *opOperand : linalgOp .getDpsInputOperands ()) {
791
796
if (opOperand == fusableOpOperand) {
792
797
expandedOpOperands.push_back (isExpanding ? expandingReshapeOp.getSrc ()
793
798
: collapsingReshapeOp.getSrc ());
794
799
continue ;
795
800
}
796
801
if (auto opOperandType =
797
802
dyn_cast<RankedTensorType>(opOperand->get ().getType ())) {
798
- AffineMap indexingMap = genericOp .getMatchingIndexingMap (opOperand);
803
+ AffineMap indexingMap = linalgOp .getMatchingIndexingMap (opOperand);
799
804
RankedTensorType expandedOperandType =
800
805
getExpandedType (opOperandType, indexingMap, expansionInfo);
801
806
if (expandedOperandType != opOperand->get ().getType ()) {
@@ -804,25 +809,25 @@ fuseWithReshapeByExpansion(GenericOp genericOp, Operation *reshapeOp,
804
809
getReassociationForExpansion (indexingMap, expansionInfo);
805
810
if (failed (reshapeLikeShapesAreCompatible (
806
811
[&](const Twine &msg) {
807
- return rewriter.notifyMatchFailure (genericOp , msg);
812
+ return rewriter.notifyMatchFailure (linalgOp , msg);
808
813
},
809
814
opOperandType.getShape (), expandedOperandType.getShape (),
810
815
reassociation,
811
816
/* isExpandingReshape=*/ true )))
812
817
return std::nullopt;
813
818
expandedOpOperands.push_back (rewriter.create <tensor::ExpandShapeOp>(
814
- genericOp .getLoc (), expandedOperandType, opOperand->get (),
819
+ linalgOp .getLoc (), expandedOperandType, opOperand->get (),
815
820
reassociation));
816
821
continue ;
817
822
}
818
823
}
819
824
expandedOpOperands.push_back (opOperand->get ());
820
825
}
821
826
822
- Location loc = genericOp .getLoc ();
827
+ Location loc = linalgOp .getLoc ();
823
828
SmallVector<Value> outputs;
824
- for (OpOperand &opOperand : genericOp .getDpsInitsMutable ()) {
825
- AffineMap indexingMap = genericOp .getMatchingIndexingMap (&opOperand);
829
+ for (OpOperand &opOperand : linalgOp .getDpsInitsMutable ()) {
830
+ AffineMap indexingMap = linalgOp .getMatchingIndexingMap (&opOperand);
826
831
auto opOperandType = cast<RankedTensorType>(opOperand.get ().getType ());
827
832
RankedTensorType expandedOutputType =
828
833
getExpandedType (opOperandType, indexingMap, expansionInfo);
@@ -831,14 +836,14 @@ fuseWithReshapeByExpansion(GenericOp genericOp, Operation *reshapeOp,
831
836
getReassociationForExpansion (indexingMap, expansionInfo);
832
837
if (failed (reshapeLikeShapesAreCompatible (
833
838
[&](const Twine &msg) {
834
- return rewriter.notifyMatchFailure (genericOp , msg);
839
+ return rewriter.notifyMatchFailure (linalgOp , msg);
835
840
},
836
841
opOperandType.getShape (), expandedOutputType.getShape (),
837
842
reassociation,
838
843
/* isExpandingReshape=*/ true )))
839
844
return std::nullopt;
840
845
outputs.push_back (rewriter.create <tensor::ExpandShapeOp>(
841
- genericOp .getLoc (), expandedOutputType, opOperand.get (),
846
+ linalgOp .getLoc (), expandedOutputType, opOperand.get (),
842
847
reassociation));
843
848
} else {
844
849
outputs.push_back (opOperand.get ());
@@ -848,14 +853,17 @@ fuseWithReshapeByExpansion(GenericOp genericOp, Operation *reshapeOp,
848
853
// The iterator types of the expanded op are all parallel.
849
854
SmallVector<utils::IteratorType> iteratorTypes (
850
855
expansionInfo.getExpandedOpNumDims (), utils::IteratorType::parallel);
856
+ for (auto [i, type] : llvm::enumerate (linalgOp.getIteratorTypesArray ()))
857
+ for (auto j : expansionInfo.getExpandedDims (i))
858
+ iteratorTypes[j] = type;
851
859
852
860
TypeRange resultTypes = ValueRange (outputs).getTypes ();
853
861
auto fusedOp =
854
- rewriter.create <GenericOp>(genericOp .getLoc (), resultTypes,
862
+ rewriter.create <GenericOp>(linalgOp .getLoc (), resultTypes,
855
863
/* inputs=*/ expandedOpOperands, outputs,
856
864
expandedOpIndexingMaps, iteratorTypes);
857
865
Region &fusedRegion = fusedOp->getRegion (0 );
858
- Region &originalRegion = genericOp ->getRegion (0 );
866
+ Region &originalRegion = linalgOp ->getRegion (0 );
859
867
rewriter.cloneRegionBefore (originalRegion, fusedRegion, fusedRegion.begin ());
860
868
861
869
// Update the index accesses after the expansion.
@@ -864,16 +872,16 @@ fuseWithReshapeByExpansion(GenericOp genericOp, Operation *reshapeOp,
864
872
// Reshape the result values to their original shape if this is a collapsing
865
873
// reshape folded into its consumer.
866
874
SmallVector<Value> resultVals;
867
- for (OpResult opResult : genericOp ->getOpResults ()) {
875
+ for (OpResult opResult : linalgOp ->getOpResults ()) {
868
876
int64_t resultNumber = opResult.getResultNumber ();
869
877
if (resultTypes[resultNumber] != opResult.getType ()) {
870
878
SmallVector<ReassociationIndices> reassociation =
871
879
getReassociationForExpansion (
872
- genericOp .getMatchingIndexingMap (
873
- genericOp .getDpsInitOperand (resultNumber)),
880
+ linalgOp .getMatchingIndexingMap (
881
+ linalgOp .getDpsInitOperand (resultNumber)),
874
882
expansionInfo);
875
883
resultVals.push_back (rewriter.create <tensor::CollapseShapeOp>(
876
- genericOp .getLoc (), opResult.getType (),
884
+ linalgOp .getLoc (), opResult.getType (),
877
885
fusedOp->getResult (resultNumber), reassociation));
878
886
} else {
879
887
resultVals.push_back (fusedOp->getResult (resultNumber));
@@ -885,37 +893,37 @@ fuseWithReshapeByExpansion(GenericOp genericOp, Operation *reshapeOp,
885
893
886
894
namespace {
887
895
888
- // / Pattern to fuse a tensor.collapse_shape op with its consumer generic op,
896
+ // / Pattern to fuse a tensor.collapse_shape op with its consumer structured op,
889
897
// / when the reshape op is collapsing dimensions. The dimensionality of the loop
890
898
// / in the consumer is expanded.
891
899
class FoldWithProducerReshapeOpByExpansion
892
- : public OpRewritePattern<GenericOp > {
900
+ : public OpInterfaceRewritePattern<LinalgOp > {
893
901
public:
894
902
FoldWithProducerReshapeOpByExpansion (MLIRContext *context,
895
903
ControlFusionFn foldReshapes,
896
904
PatternBenefit benefit = 1 )
897
- : OpRewritePattern<GenericOp >(context, benefit),
905
+ : OpInterfaceRewritePattern<LinalgOp >(context, benefit),
898
906
controlFoldingReshapes (std::move(foldReshapes)) {}
899
907
900
- LogicalResult matchAndRewrite (GenericOp genericOp ,
908
+ LogicalResult matchAndRewrite (LinalgOp linalgOp ,
901
909
PatternRewriter &rewriter) const override {
902
- for (OpOperand *opOperand : genericOp .getDpsInputOperands ()) {
910
+ for (OpOperand *opOperand : linalgOp .getDpsInputOperands ()) {
903
911
tensor::CollapseShapeOp reshapeOp =
904
912
opOperand->get ().getDefiningOp <tensor::CollapseShapeOp>();
905
913
if (!reshapeOp)
906
914
continue ;
907
915
// Fold only if
908
916
// - The tensor reshape op is folding.
909
917
// - All constraints of fusing with reshape by expansion are met.
910
- if (!isFusableWithReshapeByDimExpansion (genericOp , opOperand) ||
918
+ if (!isFusableWithReshapeByDimExpansion (linalgOp , opOperand) ||
911
919
(!controlFoldingReshapes (opOperand)))
912
920
continue ;
913
921
914
922
std::optional<SmallVector<Value>> replacementValues =
915
- fuseWithReshapeByExpansion (genericOp , reshapeOp, opOperand, rewriter);
923
+ fuseWithReshapeByExpansion (linalgOp , reshapeOp, opOperand, rewriter);
916
924
if (!replacementValues)
917
925
return failure ();
918
- rewriter.replaceOp (genericOp , *replacementValues);
926
+ rewriter.replaceOp (linalgOp , *replacementValues);
919
927
return success ();
920
928
}
921
929
return failure ();
@@ -945,7 +953,7 @@ struct FoldReshapeWithGenericOpByExpansion
945
953
" source not produced by an operation" );
946
954
}
947
955
948
- auto producer = dyn_cast<GenericOp >(producerResult.getOwner ());
956
+ auto producer = dyn_cast<LinalgOp >(producerResult.getOwner ());
949
957
if (!producer) {
950
958
return rewriter.notifyMatchFailure (reshapeOp,
951
959
" producer not a generic op" );
0 commit comments