@@ -553,17 +553,26 @@ class BubbleUpPackThroughPadOp final : public OpRewritePattern<tensor::PackOp> {
553
553
ControlPropagationFn controlFn;
554
554
};
555
555
556
+ // / Project dimsPos to the inner-most non-unit dim pos with reassocIndices.
557
+ // /
558
+ // / For example, given dimsPos [0, 2], reassocIndices [[0, 1], [2, 3]], and
559
+ // / targetShape [16, 16, 32, 1], it returns [1, 2]. Because for pos 0, the
560
+ // / inner-most projected dim in pos [0, 1] is 1. And for pos 2, the inner-most
561
+ // / non-unit projected dims in pos [2, 3] is 2.
562
+ // /
563
+ // / If all candidates in a reassociation are unit dims, it chooses the
564
+ // / inner-most dim pos.
556
565
static SmallVector<int64_t >
557
566
projectToInnerMostNonUnitDimsPos (ArrayRef<int64_t > dimsPos,
558
567
ArrayRef<ReassociationIndices> reassocIndices,
559
- ArrayRef<int64_t > baseShape ) {
568
+ ArrayRef<int64_t > targetShape ) {
560
569
SmallVector<int64_t > projectedDimsPos;
561
570
for (auto pos : dimsPos) {
562
571
// In the case all dims are unit, this will return the inner-most one.
563
572
int64_t projectedPos = reassocIndices[pos].back ();
564
573
for (auto it = reassocIndices[pos].rbegin ();
565
574
it != reassocIndices[pos].rend (); ++it) {
566
- int64_t dim = baseShape [*it];
575
+ int64_t dim = targetShape [*it];
567
576
if (dim > 1 || ShapedType::isDynamic (dim)) {
568
577
projectedPos = *it;
569
578
break ;
@@ -574,32 +583,36 @@ projectToInnerMostNonUnitDimsPos(ArrayRef<int64_t> dimsPos,
574
583
return projectedDimsPos;
575
584
}
576
585
577
- static bool
578
- isProjectedDimsDivisibleByTileSizes (ArrayRef<int64_t > projectedDimsPos,
579
- ArrayRef<int64_t > targetShape,
580
- ArrayRef<int64_t > tileSizes) {
581
- for (auto [projectedPos, tileSize] :
582
- llvm::zip_equal (projectedDimsPos, tileSizes)) {
583
- int64_t dim = targetShape[projectedPos];
586
+ // / Check if all dims in dimsPos are divisible by the corresponding tile sizes.
587
+ static bool isDimsDivisibleByTileSizes (ArrayRef<int64_t > dimsPos,
588
+ ArrayRef<int64_t > shape,
589
+ ArrayRef<int64_t > tileSizes) {
590
+ for (auto [pos, tileSize] : llvm::zip_equal (dimsPos, tileSizes)) {
591
+ int64_t dim = shape[pos];
584
592
if (ShapedType::isDynamic (dim) || (dim % tileSize) != 0 )
585
593
return false ;
586
594
}
587
595
return true ;
588
596
}
589
597
598
+ // / Permutate the reassociation indices and reindex them in the sequence order.
599
+ // / Returns the next dim pos in the sequence.
600
+ // /
601
+ // / For example, given reassocIndices [[0, 1], [2]] and permutation [1, 0], it
602
+ // / applies the permutation to get [[2], [0, 1]] and reindexes the indices into
603
+ // / [[0], [1, 2]].
590
604
static int64_t applyPermutationAndReindexReassoc (
591
- SmallVector<ReassociationIndices> &reassociationIndices,
592
- ArrayRef<int64_t > dimsPerm) {
593
- applyPermutationToVector<ReassociationIndices>(reassociationIndices,
594
- dimsPerm);
595
- int64_t lastPos = 0 ;
596
- for (ReassociationIndices &indices : reassociationIndices) {
605
+ SmallVector<ReassociationIndices> &reassocIndices,
606
+ ArrayRef<int64_t > permutation) {
607
+ applyPermutationToVector<ReassociationIndices>(reassocIndices, permutation);
608
+ int64_t nextPos = 0 ;
609
+ for (ReassociationIndices &indices : reassocIndices) {
597
610
for (auto &index : indices) {
598
- index = lastPos ;
599
- lastPos += 1 ;
611
+ index = nextPos ;
612
+ nextPos += 1 ;
600
613
}
601
614
}
602
- return lastPos ;
615
+ return nextPos ;
603
616
}
604
617
605
618
// / Bubble up pack op through collapse shape op when the packed dims can be
@@ -614,7 +627,7 @@ static int64_t applyPermutationAndReindexReassoc(
614
627
// / inner_dims_pos = [0, 1] inner_tiles = [8, 1] into %empty
615
628
// / : tensor<?x4xf32> -> tensor<?x4x8x1xf32>
616
629
// /
617
- // / Can be transformed into:
630
+ // / can be transformed into:
618
631
// /
619
632
// / %pack = tensor.pack %in outer_dims_perm = [1, 2]
620
633
// / inner_dims_pos = [1, 2] inner_tiles = [8, 1] into %empty
@@ -632,11 +645,18 @@ bubbleUpPackOpThroughCollapseShape(tensor::CollapseShapeOp collapseOp,
632
645
ArrayRef<int64_t > srcShape = collapseOp.getSrcType ().getShape ();
633
646
SmallVector<ReassociationIndices> reassocIndices =
634
647
collapseOp.getReassociationIndices ();
648
+ // Project inner tile pos to the dim pos before collapsing. For example, if
649
+ // dims [x, y] is collapsed into [z], packing on dim z can be projected back
650
+ // to pack on dim y.
651
+ //
652
+ // Project to inner-most non-unit dims to increase the chance that they can be
653
+ // divided by the inner tile sizes. This is correct because for [..., x, 1],
654
+ // packing on dim 1 is equivalent to packing on dim x.
635
655
SmallVector<int64_t > projectedInnerDimsPos =
636
656
projectToInnerMostNonUnitDimsPos (innerDimsPos, reassocIndices, srcShape);
637
657
638
- if (!isProjectedDimsDivisibleByTileSizes (projectedInnerDimsPos, srcShape,
639
- innerTileSizes)) {
658
+ if (!isDimsDivisibleByTileSizes (projectedInnerDimsPos, srcShape,
659
+ innerTileSizes)) {
640
660
return failure ();
641
661
}
642
662
// Expand the outer dims permutation with the associated source dims for the
@@ -658,14 +678,14 @@ bubbleUpPackOpThroughCollapseShape(tensor::CollapseShapeOp collapseOp,
658
678
659
679
SmallVector<ReassociationIndices> newReassocIndices = reassocIndices;
660
680
// First apply the permutation on the reassociations of the outer dims.
661
- // For example given the permutation [1, 0], the reassociations: [[0, 1], [2]]
681
+ // For example given the permutation [1, 0], the reassociations [[0, 1], [2]]
662
682
// -> [[0], [1, 2]]
663
- int64_t lastPos =
683
+ int64_t nextPos =
664
684
applyPermutationAndReindexReassoc (newReassocIndices, outerDimsPerm);
665
685
// Then add direct mapping for the inner tile dims.
666
686
for (size_t i = 0 ; i < innerDimsPos.size (); ++i) {
667
- newReassocIndices.push_back ({lastPos });
668
- lastPos += 1 ;
687
+ newReassocIndices.push_back ({nextPos });
688
+ nextPos += 1 ;
669
689
}
670
690
671
691
auto newCollapseOp = rewriter.create <tensor::CollapseShapeOp>(
@@ -723,7 +743,7 @@ class BubbleUpPackOpThroughReshapeOp final
723
743
// / %expanded = tensor.expand_shape %unpack [[0, 1], [2]]
724
744
// / : tensor<?x256xf32> into tensor<?x256x256xf32>
725
745
// /
726
- // / Can be transformed into:
746
+ // / can be transformed into:
727
747
// /
728
748
// / %expanded = tensor.expand_shape %ain [[0, 1], [2], [3], [4]]
729
749
// / : tensor<?x32x8x8xf32> into tensor<?x32x32x8x8xf32>
@@ -741,11 +761,18 @@ pushDownUnPackOpThroughExpandShape(tensor::UnPackOp unPackOp,
741
761
ArrayRef<int64_t > dstShape = expandOp.getType ().getShape ();
742
762
SmallVector<ReassociationIndices> reassocIndices =
743
763
expandOp.getReassociationIndices ();
764
+ // Project inner tile pos to the dim pos after expanding. For example, if dims
765
+ // [z] is expanded into [x, y], unpacking on dim z can be projected to unpack
766
+ // on dim y.
767
+ //
768
+ // Project to inner-most non-unit dims to increase the chance that they can be
769
+ // divided by the inner tile sizes. This is correct because for [..., x, 1],
770
+ // unpacking on dim 1 is equivalent to unpacking on dim x.
744
771
SmallVector<int64_t > projectedInnerDimsPos =
745
772
projectToInnerMostNonUnitDimsPos (innerDimsPos, reassocIndices, dstShape);
746
773
747
- if (!isProjectedDimsDivisibleByTileSizes (projectedInnerDimsPos, dstShape,
748
- innerTileSizes)) {
774
+ if (!isDimsDivisibleByTileSizes (projectedInnerDimsPos, dstShape,
775
+ innerTileSizes)) {
749
776
return failure ();
750
777
}
751
778
// Expand the outer dims permutation with the associated expanded dims for the
@@ -760,14 +787,14 @@ pushDownUnPackOpThroughExpandShape(tensor::UnPackOp unPackOp,
760
787
761
788
SmallVector<ReassociationIndices> newReassocIndices = reassocIndices;
762
789
// First apply the permutation on the reassociations of the outer dims.
763
- // For example given the permutation [1, 0], the reassociations: [[0, 1], [2]]
790
+ // For example given the permutation [1, 0], the reassociations [[0, 1], [2]]
764
791
// -> [[0], [1, 2]]
765
- int64_t lastPos =
792
+ int64_t nextPos =
766
793
applyPermutationAndReindexReassoc (newReassocIndices, outerDimsPerm);
767
794
// Then add direct mapping for the inner tile dims.
768
795
for (size_t i = 0 ; i < innerDimsPos.size (); ++i) {
769
- newReassocIndices.push_back ({lastPos });
770
- lastPos += 1 ;
796
+ newReassocIndices.push_back ({nextPos });
797
+ nextPos += 1 ;
771
798
}
772
799
773
800
RankedTensorType newExpandType =
0 commit comments