25
25
#include " mlir/IR/PatternMatch.h"
26
26
#include " mlir/Support/LLVM.h"
27
27
#include " mlir/Transforms/GreedyPatternRewriteDriver.h"
28
+ #include " mlir/Transforms/RegionUtils.h"
28
29
#include < optional>
29
30
#include < utility>
30
31
@@ -590,44 +591,45 @@ class ExpansionInfo {
590
591
// the expanded op.
591
592
LogicalResult compute (LinalgOp linalgOp, OpOperand *fusableOpOperand,
592
593
ArrayRef<AffineMap> reassociationMaps,
593
- ArrayRef<int64_t > expandedShape,
594
- ArrayRef<int64_t > collapsedShape,
594
+ ArrayRef<OpFoldResult> expandedShape,
595
595
PatternRewriter &rewriter);
596
596
unsigned getOrigOpNumDims () const { return reassociation.size (); }
597
597
unsigned getExpandedOpNumDims () const { return expandedOpNumDims; }
598
598
ReassociationIndicesRef getExpandedDims (unsigned i) const {
599
599
return reassociation[i];
600
600
}
601
- ArrayRef<int64_t > getExpandedShapeOfDim (unsigned i) const {
601
+ ArrayRef<OpFoldResult > getExpandedShapeOfDim (unsigned i) const {
602
602
return expandedShapeMap[i];
603
603
}
604
- ArrayRef<int64_t > getOriginalShape () const { return originalLoopExtent; }
604
+ ArrayRef<OpFoldResult > getOriginalShape () const { return originalLoopExtent; }
605
605
606
606
private:
607
607
// / Reassociation from the dimensions in the original operation to the
608
608
// / dimension of the expanded operation.
609
609
SmallVector<ReassociationIndices> reassociation;
610
610
// / Mapping from extent of loops in the original operation, to the extent of
611
611
// / loops in the expanded operation.
612
- SmallVector<SmallVector<int64_t >> expandedShapeMap;
612
+ SmallVector<SmallVector<OpFoldResult >> expandedShapeMap;
613
613
// / Extent of the loop in the original operation.
614
- SmallVector<int64_t > originalLoopExtent;
614
+ SmallVector<OpFoldResult > originalLoopExtent;
615
615
unsigned expandedOpNumDims;
616
616
};
617
617
} // namespace
618
618
619
619
LogicalResult ExpansionInfo::compute (LinalgOp linalgOp,
620
620
OpOperand *fusableOpOperand,
621
621
ArrayRef<AffineMap> reassociationMaps,
622
- ArrayRef<int64_t > expandedShape,
623
- ArrayRef<int64_t > collapsedShape,
622
+ ArrayRef<OpFoldResult> expandedShape,
624
623
PatternRewriter &rewriter) {
625
624
if (reassociationMaps.empty ())
626
625
return failure ();
627
626
AffineMap fusedIndexMap = linalgOp.getMatchingIndexingMap (fusableOpOperand);
628
627
629
- SmallVector<int64_t , 4 > originalLoopRange = linalgOp.getStaticLoopRanges ();
630
- originalLoopExtent.assign (originalLoopRange.begin (), originalLoopRange.end ());
628
+ OpBuilder::InsertionGuard g (rewriter);
629
+ rewriter.setInsertionPoint (linalgOp);
630
+ originalLoopExtent = llvm::map_to_vector (
631
+ linalgOp.createLoopRanges (rewriter, linalgOp->getLoc ()),
632
+ [](Range r) { return r.size ; });
631
633
632
634
reassociation.clear ();
633
635
expandedShapeMap.clear ();
@@ -639,7 +641,7 @@ LogicalResult ExpansionInfo::compute(LinalgOp linalgOp,
639
641
unsigned pos = cast<AffineDimExpr>(resultExpr.value ()).getPosition ();
640
642
AffineMap foldedDims = reassociationMaps[resultExpr.index ()];
641
643
numExpandedDims[pos] = foldedDims.getNumResults ();
642
- ArrayRef<int64_t > shape =
644
+ ArrayRef<OpFoldResult > shape =
643
645
expandedShape.slice (foldedDims.getDimPosition (0 ), numExpandedDims[pos]);
644
646
expandedShapeMap[pos].assign (shape.begin (), shape.end ());
645
647
}
@@ -660,33 +662,6 @@ LogicalResult ExpansionInfo::compute(LinalgOp linalgOp,
660
662
return success ();
661
663
}
662
664
663
- // / Expanding the body of a linalg operation requires adaptations of the
664
- // / accessed loop indices. Specifically, access of indices in the original
665
- // / operation need to be replaced with linearizations of indices in the expanded
666
- // / op. That requires the shape of the expanded dimensions to be static (at
667
- // / least all but the most significant). For now check that these are all
668
- // / statically sized. Note that this could be extended to handle dynamic case,
669
- // / but the implementation below uses `affine.apply` which seems to have issues
670
- // / when the shapes are not static.
671
- static LogicalResult isLinalgOpExpandable (LinalgOp linalgOp,
672
- const ExpansionInfo &expansionInfo,
673
- PatternRewriter &rewriter) {
674
- if (!linalgOp.hasIndexSemantics ())
675
- return success ();
676
- for (unsigned i : llvm::seq<unsigned >(0 , expansionInfo.getOrigOpNumDims ())) {
677
- ArrayRef<int64_t > expandedShape = expansionInfo.getExpandedShapeOfDim (i);
678
- if (expandedShape.size () == 1 )
679
- continue ;
680
- for (int64_t shape : expandedShape.drop_front ()) {
681
- if (ShapedType::isDynamic (shape)) {
682
- return rewriter.notifyMatchFailure (
683
- linalgOp, " cannot expand due to index semantics and dynamic dims" );
684
- }
685
- }
686
- }
687
- return success ();
688
- }
689
-
690
665
// / Return the indexing map to use in the expanded op for a given the
691
666
// / `indexingMap` of the original operation.
692
667
static AffineMap
@@ -706,18 +681,23 @@ getIndexingMapInExpandedOp(OpBuilder &builder, AffineMap indexingMap,
706
681
builder.getContext ());
707
682
}
708
683
709
- // / Return the type of the operand/result to use in the expanded op given the
710
- // / type in the original op.
711
- static RankedTensorType getExpandedType ( RankedTensorType originalType,
712
- AffineMap indexingMap,
713
- const ExpansionInfo &expansionInfo) {
714
- SmallVector<int64_t > expandedShape;
684
+ // / Return the shape and type of the operand/result to use in the expanded op
685
+ // / given the type in the original op.
686
+ static std::tuple<SmallVector<OpFoldResult>, RankedTensorType>
687
+ getExpandedShapeAndType (RankedTensorType originalType, AffineMap indexingMap,
688
+ const ExpansionInfo &expansionInfo) {
689
+ SmallVector<OpFoldResult > expandedShape;
715
690
for (AffineExpr expr : indexingMap.getResults ()) {
716
691
unsigned dim = cast<AffineDimExpr>(expr).getPosition ();
717
- auto dimExpansion = expansionInfo.getExpandedShapeOfDim (dim);
692
+ ArrayRef<OpFoldResult> dimExpansion =
693
+ expansionInfo.getExpandedShapeOfDim (dim);
718
694
expandedShape.append (dimExpansion.begin (), dimExpansion.end ());
719
695
}
720
- return RankedTensorType::get (expandedShape, originalType.getElementType ());
696
+ SmallVector<int64_t > expandedStaticShape;
697
+ std::tie (expandedStaticShape, std::ignore) =
698
+ decomposeMixedValues (expandedShape);
699
+ return {expandedShape, RankedTensorType::get (expandedStaticShape,
700
+ originalType.getElementType ())};
721
701
}
722
702
723
703
// / Returns the reassociation maps to use in the `tensor.expand_shape`
@@ -765,49 +745,28 @@ static void updateExpandedGenericOpRegion(PatternRewriter &rewriter,
765
745
// Linearize the expanded indices of the original index dimension.
766
746
OpBuilder::InsertionGuard guard (rewriter);
767
747
rewriter.setInsertionPointAfter (indexOp);
768
- ArrayRef<int64_t > expandedDimsShape =
748
+ ArrayRef<OpFoldResult > expandedDimsShape =
769
749
expansionInfo.getExpandedShapeOfDim (indexOp.getDim ()).drop_front ();
770
750
SmallVector<Value> expandedIndices;
771
751
expandedIndices.reserve (expandedDims.size () - 1 );
772
752
llvm::transform (
773
753
expandedDims.drop_front (), std::back_inserter (expandedIndices),
774
754
[&](int64_t dim) { return rewriter.create <IndexOp>(loc, dim); });
775
- Value newIndex = rewriter.create <IndexOp>(loc, expandedDims.front ());
776
- for (auto it : llvm::zip (expandedDimsShape, expandedIndices)) {
777
- assert (!ShapedType::isDynamic (std::get<0 >(it)));
778
- AffineExpr idx, acc;
755
+ OpFoldResult newIndex =
756
+ rewriter.create <IndexOp>(loc, expandedDims.front ()).getResult ();
757
+ for (auto [expandedShape, expandedIndex] :
758
+ llvm::zip (expandedDimsShape, expandedIndices)) {
759
+ AffineExpr idx, acc, shape;
779
760
bindDims (rewriter.getContext (), idx, acc);
780
- newIndex = rewriter.create <affine::AffineApplyOp>(
781
- indexOp.getLoc (), idx + acc * std::get<0 >(it),
782
- ValueRange{std::get<1 >(it), newIndex});
783
- }
784
- rewriter.replaceOp (indexOp, newIndex);
785
- }
786
- }
787
-
788
- // / Checks if a single dynamic dimension expanded into multiple dynamic
789
- // / dimensions.
790
- static LogicalResult
791
- validateDynamicDimExpansion (LinalgOp linalgOp,
792
- const ExpansionInfo &expansionInfo,
793
- PatternRewriter &rewriter) {
794
- for (unsigned i : llvm::seq<unsigned >(0 , expansionInfo.getOrigOpNumDims ())) {
795
- ArrayRef<int64_t > expandedShape = expansionInfo.getExpandedShapeOfDim (i);
796
- if (expandedShape.size () == 1 )
797
- continue ;
798
- bool foundDynamic = false ;
799
- for (int64_t shape : expandedShape) {
800
- if (!ShapedType::isDynamic (shape))
801
- continue ;
802
- if (foundDynamic) {
803
- return rewriter.notifyMatchFailure (
804
- linalgOp, " cannot infer expanded shape with multiple dynamic "
805
- " dims in the same reassociation group" );
806
- }
807
- foundDynamic = true ;
761
+ bindSymbols (rewriter.getContext (), shape);
762
+ newIndex = affine::makeComposedFoldedAffineApply (
763
+ rewriter, indexOp.getLoc (), idx + acc * shape,
764
+ ArrayRef<OpFoldResult>{expandedIndex, newIndex, expandedShape});
808
765
}
766
+ Value newIndexVal =
767
+ getValueOrCreateConstantIndexOp (rewriter, indexOp.getLoc (), newIndex);
768
+ rewriter.replaceOp (indexOp, newIndexVal);
809
769
}
810
- return success ();
811
770
}
812
771
813
772
// Create an expanded transpose op.
@@ -910,31 +869,34 @@ fuseWithReshapeByExpansion(LinalgOp linalgOp, Operation *reshapeOp,
910
869
" preconditions for fuse operation failed" );
911
870
912
871
Location loc = linalgOp.getLoc ();
913
- // Check if reshape is expanding or collapsing.
914
- auto expandingReshapeOp = dyn_cast<tensor::ExpandShapeOp>(*reshapeOp);
915
- auto collapsingReshapeOp = dyn_cast<tensor::CollapseShapeOp>(*reshapeOp);
916
- bool isExpanding = (expandingReshapeOp != nullptr );
917
- RankedTensorType expandedType = isExpanding
918
- ? expandingReshapeOp.getResultType ()
919
- : collapsingReshapeOp.getSrcType ();
920
- RankedTensorType collapsedType = isExpanding
921
- ? expandingReshapeOp.getSrcType ()
922
- : collapsingReshapeOp.getResultType ();
872
+ SmallVector<OpFoldResult> expandedShape, collapsedShape;
873
+ SmallVector<AffineMap, 4 > reassociationIndices;
874
+ Value src;
875
+ if (auto expandingReshapeOp = dyn_cast<tensor::ExpandShapeOp>(reshapeOp)) {
876
+ // Try to move the dynamic dimensions in output shape before the `linalgOp`
877
+ // to maintain SSA validity
878
+ if (failed (moveValueDefinitions (
879
+ rewriter, expandingReshapeOp.getOutputShape (), linalgOp)))
880
+ return std::nullopt;
881
+
882
+ expandedShape = expandingReshapeOp.getMixedOutputShape ();
883
+ reassociationIndices = expandingReshapeOp.getReassociationMaps ();
884
+ src = expandingReshapeOp.getSrc ();
885
+ } else {
886
+ auto collapsingReshapeOp = dyn_cast<tensor::CollapseShapeOp>(reshapeOp);
887
+ if (!collapsingReshapeOp)
888
+ return std::nullopt;
889
+
890
+ expandedShape = tensor::getMixedSizes (
891
+ rewriter, collapsingReshapeOp->getLoc (), collapsingReshapeOp.getSrc ());
892
+ reassociationIndices = collapsingReshapeOp.getReassociationMaps ();
893
+ src = collapsingReshapeOp.getSrc ();
894
+ }
923
895
924
896
ExpansionInfo expansionInfo;
925
- if (failed (expansionInfo.compute (
926
- linalgOp, fusableOpOperand,
927
- isExpanding ? expandingReshapeOp.getReassociationMaps ()
928
- : collapsingReshapeOp.getReassociationMaps (),
929
- expandedType.getShape (), collapsedType.getShape (), rewriter)))
930
- return std::nullopt;
931
-
932
- // TODO: With the support of multiple dynamic dims expansion in
933
- // tensor.expand_shape op, this case can be handled.
934
- if (failed (validateDynamicDimExpansion (linalgOp, expansionInfo, rewriter)))
935
- return std::nullopt;
936
-
937
- if (failed (isLinalgOpExpandable (linalgOp, expansionInfo, rewriter)))
897
+ if (failed (expansionInfo.compute (linalgOp, fusableOpOperand,
898
+ reassociationIndices, expandedShape,
899
+ rewriter)))
938
900
return std::nullopt;
939
901
940
902
SmallVector<AffineMap, 4 > expandedOpIndexingMaps = llvm::to_vector<4 >(
@@ -950,15 +912,16 @@ fuseWithReshapeByExpansion(LinalgOp linalgOp, Operation *reshapeOp,
950
912
expandedOpOperands.reserve (linalgOp.getNumDpsInputs ());
951
913
for (OpOperand *opOperand : linalgOp.getDpsInputOperands ()) {
952
914
if (opOperand == fusableOpOperand) {
953
- expandedOpOperands.push_back (isExpanding ? expandingReshapeOp.getSrc ()
954
- : collapsingReshapeOp.getSrc ());
915
+ expandedOpOperands.push_back (src);
955
916
continue ;
956
917
}
957
918
if (auto opOperandType =
958
919
dyn_cast<RankedTensorType>(opOperand->get ().getType ())) {
959
920
AffineMap indexingMap = linalgOp.getMatchingIndexingMap (opOperand);
960
- RankedTensorType expandedOperandType =
961
- getExpandedType (opOperandType, indexingMap, expansionInfo);
921
+ SmallVector<OpFoldResult> expandedOperandShape;
922
+ RankedTensorType expandedOperandType;
923
+ std::tie (expandedOperandShape, expandedOperandType) =
924
+ getExpandedShapeAndType (opOperandType, indexingMap, expansionInfo);
962
925
if (expandedOperandType != opOperand->get ().getType ()) {
963
926
// Reshape the operand to get the right type.
964
927
SmallVector<ReassociationIndices> reassociation =
@@ -972,7 +935,8 @@ fuseWithReshapeByExpansion(LinalgOp linalgOp, Operation *reshapeOp,
972
935
/* isExpandingReshape=*/ true )))
973
936
return std::nullopt;
974
937
expandedOpOperands.push_back (rewriter.create <tensor::ExpandShapeOp>(
975
- loc, expandedOperandType, opOperand->get (), reassociation));
938
+ loc, expandedOperandType, opOperand->get (), reassociation,
939
+ expandedOperandShape));
976
940
continue ;
977
941
}
978
942
}
@@ -983,8 +947,10 @@ fuseWithReshapeByExpansion(LinalgOp linalgOp, Operation *reshapeOp,
983
947
for (OpOperand &opOperand : linalgOp.getDpsInitsMutable ()) {
984
948
AffineMap indexingMap = linalgOp.getMatchingIndexingMap (&opOperand);
985
949
auto opOperandType = cast<RankedTensorType>(opOperand.get ().getType ());
986
- RankedTensorType expandedOutputType =
987
- getExpandedType (opOperandType, indexingMap, expansionInfo);
950
+ SmallVector<OpFoldResult> expandedOutputShape;
951
+ RankedTensorType expandedOutputType;
952
+ std::tie (expandedOutputShape, expandedOutputType) =
953
+ getExpandedShapeAndType (opOperandType, indexingMap, expansionInfo);
988
954
if (expandedOutputType != opOperand.get ().getType ()) {
989
955
SmallVector<ReassociationIndices> reassociation =
990
956
getReassociationForExpansion (indexingMap, expansionInfo);
@@ -997,7 +963,8 @@ fuseWithReshapeByExpansion(LinalgOp linalgOp, Operation *reshapeOp,
997
963
/* isExpandingReshape=*/ true )))
998
964
return std::nullopt;
999
965
outputs.push_back (rewriter.create <tensor::ExpandShapeOp>(
1000
- loc, expandedOutputType, opOperand.get (), reassociation));
966
+ loc, expandedOutputType, opOperand.get (), reassociation,
967
+ expandedOutputShape));
1001
968
} else {
1002
969
outputs.push_back (opOperand.get ());
1003
970
}
0 commit comments