25
25
#include " mlir/IR/PatternMatch.h"
26
26
#include " mlir/Support/LLVM.h"
27
27
#include " mlir/Transforms/GreedyPatternRewriteDriver.h"
28
+ #include " mlir/Transforms/RegionUtils.h"
28
29
#include < optional>
29
30
#include < utility>
30
31
@@ -590,44 +591,45 @@ class ExpansionInfo {
590
591
// the expanded op.
591
592
LogicalResult compute (LinalgOp linalgOp, OpOperand *fusableOpOperand,
592
593
ArrayRef<AffineMap> reassociationMaps,
593
- ArrayRef<int64_t > expandedShape,
594
- ArrayRef<int64_t > collapsedShape,
594
+ ArrayRef<OpFoldResult> expandedShape,
595
595
PatternRewriter &rewriter);
596
596
unsigned getOrigOpNumDims () const { return reassociation.size (); }
597
597
unsigned getExpandedOpNumDims () const { return expandedOpNumDims; }
598
598
ReassociationIndicesRef getExpandedDims (unsigned i) const {
599
599
return reassociation[i];
600
600
}
601
- ArrayRef<int64_t > getExpandedShapeOfDim (unsigned i) const {
601
+ ArrayRef<OpFoldResult > getExpandedShapeOfDim (unsigned i) const {
602
602
return expandedShapeMap[i];
603
603
}
604
- ArrayRef<int64_t > getOriginalShape () const { return originalLoopExtent; }
604
+ ArrayRef<OpFoldResult > getOriginalShape () const { return originalLoopExtent; }
605
605
606
606
private:
607
607
// / Reassociation from the dimensions in the original operation to the
608
608
// / dimension of the expanded operation.
609
609
SmallVector<ReassociationIndices> reassociation;
610
610
// / Mapping from extent of loops in the original operation, to the extent of
611
611
// / loops in the expanded operation.
612
- SmallVector<SmallVector<int64_t >> expandedShapeMap;
612
+ SmallVector<SmallVector<OpFoldResult >> expandedShapeMap;
613
613
// / Extent of the loop in the original operation.
614
- SmallVector<int64_t > originalLoopExtent;
614
+ SmallVector<OpFoldResult > originalLoopExtent;
615
615
unsigned expandedOpNumDims;
616
616
};
617
617
} // namespace
618
618
619
619
LogicalResult ExpansionInfo::compute (LinalgOp linalgOp,
620
620
OpOperand *fusableOpOperand,
621
621
ArrayRef<AffineMap> reassociationMaps,
622
- ArrayRef<int64_t > expandedShape,
623
- ArrayRef<int64_t > collapsedShape,
622
+ ArrayRef<OpFoldResult> expandedShape,
624
623
PatternRewriter &rewriter) {
625
624
if (reassociationMaps.empty ())
626
625
return failure ();
627
626
AffineMap fusedIndexMap = linalgOp.getMatchingIndexingMap (fusableOpOperand);
628
627
629
- SmallVector<int64_t , 4 > originalLoopRange = linalgOp.getStaticLoopRanges ();
630
- originalLoopExtent.assign (originalLoopRange.begin (), originalLoopRange.end ());
628
+ OpBuilder::InsertionGuard g (rewriter);
629
+ rewriter.setInsertionPoint (linalgOp);
630
+ originalLoopExtent = llvm::map_to_vector (
631
+ linalgOp.createLoopRanges (rewriter, linalgOp->getLoc ()),
632
+ [](Range r) { return r.size ; });
631
633
632
634
reassociation.clear ();
633
635
expandedShapeMap.clear ();
@@ -639,7 +641,7 @@ LogicalResult ExpansionInfo::compute(LinalgOp linalgOp,
639
641
unsigned pos = cast<AffineDimExpr>(resultExpr.value ()).getPosition ();
640
642
AffineMap foldedDims = reassociationMaps[resultExpr.index ()];
641
643
numExpandedDims[pos] = foldedDims.getNumResults ();
642
- ArrayRef<int64_t > shape =
644
+ ArrayRef<OpFoldResult > shape =
643
645
expandedShape.slice (foldedDims.getDimPosition (0 ), numExpandedDims[pos]);
644
646
expandedShapeMap[pos].assign (shape.begin (), shape.end ());
645
647
}
@@ -660,33 +662,6 @@ LogicalResult ExpansionInfo::compute(LinalgOp linalgOp,
660
662
return success ();
661
663
}
662
664
663
- // / Expanding the body of a linalg operation requires adaptations of the
664
- // / accessed loop indices. Specifically, access of indices in the original
665
- // / operation need to be replaced with linearizations of indices in the expanded
666
- // / op. That requires the shape of the expanded dimensions to be static (at
667
- // / least all but the most significant). For now check that these are all
668
- // / statically sized. Note that this could be extended to handle dynamic case,
669
- // / but the implementation below uses `affine.apply` which seems to have issues
670
- // / when the shapes are not static.
671
- static LogicalResult isLinalgOpExpandable (LinalgOp linalgOp,
672
- const ExpansionInfo &expansionInfo,
673
- PatternRewriter &rewriter) {
674
- if (!linalgOp.hasIndexSemantics ())
675
- return success ();
676
- for (unsigned i : llvm::seq<unsigned >(0 , expansionInfo.getOrigOpNumDims ())) {
677
- ArrayRef<int64_t > expandedShape = expansionInfo.getExpandedShapeOfDim (i);
678
- if (expandedShape.size () == 1 )
679
- continue ;
680
- for (int64_t shape : expandedShape.drop_front ()) {
681
- if (ShapedType::isDynamic (shape)) {
682
- return rewriter.notifyMatchFailure (
683
- linalgOp, " cannot expand due to index semantics and dynamic dims" );
684
- }
685
- }
686
- }
687
- return success ();
688
- }
689
-
690
665
// / Return the indexing map to use in the expanded op for a given the
691
666
// / `indexingMap` of the original operation.
692
667
static AffineMap
@@ -708,16 +683,28 @@ getIndexingMapInExpandedOp(OpBuilder &builder, AffineMap indexingMap,
708
683
709
684
// / Return the type of the operand/result to use in the expanded op given the
710
685
// / type in the original op.
711
- static RankedTensorType getExpandedType (RankedTensorType originalType,
712
- AffineMap indexingMap,
713
- const ExpansionInfo &expansionInfo) {
714
- SmallVector<int64_t > expandedShape;
686
+ static std::tuple<SmallVector<OpFoldResult>, RankedTensorType>
687
+ getExpandedShapeAndType (RankedTensorType originalType, AffineMap indexingMap,
688
+ const ExpansionInfo &expansionInfo) {
689
+ SmallVector<int64_t > expandedStaticShape;
690
+ SmallVector<OpFoldResult> expandedShape;
715
691
for (AffineExpr expr : indexingMap.getResults ()) {
716
692
unsigned dim = cast<AffineDimExpr>(expr).getPosition ();
717
- auto dimExpansion = expansionInfo.getExpandedShapeOfDim (dim);
693
+ ArrayRef<OpFoldResult> dimExpansion =
694
+ expansionInfo.getExpandedShapeOfDim (dim);
695
+ llvm::append_range (expandedStaticShape,
696
+ llvm::map_range (dimExpansion, [](OpFoldResult ofr) {
697
+ std::optional<int64_t > staticShape =
698
+ getConstantIntValue (ofr);
699
+ if (staticShape) {
700
+ return staticShape.value ();
701
+ }
702
+ return ShapedType::kDynamic ;
703
+ }));
718
704
expandedShape.append (dimExpansion.begin (), dimExpansion.end ());
719
705
}
720
- return RankedTensorType::get (expandedShape, originalType.getElementType ());
706
+ return {expandedShape, RankedTensorType::get (expandedStaticShape,
707
+ originalType.getElementType ())};
721
708
}
722
709
723
710
// / Returns the reassociation maps to use in the `tensor.expand_shape`
@@ -765,49 +752,27 @@ static void updateExpandedGenericOpRegion(PatternRewriter &rewriter,
765
752
// Linearize the expanded indices of the original index dimension.
766
753
OpBuilder::InsertionGuard guard (rewriter);
767
754
rewriter.setInsertionPointAfter (indexOp);
768
- ArrayRef<int64_t > expandedDimsShape =
755
+ ArrayRef<OpFoldResult > expandedDimsShape =
769
756
expansionInfo.getExpandedShapeOfDim (indexOp.getDim ()).drop_front ();
770
757
SmallVector<Value> expandedIndices;
771
758
expandedIndices.reserve (expandedDims.size () - 1 );
772
759
llvm::transform (
773
760
expandedDims.drop_front (), std::back_inserter (expandedIndices),
774
761
[&](int64_t dim) { return rewriter.create <IndexOp>(loc, dim); });
775
- Value newIndex = rewriter.create <IndexOp>(loc, expandedDims.front ());
762
+ OpFoldResult newIndex =
763
+ rewriter.create <IndexOp>(loc, expandedDims.front ()).getResult ();
776
764
for (auto it : llvm::zip (expandedDimsShape, expandedIndices)) {
777
- assert (!ShapedType::isDynamic (std::get<0 >(it)));
778
- AffineExpr idx, acc;
765
+ AffineExpr idx, acc, shape;
779
766
bindDims (rewriter.getContext (), idx, acc);
780
- newIndex = rewriter.create <affine::AffineApplyOp>(
781
- indexOp.getLoc (), idx + acc * std::get<0 >(it),
782
- ValueRange{std::get<1 >(it), newIndex});
783
- }
784
- rewriter.replaceOp (indexOp, newIndex);
785
- }
786
- }
787
-
788
- // / Checks if a single dynamic dimension expanded into multiple dynamic
789
- // / dimensions.
790
- static LogicalResult
791
- validateDynamicDimExpansion (LinalgOp linalgOp,
792
- const ExpansionInfo &expansionInfo,
793
- PatternRewriter &rewriter) {
794
- for (unsigned i : llvm::seq<unsigned >(0 , expansionInfo.getOrigOpNumDims ())) {
795
- ArrayRef<int64_t > expandedShape = expansionInfo.getExpandedShapeOfDim (i);
796
- if (expandedShape.size () == 1 )
797
- continue ;
798
- bool foundDynamic = false ;
799
- for (int64_t shape : expandedShape) {
800
- if (!ShapedType::isDynamic (shape))
801
- continue ;
802
- if (foundDynamic) {
803
- return rewriter.notifyMatchFailure (
804
- linalgOp, " cannot infer expanded shape with multiple dynamic "
805
- " dims in the same reassociation group" );
806
- }
807
- foundDynamic = true ;
767
+ bindSymbols (rewriter.getContext (), shape);
768
+ newIndex = affine::makeComposedFoldedAffineApply (
769
+ rewriter, indexOp.getLoc (), idx + acc * shape,
770
+ ArrayRef<OpFoldResult>{std::get<1 >(it), newIndex, std::get<0 >(it)});
808
771
}
772
+ Value newIndexVal =
773
+ getValueOrCreateConstantIndexOp (rewriter, indexOp.getLoc (), newIndex);
774
+ rewriter.replaceOp (indexOp, newIndexVal);
809
775
}
810
- return success ();
811
776
}
812
777
813
778
// Create an expanded transpose op.
@@ -910,31 +875,31 @@ fuseWithReshapeByExpansion(LinalgOp linalgOp, Operation *reshapeOp,
910
875
" preconditions for fuse operation failed" );
911
876
912
877
Location loc = linalgOp.getLoc ();
913
- // Check if reshape is expanding or collapsing.
914
- auto expandingReshapeOp = dyn_cast<tensor::ExpandShapeOp>(*reshapeOp);
915
- auto collapsingReshapeOp = dyn_cast<tensor::CollapseShapeOp>(*reshapeOp);
916
- bool isExpanding = (expandingReshapeOp != nullptr );
917
- RankedTensorType expandedType = isExpanding
918
- ? expandingReshapeOp.getResultType ()
919
- : collapsingReshapeOp.getSrcType ();
920
- RankedTensorType collapsedType = isExpanding
921
- ? expandingReshapeOp.getSrcType ()
922
- : collapsingReshapeOp.getResultType ();
878
+ SmallVector<OpFoldResult> expandedShape, collapsedShape;
879
+ SmallVector<AffineMap, 4 > reassociationIndices;
880
+ Value src;
881
+ if (auto expandingReshapeOp = dyn_cast<tensor::ExpandShapeOp>(reshapeOp)) {
882
+ // Try to move the dynamic dimensions in output shape before the `linalgOp`
883
+ // to maintain SSA validity
884
+ if (failed (moveValueDefinitions (
885
+ rewriter, expandingReshapeOp.getOutputShape (), linalgOp)))
886
+ return std::nullopt;
887
+
888
+ expandedShape = expandingReshapeOp.getMixedOutputShape ();
889
+ reassociationIndices = expandingReshapeOp.getReassociationMaps ();
890
+ src = expandingReshapeOp.getSrc ();
891
+ } else {
892
+ auto collapsingReshapeOp = dyn_cast<tensor::CollapseShapeOp>(reshapeOp);
893
+ expandedShape = tensor::getMixedSizes (
894
+ rewriter, collapsingReshapeOp->getLoc (), collapsingReshapeOp.getSrc ());
895
+ reassociationIndices = collapsingReshapeOp.getReassociationMaps ();
896
+ src = collapsingReshapeOp.getSrc ();
897
+ }
923
898
924
899
ExpansionInfo expansionInfo;
925
- if (failed (expansionInfo.compute (
926
- linalgOp, fusableOpOperand,
927
- isExpanding ? expandingReshapeOp.getReassociationMaps ()
928
- : collapsingReshapeOp.getReassociationMaps (),
929
- expandedType.getShape (), collapsedType.getShape (), rewriter)))
930
- return std::nullopt;
931
-
932
- // TODO: With the support of multiple dynamic dims expansion in
933
- // tensor.expand_shape op, this case can be handled.
934
- if (failed (validateDynamicDimExpansion (linalgOp, expansionInfo, rewriter)))
935
- return std::nullopt;
936
-
937
- if (failed (isLinalgOpExpandable (linalgOp, expansionInfo, rewriter)))
900
+ if (failed (expansionInfo.compute (linalgOp, fusableOpOperand,
901
+ reassociationIndices, expandedShape,
902
+ rewriter)))
938
903
return std::nullopt;
939
904
940
905
SmallVector<AffineMap, 4 > expandedOpIndexingMaps = llvm::to_vector<4 >(
@@ -950,15 +915,16 @@ fuseWithReshapeByExpansion(LinalgOp linalgOp, Operation *reshapeOp,
950
915
expandedOpOperands.reserve (linalgOp.getNumDpsInputs ());
951
916
for (OpOperand *opOperand : linalgOp.getDpsInputOperands ()) {
952
917
if (opOperand == fusableOpOperand) {
953
- expandedOpOperands.push_back (isExpanding ? expandingReshapeOp.getSrc ()
954
- : collapsingReshapeOp.getSrc ());
918
+ expandedOpOperands.push_back (src);
955
919
continue ;
956
920
}
957
921
if (auto opOperandType =
958
922
dyn_cast<RankedTensorType>(opOperand->get ().getType ())) {
959
923
AffineMap indexingMap = linalgOp.getMatchingIndexingMap (opOperand);
960
- RankedTensorType expandedOperandType =
961
- getExpandedType (opOperandType, indexingMap, expansionInfo);
924
+ SmallVector<OpFoldResult> expandedOperandShape;
925
+ RankedTensorType expandedOperandType;
926
+ std::tie (expandedOperandShape, expandedOperandType) =
927
+ getExpandedShapeAndType (opOperandType, indexingMap, expansionInfo);
962
928
if (expandedOperandType != opOperand->get ().getType ()) {
963
929
// Reshape the operand to get the right type.
964
930
SmallVector<ReassociationIndices> reassociation =
@@ -972,7 +938,8 @@ fuseWithReshapeByExpansion(LinalgOp linalgOp, Operation *reshapeOp,
972
938
/* isExpandingReshape=*/ true )))
973
939
return std::nullopt;
974
940
expandedOpOperands.push_back (rewriter.create <tensor::ExpandShapeOp>(
975
- loc, expandedOperandType, opOperand->get (), reassociation));
941
+ loc, expandedOperandType, opOperand->get (), reassociation,
942
+ expandedOperandShape));
976
943
continue ;
977
944
}
978
945
}
@@ -983,8 +950,10 @@ fuseWithReshapeByExpansion(LinalgOp linalgOp, Operation *reshapeOp,
983
950
for (OpOperand &opOperand : linalgOp.getDpsInitsMutable ()) {
984
951
AffineMap indexingMap = linalgOp.getMatchingIndexingMap (&opOperand);
985
952
auto opOperandType = cast<RankedTensorType>(opOperand.get ().getType ());
986
- RankedTensorType expandedOutputType =
987
- getExpandedType (opOperandType, indexingMap, expansionInfo);
953
+ SmallVector<OpFoldResult> expandedOutputShape;
954
+ RankedTensorType expandedOutputType;
955
+ std::tie (expandedOutputShape, expandedOutputType) =
956
+ getExpandedShapeAndType (opOperandType, indexingMap, expansionInfo);
988
957
if (expandedOutputType != opOperand.get ().getType ()) {
989
958
SmallVector<ReassociationIndices> reassociation =
990
959
getReassociationForExpansion (indexingMap, expansionInfo);
@@ -997,7 +966,8 @@ fuseWithReshapeByExpansion(LinalgOp linalgOp, Operation *reshapeOp,
997
966
/* isExpandingReshape=*/ true )))
998
967
return std::nullopt;
999
968
outputs.push_back (rewriter.create <tensor::ExpandShapeOp>(
1000
- loc, expandedOutputType, opOperand.get (), reassociation));
969
+ loc, expandedOutputType, opOperand.get (), reassociation,
970
+ expandedOutputShape));
1001
971
} else {
1002
972
outputs.push_back (opOperand.get ());
1003
973
}
0 commit comments