@@ -595,44 +595,45 @@ class ExpansionInfo {
595
595
// the expanded op.
596
596
LogicalResult compute (LinalgOp linalgOp, OpOperand *fusableOpOperand,
597
597
ArrayRef<AffineMap> reassociationMaps,
598
- ArrayRef<int64_t > expandedShape,
599
- ArrayRef<int64_t > collapsedShape,
598
+ ArrayRef<OpFoldResult> expandedShape,
600
599
PatternRewriter &rewriter);
601
600
unsigned getOrigOpNumDims () const { return reassociation.size (); }
602
601
unsigned getExpandedOpNumDims () const { return expandedOpNumDims; }
603
602
ReassociationIndicesRef getExpandedDims (unsigned i) const {
604
603
return reassociation[i];
605
604
}
606
- ArrayRef<int64_t > getExpandedShapeOfDim (unsigned i) const {
605
+ ArrayRef<OpFoldResult > getExpandedShapeOfDim (unsigned i) const {
607
606
return expandedShapeMap[i];
608
607
}
609
- ArrayRef<int64_t > getOriginalShape () const { return originalLoopExtent; }
608
+ ArrayRef<OpFoldResult > getOriginalShape () const { return originalLoopExtent; }
610
609
611
610
private:
612
611
// / Reassociation from the dimensions in the original operation to the
613
612
// / dimension of the expanded operation.
614
613
SmallVector<ReassociationIndices> reassociation;
615
614
// / Mapping from extent of loops in the original operation, to the extent of
616
615
// / loops in the expanded operation.
617
- SmallVector<SmallVector<int64_t >> expandedShapeMap;
616
+ SmallVector<SmallVector<OpFoldResult >> expandedShapeMap;
618
617
// / Extent of the loop in the original operation.
619
- SmallVector<int64_t > originalLoopExtent;
618
+ SmallVector<OpFoldResult > originalLoopExtent;
620
619
unsigned expandedOpNumDims;
621
620
};
622
621
} // namespace
623
622
624
623
LogicalResult ExpansionInfo::compute (LinalgOp linalgOp,
625
624
OpOperand *fusableOpOperand,
626
625
ArrayRef<AffineMap> reassociationMaps,
627
- ArrayRef<int64_t > expandedShape,
628
- ArrayRef<int64_t > collapsedShape,
626
+ ArrayRef<OpFoldResult> expandedShape,
629
627
PatternRewriter &rewriter) {
630
628
if (reassociationMaps.empty ())
631
629
return failure ();
632
630
AffineMap fusedIndexMap = linalgOp.getMatchingIndexingMap (fusableOpOperand);
633
631
634
- SmallVector<int64_t , 4 > originalLoopRange = linalgOp.getStaticLoopRanges ();
635
- originalLoopExtent.assign (originalLoopRange.begin (), originalLoopRange.end ());
632
+ OpBuilder::InsertionGuard g (rewriter);
633
+ rewriter.setInsertionPoint (linalgOp);
634
+ originalLoopExtent = llvm::map_to_vector (
635
+ linalgOp.createLoopRanges (rewriter, linalgOp->getLoc ()),
636
+ [](Range r) { return r.size ; });
636
637
637
638
reassociation.clear ();
638
639
expandedShapeMap.clear ();
@@ -644,7 +645,7 @@ LogicalResult ExpansionInfo::compute(LinalgOp linalgOp,
644
645
unsigned pos = cast<AffineDimExpr>(resultExpr.value ()).getPosition ();
645
646
AffineMap foldedDims = reassociationMaps[resultExpr.index ()];
646
647
numExpandedDims[pos] = foldedDims.getNumResults ();
647
- ArrayRef<int64_t > shape =
648
+ ArrayRef<OpFoldResult > shape =
648
649
expandedShape.slice (foldedDims.getDimPosition (0 ), numExpandedDims[pos]);
649
650
expandedShapeMap[pos].assign (shape.begin (), shape.end ());
650
651
}
@@ -665,33 +666,6 @@ LogicalResult ExpansionInfo::compute(LinalgOp linalgOp,
665
666
return success ();
666
667
}
667
668
668
- // / Expanding the body of a linalg operation requires adaptations of the
669
- // / accessed loop indices. Specifically, access of indices in the original
670
- // / operation need to be replaced with linearizations of indices in the expanded
671
- // / op. That requires the shape of the expanded dimensions to be static (at
672
- // / least all but the most significant). For now check that these are all
673
- // / statically sized. Note that this could be extended to handle dynamic case,
674
- // / but the implementation below uses `affine.apply` which seems to have issues
675
- // / when the shapes are not static.
676
- static LogicalResult isLinalgOpExpandable (LinalgOp linalgOp,
677
- const ExpansionInfo &expansionInfo,
678
- PatternRewriter &rewriter) {
679
- if (!linalgOp.hasIndexSemantics ())
680
- return success ();
681
- for (unsigned i : llvm::seq<unsigned >(0 , expansionInfo.getOrigOpNumDims ())) {
682
- ArrayRef<int64_t > expandedShape = expansionInfo.getExpandedShapeOfDim (i);
683
- if (expandedShape.size () == 1 )
684
- continue ;
685
- for (int64_t shape : expandedShape.drop_front ()) {
686
- if (ShapedType::isDynamic (shape)) {
687
- return rewriter.notifyMatchFailure (
688
- linalgOp, " cannot expand due to index semantics and dynamic dims" );
689
- }
690
- }
691
- }
692
- return success ();
693
- }
694
-
695
669
// / Return the indexing map to use in the expanded op for a given the
696
670
// / `indexingMap` of the original operation.
697
671
static AffineMap
@@ -713,16 +687,28 @@ getIndexingMapInExpandedOp(OpBuilder &builder, AffineMap indexingMap,
713
687
714
688
// / Return the type of the operand/result to use in the expanded op given the
715
689
// / type in the original op.
716
- static RankedTensorType getExpandedType (RankedTensorType originalType,
717
- AffineMap indexingMap,
718
- const ExpansionInfo &expansionInfo) {
719
- SmallVector<int64_t > expandedShape;
690
+ static std::tuple<SmallVector<OpFoldResult>, RankedTensorType>
691
+ getExpandedShapeAndType (RankedTensorType originalType, AffineMap indexingMap,
692
+ const ExpansionInfo &expansionInfo) {
693
+ SmallVector<int64_t > expandedStaticShape;
694
+ SmallVector<OpFoldResult> expandedShape;
720
695
for (AffineExpr expr : indexingMap.getResults ()) {
721
696
unsigned dim = cast<AffineDimExpr>(expr).getPosition ();
722
- auto dimExpansion = expansionInfo.getExpandedShapeOfDim (dim);
697
+ ArrayRef<OpFoldResult> dimExpansion =
698
+ expansionInfo.getExpandedShapeOfDim (dim);
699
+ llvm::append_range (expandedStaticShape,
700
+ llvm::map_range (dimExpansion, [](OpFoldResult ofr) {
701
+ std::optional<int64_t > staticShape =
702
+ getConstantIntValue (ofr);
703
+ if (staticShape) {
704
+ return staticShape.value ();
705
+ }
706
+ return ShapedType::kDynamic ;
707
+ }));
723
708
expandedShape.append (dimExpansion.begin (), dimExpansion.end ());
724
709
}
725
- return RankedTensorType::get (expandedShape, originalType.getElementType ());
710
+ return {expandedShape, RankedTensorType::get (expandedStaticShape,
711
+ originalType.getElementType ())};
726
712
}
727
713
728
714
// / Returns the reassociation maps to use in the `tensor.expand_shape`
@@ -770,49 +756,27 @@ static void updateExpandedGenericOpRegion(PatternRewriter &rewriter,
770
756
// Linearize the expanded indices of the original index dimension.
771
757
OpBuilder::InsertionGuard guard (rewriter);
772
758
rewriter.setInsertionPointAfter (indexOp);
773
- ArrayRef<int64_t > expandedDimsShape =
759
+ ArrayRef<OpFoldResult > expandedDimsShape =
774
760
expansionInfo.getExpandedShapeOfDim (indexOp.getDim ()).drop_front ();
775
761
SmallVector<Value> expandedIndices;
776
762
expandedIndices.reserve (expandedDims.size () - 1 );
777
763
llvm::transform (
778
764
expandedDims.drop_front (), std::back_inserter (expandedIndices),
779
765
[&](int64_t dim) { return rewriter.create <IndexOp>(loc, dim); });
780
- Value newIndex = rewriter.create <IndexOp>(loc, expandedDims.front ());
766
+ OpFoldResult newIndex =
767
+ rewriter.create <IndexOp>(loc, expandedDims.front ()).getResult ();
781
768
for (auto it : llvm::zip (expandedDimsShape, expandedIndices)) {
782
- assert (!ShapedType::isDynamic (std::get<0 >(it)));
783
- AffineExpr idx, acc;
769
+ AffineExpr idx, acc, shape;
784
770
bindDims (rewriter.getContext (), idx, acc);
785
- newIndex = rewriter.create <affine::AffineApplyOp>(
786
- indexOp.getLoc (), idx + acc * std::get<0 >(it),
787
- ValueRange{std::get<1 >(it), newIndex});
788
- }
789
- rewriter.replaceOp (indexOp, newIndex);
790
- }
791
- }
792
-
793
- // / Checks if a single dynamic dimension expanded into multiple dynamic
794
- // / dimensions.
795
- static LogicalResult
796
- validateDynamicDimExpansion (LinalgOp linalgOp,
797
- const ExpansionInfo &expansionInfo,
798
- PatternRewriter &rewriter) {
799
- for (unsigned i : llvm::seq<unsigned >(0 , expansionInfo.getOrigOpNumDims ())) {
800
- ArrayRef<int64_t > expandedShape = expansionInfo.getExpandedShapeOfDim (i);
801
- if (expandedShape.size () == 1 )
802
- continue ;
803
- bool foundDynamic = false ;
804
- for (int64_t shape : expandedShape) {
805
- if (!ShapedType::isDynamic (shape))
806
- continue ;
807
- if (foundDynamic) {
808
- return rewriter.notifyMatchFailure (
809
- linalgOp, " cannot infer expanded shape with multiple dynamic "
810
- " dims in the same reassociation group" );
811
- }
812
- foundDynamic = true ;
771
+ bindSymbols (rewriter.getContext (), shape);
772
+ newIndex = affine::makeComposedFoldedAffineApply (
773
+ rewriter, indexOp.getLoc (), idx + acc * shape,
774
+ ArrayRef<OpFoldResult>{std::get<1 >(it), newIndex, std::get<0 >(it)});
813
775
}
776
+ Value newIndexVal =
777
+ getValueOrCreateConstantIndexOp (rewriter, indexOp.getLoc (), newIndex);
778
+ rewriter.replaceOp (indexOp, newIndexVal);
814
779
}
815
- return success ();
816
780
}
817
781
818
782
// / Implements the fusion of a tensor.collapse_shape or a tensor.expand_shape op
@@ -826,31 +790,25 @@ fuseWithReshapeByExpansion(LinalgOp linalgOp, Operation *reshapeOp,
826
790
" preconditions for fuse operation failed" );
827
791
828
792
Location loc = linalgOp.getLoc ();
829
- // Check if reshape is expanding or collapsing.
830
- auto expandingReshapeOp = dyn_cast<tensor::ExpandShapeOp>(*reshapeOp);
831
- auto collapsingReshapeOp = dyn_cast<tensor::CollapseShapeOp>(*reshapeOp);
832
- bool isExpanding = (expandingReshapeOp != nullptr );
833
- RankedTensorType expandedType = isExpanding
834
- ? expandingReshapeOp.getResultType ()
835
- : collapsingReshapeOp.getSrcType ();
836
- RankedTensorType collapsedType = isExpanding
837
- ? expandingReshapeOp.getSrcType ()
838
- : collapsingReshapeOp.getResultType ();
793
+ SmallVector<OpFoldResult> expandedShape, collapsedShape;
794
+ SmallVector<AffineMap, 4 > reassociationIndices;
795
+ Value src;
796
+ if (auto expandingReshapeOp = dyn_cast<tensor::ExpandShapeOp>(reshapeOp)) {
797
+ expandedShape = expandingReshapeOp.getMixedOutputShape ();
798
+ reassociationIndices = expandingReshapeOp.getReassociationMaps ();
799
+ src = expandingReshapeOp.getSrc ();
800
+ } else {
801
+ auto collapsingReshapeOp = dyn_cast<tensor::CollapseShapeOp>(reshapeOp);
802
+ expandedShape = tensor::getMixedSizes (
803
+ rewriter, collapsingReshapeOp->getLoc (), collapsingReshapeOp.getSrc ());
804
+ reassociationIndices = collapsingReshapeOp.getReassociationMaps ();
805
+ src = collapsingReshapeOp.getSrc ();
806
+ }
839
807
840
808
ExpansionInfo expansionInfo;
841
809
if (failed (expansionInfo.compute (
842
- linalgOp, fusableOpOperand,
843
- isExpanding ? expandingReshapeOp.getReassociationMaps ()
844
- : collapsingReshapeOp.getReassociationMaps (),
845
- expandedType.getShape (), collapsedType.getShape (), rewriter)))
846
- return std::nullopt;
847
-
848
- // TODO: With the support of multiple dynamic dims expansion in
849
- // tensor.expand_shape op, this case can be handled.
850
- if (failed (validateDynamicDimExpansion (linalgOp, expansionInfo, rewriter)))
851
- return std::nullopt;
852
-
853
- if (failed (isLinalgOpExpandable (linalgOp, expansionInfo, rewriter)))
810
+ linalgOp, fusableOpOperand, reassociationIndices,
811
+ expandedShape, rewriter)))
854
812
return std::nullopt;
855
813
856
814
SmallVector<AffineMap, 4 > expandedOpIndexingMaps = llvm::to_vector<4 >(
@@ -866,15 +824,16 @@ fuseWithReshapeByExpansion(LinalgOp linalgOp, Operation *reshapeOp,
866
824
expandedOpOperands.reserve (linalgOp.getNumDpsInputs ());
867
825
for (OpOperand *opOperand : linalgOp.getDpsInputOperands ()) {
868
826
if (opOperand == fusableOpOperand) {
869
- expandedOpOperands.push_back (isExpanding ? expandingReshapeOp.getSrc ()
870
- : collapsingReshapeOp.getSrc ());
827
+ expandedOpOperands.push_back (src);
871
828
continue ;
872
829
}
873
830
if (auto opOperandType =
874
831
dyn_cast<RankedTensorType>(opOperand->get ().getType ())) {
875
832
AffineMap indexingMap = linalgOp.getMatchingIndexingMap (opOperand);
876
- RankedTensorType expandedOperandType =
877
- getExpandedType (opOperandType, indexingMap, expansionInfo);
833
+ SmallVector<OpFoldResult> expandedOperandShape;
834
+ RankedTensorType expandedOperandType;
835
+ std::tie (expandedOperandShape, expandedOperandType) =
836
+ getExpandedShapeAndType (opOperandType, indexingMap, expansionInfo);
878
837
if (expandedOperandType != opOperand->get ().getType ()) {
879
838
// Reshape the operand to get the right type.
880
839
SmallVector<ReassociationIndices> reassociation =
@@ -888,7 +847,8 @@ fuseWithReshapeByExpansion(LinalgOp linalgOp, Operation *reshapeOp,
888
847
/* isExpandingReshape=*/ true )))
889
848
return std::nullopt;
890
849
expandedOpOperands.push_back (rewriter.create <tensor::ExpandShapeOp>(
891
- loc, expandedOperandType, opOperand->get (), reassociation));
850
+ loc, expandedOperandType, opOperand->get (), reassociation,
851
+ expandedOperandShape));
892
852
continue ;
893
853
}
894
854
}
@@ -899,8 +859,10 @@ fuseWithReshapeByExpansion(LinalgOp linalgOp, Operation *reshapeOp,
899
859
for (OpOperand &opOperand : linalgOp.getDpsInitsMutable ()) {
900
860
AffineMap indexingMap = linalgOp.getMatchingIndexingMap (&opOperand);
901
861
auto opOperandType = cast<RankedTensorType>(opOperand.get ().getType ());
902
- RankedTensorType expandedOutputType =
903
- getExpandedType (opOperandType, indexingMap, expansionInfo);
862
+ SmallVector<OpFoldResult> expandedOutputShape;
863
+ RankedTensorType expandedOutputType;
864
+ std::tie (expandedOutputShape, expandedOutputType) =
865
+ getExpandedShapeAndType (opOperandType, indexingMap, expansionInfo);
904
866
if (expandedOutputType != opOperand.get ().getType ()) {
905
867
SmallVector<ReassociationIndices> reassociation =
906
868
getReassociationForExpansion (indexingMap, expansionInfo);
@@ -913,7 +875,8 @@ fuseWithReshapeByExpansion(LinalgOp linalgOp, Operation *reshapeOp,
913
875
/* isExpandingReshape=*/ true )))
914
876
return std::nullopt;
915
877
outputs.push_back (rewriter.create <tensor::ExpandShapeOp>(
916
- loc, expandedOutputType, opOperand.get (), reassociation));
878
+ loc, expandedOutputType, opOperand.get (), reassociation,
879
+ expandedOutputShape));
917
880
} else {
918
881
outputs.push_back (opOperand.get ());
919
882
}
0 commit comments