@@ -552,6 +552,26 @@ class BubbleUpPackThroughPadOp final : public OpRewritePattern<tensor::PackOp> {
552
552
ControlPropagationFn controlFn;
553
553
};
554
554
555
+ static SmallVector<int64_t >
556
+ projectToInnerMostNonUnitDimsPos (ArrayRef<int64_t > dimsPos,
557
+ ArrayRef<ReassociationIndices> reassocIndices,
558
+ ArrayRef<int64_t > baseShape) {
559
+ SmallVector<int64_t > projectedDimsPos;
560
+ for (auto pos : dimsPos) {
561
+ int64_t projectedPos = -1 ;
562
+ for (auto it = reassocIndices[pos].rbegin ();
563
+ it != reassocIndices[pos].rend (); ++it) {
564
+ projectedPos = *it;
565
+ if (baseShape[projectedPos] > 1 ) {
566
+ break ;
567
+ }
568
+ }
569
+ assert (projectedPos != -1 && " projected dim not found" );
570
+ projectedDimsPos.push_back (projectedPos);
571
+ }
572
+ return projectedDimsPos;
573
+ }
574
+
555
575
static LogicalResult
556
576
bubbleUpPackOpThroughCollapseShape (tensor::CollapseShapeOp collapseOp,
557
577
tensor::PackOp packOp,
@@ -568,10 +588,9 @@ bubbleUpPackOpThroughCollapseShape(tensor::CollapseShapeOp collapseOp,
568
588
ArrayRef<int64_t > srcShape = collapseOp.getSrcType ().getShape ();
569
589
SmallVector<ReassociationIndices> reassocIndices =
570
590
collapseOp.getReassociationIndices ();
571
- SmallVector<int64_t > baseDimsPos;
572
- for (auto pos : innerDimsPos) {
573
- baseDimsPos.push_back (reassocIndices[pos].back ());
574
- }
591
+ SmallVector<int64_t > baseDimsPos =
592
+ projectToInnerMostNonUnitDimsPos (innerDimsPos, reassocIndices, srcShape);
593
+
575
594
// Check if the base dims before reassociation are divisible by the inner tile
576
595
// sizes.
577
596
for (auto [basePos, tileSize] :
@@ -590,11 +609,11 @@ bubbleUpPackOpThroughCollapseShape(tensor::CollapseShapeOp collapseOp,
590
609
}
591
610
592
611
auto emptyOp = tensor::PackOp::createDestinationTensor (
593
- rewriter, packOp.getLoc (), collapseOp.getSrc (), packOp.getMixedTiles (), baseDimsPos,
594
- newOuterDimsPerm);
612
+ rewriter, packOp.getLoc (), collapseOp.getSrc (), packOp.getMixedTiles (),
613
+ baseDimsPos, newOuterDimsPerm);
595
614
auto newPackOp = rewriter.create <tensor::PackOp>(
596
- packOp.getLoc (), collapseOp.getSrc (), emptyOp, baseDimsPos, packOp. getMixedTiles (),
597
- packOp.getPaddingValue (), newOuterDimsPerm);
615
+ packOp.getLoc (), collapseOp.getSrc (), emptyOp, baseDimsPos,
616
+ packOp.getMixedTiles (), packOp. getPaddingValue (), newOuterDimsPerm);
598
617
599
618
SmallVector<ReassociationIndices> newReassocIndices;
600
619
int64_t currPos = 0 ;
@@ -660,10 +679,9 @@ pushDownUnPackOpThroughExpandShape(tensor::UnPackOp unPackOp,
660
679
ArrayRef<int64_t > dstShape = expandOp.getType ().getShape ();
661
680
SmallVector<ReassociationIndices> reassocIndices =
662
681
expandOp.getReassociationIndices ();
663
- SmallVector<int64_t > baseDimsPos;
664
- for (auto pos : innerDimsPos) {
665
- baseDimsPos.push_back (reassocIndices[pos].back ());
666
- }
682
+ SmallVector<int64_t > baseDimsPos =
683
+ projectToInnerMostNonUnitDimsPos (innerDimsPos, reassocIndices, dstShape);
684
+
667
685
// Check if the base dims after reassociation are divisible by the inner tile
668
686
// sizes.
669
687
for (auto [basePos, tileSize] :
@@ -702,8 +720,8 @@ pushDownUnPackOpThroughExpandShape(tensor::UnPackOp unPackOp,
702
720
newReassocIndices);
703
721
704
722
auto emptyOp = tensor::UnPackOp::createDestinationTensor (
705
- rewriter, unPackOp.getLoc (), newExpandOp, unPackOp.getMixedTiles (), baseDimsPos,
706
- newOuterDimsPerm);
723
+ rewriter, unPackOp.getLoc (), newExpandOp, unPackOp.getMixedTiles (),
724
+ baseDimsPos, newOuterDimsPerm);
707
725
auto newUnPackOp = rewriter.create <tensor::UnPackOp>(
708
726
unPackOp.getLoc (), newExpandOp.getResult (), emptyOp, baseDimsPos,
709
727
unPackOp.getMixedTiles (), newOuterDimsPerm);
0 commit comments