17
17
#include " mlir/Dialect/Utils/IndexingUtils.h"
18
18
#include " mlir/IR/Dominance.h"
19
19
#include " mlir/Transforms/GreedyPatternRewriteDriver.h"
20
+ #include " llvm/ADT/TypeSwitch.h"
20
21
#include " llvm/Support/Debug.h"
21
22
#include < optional>
22
23
@@ -572,6 +573,39 @@ projectToInnerMostNonUnitDimsPos(ArrayRef<int64_t> dimsPos,
572
573
return projectedDimsPos;
573
574
}
574
575
576
+ static int64_t applyPermutationAndReindexReassoc (
577
+ SmallVector<ReassociationIndices> &reassociationIndices,
578
+ ArrayRef<int64_t > dimsPerm) {
579
+ applyPermutationToVector<ReassociationIndices>(reassociationIndices,
580
+ dimsPerm);
581
+ int64_t lastPos = 0 ;
582
+ for (ReassociationIndices &indices : reassociationIndices) {
583
+ for (auto &index : indices) {
584
+ index = lastPos;
585
+ lastPos += 1 ;
586
+ }
587
+ }
588
+ return lastPos;
589
+ }
590
+
591
+ // / Bubble up pack op through collapse shape op when the packed dims can be
592
+ // / mapped to the source dims before collapsing. This is possible when the inner
593
+ // / tile sizes can divide the mapped source dims.
594
+ // /
595
+ // / For example:
596
+ // /
597
+ // / %collapsed = tensor.collapse_shape %in [[0, 1], 2] : tensor<?x16x4xf32> into
598
+ // / tensor<?x4xf32> %out = tensor.empty() : tensor<?x4x8x1xf32> %pack =
599
+ // / tensor.pack %collapsed outer_dims_perm = [0, 1] inner_dims_pos = [0, 1]
600
+ // / inner_tiles = [8, 1] into %out : tensor<?x4xf32> -> tensor<?x4x8x1xf32>
601
+ // /
602
+ // / Can be transformed into:
603
+ // /
604
+ // / %out = tensor.empty() : tensor<?x2x4x8x1xf32>
605
+ // / %pack = tensor.pack %in outer_dims_perm = [1, 2] inner_dims_pos = [1, 2]
606
+ // / inner_tiles = [8, 1] into %out : tensor<?x16x4xf32> -> tensor<?x2x4x8x1xf32>
607
+ // / %collapsed = tensor.collapse_shape %1 [[0, 1], 2, 3, 4] :
608
+ // / tensor<?x2x4x8x1xf32> into tensor<?x4x8x1>
575
609
static LogicalResult
576
610
bubbleUpPackOpThroughCollapseShape (tensor::CollapseShapeOp collapseOp,
577
611
tensor::PackOp packOp,
@@ -580,27 +614,23 @@ bubbleUpPackOpThroughCollapseShape(tensor::CollapseShapeOp collapseOp,
580
614
ArrayRef<int64_t > innerDimsPos = packOp.getInnerDimsPos ();
581
615
ArrayRef<int64_t > outerDimsPerm = packOp.getOuterDimsPerm ();
582
616
583
- if (llvm::any_of (innerTileSizes,
584
- [](int64_t size) { return ShapedType::isDynamic (size); })) {
585
- return failure ();
586
- }
587
-
588
617
ArrayRef<int64_t > srcShape = collapseOp.getSrcType ().getShape ();
589
618
SmallVector<ReassociationIndices> reassocIndices =
590
619
collapseOp.getReassociationIndices ();
591
- SmallVector<int64_t > baseDimsPos =
620
+ SmallVector<int64_t > projectedInnerDimsPos =
592
621
projectToInnerMostNonUnitDimsPos (innerDimsPos, reassocIndices, srcShape);
593
622
594
- // Check if the base dims before reassociation are divisible by the inner tile
623
+ // Check if the projected dims on the source are divisible by the inner tile
595
624
// sizes.
596
- for (auto [basePos , tileSize] :
597
- llvm::zip_equal (baseDimsPos , innerTileSizes)) {
598
- int64_t dim = srcShape[basePos ];
599
- if (ShapedType::isDynamic (dim) || (dim % tileSize) != 0 ) {
625
+ for (auto [projectedPos , tileSize] :
626
+ llvm::zip_equal (projectedInnerDimsPos , innerTileSizes)) {
627
+ int64_t dim = srcShape[projectedPos ];
628
+ if (ShapedType::isDynamic (dim) || (dim % tileSize) != 0 )
600
629
return failure ();
601
- }
602
630
}
603
- // Expand the outer dims perm with associated src dims.
631
+ // Expand the outer dims permutation with the associated source dims for the
632
+ // new permutation after bubbling. This is because moving a collapsed dim is
633
+ // equivalent to moving the associated source dims together.
604
634
SmallVector<int64_t > newOuterDimsPerm;
605
635
for (auto outerPos : outerDimsPerm) {
606
636
newOuterDimsPerm.insert (newOuterDimsPerm.end (),
@@ -610,23 +640,19 @@ bubbleUpPackOpThroughCollapseShape(tensor::CollapseShapeOp collapseOp,
610
640
611
641
auto emptyOp = tensor::PackOp::createDestinationTensor (
612
642
rewriter, packOp.getLoc (), collapseOp.getSrc (), packOp.getMixedTiles (),
613
- baseDimsPos , newOuterDimsPerm);
643
+ projectedInnerDimsPos , newOuterDimsPerm);
614
644
auto newPackOp = rewriter.create <tensor::PackOp>(
615
- packOp.getLoc (), collapseOp.getSrc (), emptyOp, baseDimsPos ,
645
+ packOp.getLoc (), collapseOp.getSrc (), emptyOp, projectedInnerDimsPos ,
616
646
packOp.getMixedTiles (), packOp.getPaddingValue (), newOuterDimsPerm);
617
647
618
- SmallVector<ReassociationIndices> newReassocIndices;
619
- int64_t currPos = 0 ;
620
- for (auto outerPos : outerDimsPerm) {
621
- int64_t start = currPos;
622
- int64_t end = start + reassocIndices[outerPos].size ();
623
- newReassocIndices.push_back (llvm::to_vector (llvm::seq (start, end)));
624
- currPos = end;
625
- }
626
- for (auto unused : innerTileSizes) {
627
- (void )unused;
628
- newReassocIndices.push_back ({currPos});
629
- currPos += 1 ;
648
+ SmallVector<ReassociationIndices> newReassocIndices = reassocIndices;
649
+ // First build reassociations on the outer dims after the permutation.
650
+ int64_t lastPos =
651
+ applyPermutationAndReindexReassoc (newReassocIndices, outerDimsPerm);
652
+ // Then add direct mapping for the inner tile dims.
653
+ for (size_t i = 0 ; i < innerDimsPos.size (); ++i) {
654
+ newReassocIndices.push_back ({lastPos});
655
+ lastPos += 1 ;
630
656
}
631
657
632
658
auto newCollapseOp = rewriter.create <tensor::CollapseShapeOp>(
@@ -644,18 +670,28 @@ class BubbleUpPackOpThroughReshapeOp final
644
670
645
671
LogicalResult matchAndRewrite (tensor::PackOp packOp,
646
672
PatternRewriter &rewriter) const override {
647
- if (packOp.getPaddingValue ())
673
+ // User controlled propagation function.
674
+ if (!controlFn (packOp))
648
675
return failure ();
649
676
650
677
Operation *srcOp = packOp.getSource ().getDefiningOp ();
678
+ // Currently only support when the pack op is the only user.
651
679
if (!srcOp || !(srcOp->getNumResults () == 1 ) ||
652
- !srcOp->getResult (0 ).hasOneUse ())
680
+ !srcOp->getResult (0 ).hasOneUse ()) {
653
681
return failure ();
654
-
655
- if (auto collapseOp = dyn_cast<tensor::CollapseShapeOp>(srcOp)) {
656
- return bubbleUpPackOpThroughCollapseShape (collapseOp, packOp, rewriter);
657
682
}
658
- return failure ();
683
+ // Currently only support static inner tile sizes.
684
+ if (llvm::any_of (packOp.getStaticTiles (), [](int64_t size) {
685
+ return ShapedType::isDynamic (size);
686
+ })) {
687
+ return failure ();
688
+ }
689
+
690
+ return TypeSwitch<Operation *, LogicalResult>(srcOp)
691
+ .Case ([&](tensor::CollapseShapeOp op) {
692
+ return bubbleUpPackOpThroughCollapseShape (op, packOp, rewriter);
693
+ })
694
+ .Default ([](Operation *) { return failure (); });
659
695
}
660
696
661
697
private:
@@ -666,65 +702,59 @@ static LogicalResult
666
702
pushDownUnPackOpThroughExpandShape (tensor::UnPackOp unPackOp,
667
703
tensor::ExpandShapeOp expandOp,
668
704
PatternRewriter &rewriter) {
669
-
670
705
SmallVector<int64_t > innerTileSizes = unPackOp.getStaticTiles ();
671
706
ArrayRef<int64_t > innerDimsPos = unPackOp.getInnerDimsPos ();
672
707
ArrayRef<int64_t > outerDimsPerm = unPackOp.getOuterDimsPerm ();
673
708
674
- if (llvm::any_of (innerTileSizes,
675
- [](int64_t size) { return ShapedType::isDynamic (size); })) {
676
- return failure ();
677
- }
678
-
679
709
ArrayRef<int64_t > dstShape = expandOp.getType ().getShape ();
680
710
SmallVector<ReassociationIndices> reassocIndices =
681
711
expandOp.getReassociationIndices ();
682
- SmallVector<int64_t > baseDimsPos =
712
+ SmallVector<int64_t > projectedInnerDimsPos =
683
713
projectToInnerMostNonUnitDimsPos (innerDimsPos, reassocIndices, dstShape);
684
714
685
- // Check if the base dims after reassociation are divisible by the inner tile
715
+ // Check if the projected dims on the dest are divisible by the inner tile
686
716
// sizes.
687
- for (auto [basePos, tileSize] :
688
- llvm::zip_equal (baseDimsPos, innerTileSizes)) {
689
- int64_t dim = dstShape[basePos];
690
- if (ShapedType::isDynamic (dim) || dstShape[basePos] % tileSize != 0 ) {
717
+ for (auto [projectedPos, tileSize] :
718
+ llvm::zip_equal (projectedInnerDimsPos, innerTileSizes)) {
719
+ int64_t dim = dstShape[projectedPos];
720
+ if (ShapedType::isDynamic (dim) ||
721
+ (dstShape[projectedPos] % tileSize) != 0 ) {
691
722
return failure ();
692
723
}
693
724
}
694
- // Expand the outer dims perm with associated src dims.
725
+ // Expand the outer dims permutation with the associated expanded dims for the
726
+ // new permutation after pushing. This is because moving a source dim is
727
+ // equivalent to moving the associated expanded dims together.
695
728
SmallVector<int64_t > newOuterDimsPerm;
696
729
for (auto outerPos : outerDimsPerm) {
697
730
newOuterDimsPerm.insert (newOuterDimsPerm.end (),
698
731
reassocIndices[outerPos].begin (),
699
732
reassocIndices[outerPos].end ());
700
733
}
701
734
702
- SmallVector<ReassociationIndices> newReassocIndices;
703
- int64_t currPos = 0 ;
704
- for (auto outerPos : outerDimsPerm) {
705
- int64_t start = currPos;
706
- int64_t end = start + reassocIndices[outerPos].size ();
707
- newReassocIndices.push_back (llvm::to_vector (llvm::seq (start, end)));
708
- currPos = end;
709
- }
710
- for (auto unused : innerTileSizes) {
711
- (void )unused;
712
- newReassocIndices.push_back ({currPos});
713
- currPos += 1 ;
735
+ SmallVector<ReassociationIndices> newReassocIndices = reassocIndices;
736
+ // First build reassociations on the outer dims after the permutation.
737
+ int64_t lastPos =
738
+ applyPermutationAndReindexReassoc (newReassocIndices, outerDimsPerm);
739
+ // Then add direct mapping for the inner tile dims.
740
+ for (size_t i = 0 ; i < innerDimsPos.size (); ++i) {
741
+ newReassocIndices.push_back ({lastPos});
742
+ lastPos += 1 ;
714
743
}
715
744
716
- RankedTensorType newExpandType = tensor::PackOp::inferPackedType (
717
- expandOp.getType (), innerTileSizes, baseDimsPos, newOuterDimsPerm);
745
+ RankedTensorType newExpandType =
746
+ tensor::PackOp::inferPackedType (expandOp.getType (), innerTileSizes,
747
+ projectedInnerDimsPos, newOuterDimsPerm);
718
748
auto newExpandOp = rewriter.create <tensor::ExpandShapeOp>(
719
749
expandOp.getLoc (), newExpandType, unPackOp.getSource (),
720
750
newReassocIndices);
721
751
722
752
auto emptyOp = tensor::UnPackOp::createDestinationTensor (
723
753
rewriter, unPackOp.getLoc (), newExpandOp, unPackOp.getMixedTiles (),
724
- baseDimsPos , newOuterDimsPerm);
754
+ projectedInnerDimsPos , newOuterDimsPerm);
725
755
auto newUnPackOp = rewriter.create <tensor::UnPackOp>(
726
- unPackOp.getLoc (), newExpandOp.getResult (), emptyOp, baseDimsPos,
727
- unPackOp.getMixedTiles (), newOuterDimsPerm);
756
+ unPackOp.getLoc (), newExpandOp.getResult (), emptyOp,
757
+ projectedInnerDimsPos, unPackOp.getMixedTiles (), newOuterDimsPerm);
728
758
rewriter.replaceOp (expandOp, newUnPackOp);
729
759
730
760
return success ();
@@ -740,16 +770,28 @@ class PushDownUnPackOpThroughReshapeOp final
740
770
741
771
LogicalResult matchAndRewrite (tensor::UnPackOp unPackOp,
742
772
PatternRewriter &rewriter) const override {
773
+ // User controlled propagation function.
774
+ if (!controlFn (unPackOp))
775
+ return failure ();
776
+
743
777
Value result = unPackOp.getResult ();
778
+ // Currently only support unpack op with the single user.
744
779
if (!result.hasOneUse ()) {
745
780
return failure ();
746
781
}
747
- Operation *userOp = *result.user_begin ();
748
-
749
- if (auto expandOp = dyn_cast<tensor::ExpandShapeOp>(userOp)) {
750
- return pushDownUnPackOpThroughExpandShape (unPackOp, expandOp, rewriter);
782
+ // Currently only support static inner tile sizes.
783
+ if (llvm::any_of (unPackOp.getStaticTiles (), [](int64_t size) {
784
+ return ShapedType::isDynamic (size);
785
+ })) {
786
+ return failure ();
751
787
}
752
- return failure ();
788
+
789
+ Operation *userOp = *result.user_begin ();
790
+ return TypeSwitch<Operation *, LogicalResult>(userOp)
791
+ .Case ([&](tensor::ExpandShapeOp op) {
792
+ return pushDownUnPackOpThroughExpandShape (unPackOp, op, rewriter);
793
+ })
794
+ .Default ([](Operation *) { return failure (); });
753
795
}
754
796
755
797
private:
0 commit comments