@@ -552,6 +552,192 @@ class BubbleUpPackThroughPadOp final : public OpRewritePattern<tensor::PackOp> {
552
552
ControlPropagationFn controlFn;
553
553
};
554
554
555
+ static LogicalResult
556
+ bubbleUpPackOpThroughCollapseShape (tensor::CollapseShapeOp collapseOp,
557
+ tensor::PackOp packOp,
558
+ PatternRewriter &rewriter) {
559
+ SmallVector<int64_t > innerTileSizes = packOp.getStaticTiles ();
560
+ ArrayRef<int64_t > innerDimsPos = packOp.getInnerDimsPos ();
561
+ ArrayRef<int64_t > outerDimsPerm = packOp.getOuterDimsPerm ();
562
+
563
+ if (llvm::any_of (innerTileSizes,
564
+ [](int64_t size) { return ShapedType::isDynamic (size); })) {
565
+ return failure ();
566
+ }
567
+
568
+ ArrayRef<int64_t > srcShape = collapseOp.getSrcType ().getShape ();
569
+ SmallVector<ReassociationIndices> reassocIndices =
570
+ collapseOp.getReassociationIndices ();
571
+ SmallVector<int64_t > baseDimsPos;
572
+ for (auto pos : innerDimsPos) {
573
+ baseDimsPos.push_back (reassocIndices[pos].back ());
574
+ }
575
+ // Check if the base dims before reassociation are divisible by the inner tile
576
+ // sizes.
577
+ for (auto [basePos, tileSize] :
578
+ llvm::zip_equal (baseDimsPos, innerTileSizes)) {
579
+ int64_t dim = srcShape[basePos];
580
+ if (ShapedType::isDynamic (dim) || (dim % tileSize) != 0 ) {
581
+ return failure ();
582
+ }
583
+ }
584
+ // Expand the outer dims perm with associated src dims.
585
+ SmallVector<int64_t > newOuterDimsPerm;
586
+ for (auto outerPos : outerDimsPerm) {
587
+ newOuterDimsPerm.insert (newOuterDimsPerm.end (),
588
+ reassocIndices[outerPos].begin (),
589
+ reassocIndices[outerPos].end ());
590
+ }
591
+
592
+ auto emptyOp = tensor::PackOp::createDestinationTensor (
593
+ rewriter, packOp.getLoc (), collapseOp.getSrc (), packOp.getMixedTiles (), baseDimsPos,
594
+ newOuterDimsPerm);
595
+ auto newPackOp = rewriter.create <tensor::PackOp>(
596
+ packOp.getLoc (), collapseOp.getSrc (), emptyOp, baseDimsPos, packOp.getMixedTiles (),
597
+ packOp.getPaddingValue (), newOuterDimsPerm);
598
+
599
+ SmallVector<ReassociationIndices> newReassocIndices;
600
+ int64_t currPos = 0 ;
601
+ for (auto outerPos : outerDimsPerm) {
602
+ int64_t start = currPos;
603
+ int64_t end = start + reassocIndices[outerPos].size ();
604
+ newReassocIndices.push_back (llvm::to_vector (llvm::seq (start, end)));
605
+ currPos = end;
606
+ }
607
+ for (auto unused : innerTileSizes) {
608
+ (void )unused;
609
+ newReassocIndices.push_back ({currPos});
610
+ currPos += 1 ;
611
+ }
612
+
613
+ auto newCollapseOp = rewriter.create <tensor::CollapseShapeOp>(
614
+ collapseOp.getLoc (), packOp.getType (), newPackOp, newReassocIndices);
615
+ rewriter.replaceOp (packOp, newCollapseOp);
616
+
617
+ return success ();
618
+ }
619
+
620
+ class BubbleUpPackOpThroughReshapeOp final
621
+ : public OpRewritePattern<tensor::PackOp> {
622
+ public:
623
+ BubbleUpPackOpThroughReshapeOp (MLIRContext *context, ControlPropagationFn fun)
624
+ : OpRewritePattern<tensor::PackOp>(context), controlFn(std::move(fun)) {}
625
+
626
+ LogicalResult matchAndRewrite (tensor::PackOp packOp,
627
+ PatternRewriter &rewriter) const override {
628
+ if (packOp.getPaddingValue ())
629
+ return failure ();
630
+
631
+ Operation *srcOp = packOp.getSource ().getDefiningOp ();
632
+ if (!srcOp || !(srcOp->getNumResults () == 1 ) ||
633
+ !srcOp->getResult (0 ).hasOneUse ())
634
+ return failure ();
635
+
636
+ if (auto collapseOp = dyn_cast<tensor::CollapseShapeOp>(srcOp)) {
637
+ return bubbleUpPackOpThroughCollapseShape (collapseOp, packOp, rewriter);
638
+ }
639
+ return failure ();
640
+ }
641
+
642
+ private:
643
+ ControlPropagationFn controlFn;
644
+ };
645
+
646
+ static LogicalResult
647
+ pushDownUnPackOpThroughExpandShape (tensor::UnPackOp unPackOp,
648
+ tensor::ExpandShapeOp expandOp,
649
+ PatternRewriter &rewriter) {
650
+
651
+ SmallVector<int64_t > innerTileSizes = unPackOp.getStaticTiles ();
652
+ ArrayRef<int64_t > innerDimsPos = unPackOp.getInnerDimsPos ();
653
+ ArrayRef<int64_t > outerDimsPerm = unPackOp.getOuterDimsPerm ();
654
+
655
+ if (llvm::any_of (innerTileSizes,
656
+ [](int64_t size) { return ShapedType::isDynamic (size); })) {
657
+ return failure ();
658
+ }
659
+
660
+ ArrayRef<int64_t > dstShape = expandOp.getType ().getShape ();
661
+ SmallVector<ReassociationIndices> reassocIndices =
662
+ expandOp.getReassociationIndices ();
663
+ SmallVector<int64_t > baseDimsPos;
664
+ for (auto pos : innerDimsPos) {
665
+ baseDimsPos.push_back (reassocIndices[pos].back ());
666
+ }
667
+ // Check if the base dims after reassociation are divisible by the inner tile
668
+ // sizes.
669
+ for (auto [basePos, tileSize] :
670
+ llvm::zip_equal (baseDimsPos, innerTileSizes)) {
671
+ int64_t dim = dstShape[basePos];
672
+ if (ShapedType::isDynamic (dim) || dstShape[basePos] % tileSize != 0 ) {
673
+ return failure ();
674
+ }
675
+ }
676
+ // Expand the outer dims perm with associated src dims.
677
+ SmallVector<int64_t > newOuterDimsPerm;
678
+ for (auto outerPos : outerDimsPerm) {
679
+ newOuterDimsPerm.insert (newOuterDimsPerm.end (),
680
+ reassocIndices[outerPos].begin (),
681
+ reassocIndices[outerPos].end ());
682
+ }
683
+
684
+ SmallVector<ReassociationIndices> newReassocIndices;
685
+ int64_t currPos = 0 ;
686
+ for (auto outerPos : outerDimsPerm) {
687
+ int64_t start = currPos;
688
+ int64_t end = start + reassocIndices[outerPos].size ();
689
+ newReassocIndices.push_back (llvm::to_vector (llvm::seq (start, end)));
690
+ currPos = end;
691
+ }
692
+ for (auto unused : innerTileSizes) {
693
+ (void )unused;
694
+ newReassocIndices.push_back ({currPos});
695
+ currPos += 1 ;
696
+ }
697
+
698
+ RankedTensorType newExpandType = tensor::PackOp::inferPackedType (
699
+ expandOp.getType (), innerTileSizes, baseDimsPos, newOuterDimsPerm);
700
+ auto newExpandOp = rewriter.create <tensor::ExpandShapeOp>(
701
+ expandOp.getLoc (), newExpandType, unPackOp.getSource (),
702
+ newReassocIndices);
703
+
704
+ auto emptyOp = tensor::UnPackOp::createDestinationTensor (
705
+ rewriter, unPackOp.getLoc (), newExpandOp, unPackOp.getMixedTiles (), baseDimsPos,
706
+ newOuterDimsPerm);
707
+ auto newUnPackOp = rewriter.create <tensor::UnPackOp>(
708
+ unPackOp.getLoc (), newExpandOp.getResult (), emptyOp, baseDimsPos,
709
+ unPackOp.getMixedTiles (), newOuterDimsPerm);
710
+ rewriter.replaceOp (expandOp, newUnPackOp);
711
+
712
+ return success ();
713
+ }
714
+
715
+ class PushDownUnPackOpThroughReshapeOp final
716
+ : public OpRewritePattern<tensor::UnPackOp> {
717
+ public:
718
+ PushDownUnPackOpThroughReshapeOp (MLIRContext *context,
719
+ ControlPropagationFn fun)
720
+ : OpRewritePattern<tensor::UnPackOp>(context), controlFn(std::move(fun)) {
721
+ }
722
+
723
+ LogicalResult matchAndRewrite (tensor::UnPackOp unPackOp,
724
+ PatternRewriter &rewriter) const override {
725
+ Value result = unPackOp.getResult ();
726
+ if (!result.hasOneUse ()) {
727
+ return failure ();
728
+ }
729
+ Operation *userOp = *result.user_begin ();
730
+
731
+ if (auto expandOp = dyn_cast<tensor::ExpandShapeOp>(userOp)) {
732
+ return pushDownUnPackOpThroughExpandShape (unPackOp, expandOp, rewriter);
733
+ }
734
+ return failure ();
735
+ }
736
+
737
+ private:
738
+ ControlPropagationFn controlFn;
739
+ };
740
+
555
741
// TODO: Relax this restriction. We should unpack a generic op also
556
742
// in the presence of multiple unpack ops as producers.
557
743
// / Return the unpacked operand, if present, for the current generic op.
@@ -774,6 +960,7 @@ void mlir::linalg::populateDataLayoutPropagationPatterns(
774
960
const ControlPropagationFn &controlPackUnPackPropagation) {
775
961
patterns
776
962
.insert <BubbleUpPackOpThroughGenericOpPattern, BubbleUpPackThroughPadOp,
777
- PushDownUnPackOpThroughGenericOp, PushDownUnPackThroughPadOp>(
963
+ BubbleUpPackOpThroughReshapeOp, PushDownUnPackOpThroughGenericOp,
964
+ PushDownUnPackThroughPadOp, PushDownUnPackOpThroughReshapeOp>(
778
965
patterns.getContext (), controlPackUnPackPropagation);
779
966
}
0 commit comments