|
17 | 17 | #include "mlir/Dialect/Utils/IndexingUtils.h"
|
18 | 18 | #include "mlir/IR/Dominance.h"
|
19 | 19 | #include "mlir/Transforms/GreedyPatternRewriteDriver.h"
|
| 20 | +#include "llvm/ADT/SetOperations.h" |
| 21 | +#include "llvm/ADT/SetVector.h" |
20 | 22 | #include "llvm/ADT/TypeSwitch.h"
|
21 | 23 | #include "llvm/Support/Debug.h"
|
22 | 24 | #include <optional>
|
@@ -694,6 +696,131 @@ bubbleUpPackOpThroughCollapseShape(tensor::CollapseShapeOp collapseOp,
|
694 | 696 | return success();
|
695 | 697 | }
|
696 | 698 |
|
| 699 | +/// Project dimsPos to their collapsed positions in the reassocIndices. |
| 700 | +/// |
| 701 | +/// For example, given dimsPos [0, 1, 2, 4], and matching reassocIndices |
| 702 | +/// [[0], [1, 2], [3], [4]], it returns [0, 1, 1, 3]. Because for pos 0, |
| 703 | +/// the reassoc dim [0] is 0. For pos 1 and 2, the reassoc dim in pos |
| 704 | +/// [1, 2] is 1. And for pos 4, the reassoc dim [4] is 3. |
| 705 | +static SmallVector<int64_t> |
| 706 | +projectDimsPosIntoReassocPos(ArrayRef<int64_t> dimsPos, |
| 707 | + ArrayRef<ReassociationIndices> reassocIndices) { |
| 708 | + SmallVector<int64_t> projectedPos; |
| 709 | + |
| 710 | + // Map each dimension to the position of corresponding reassociation index. |
| 711 | + for (auto pos : dimsPos) { |
| 712 | + for (auto [idx, indices] : llvm::enumerate(reassocIndices)) { |
| 713 | + // If the dimension is present in the current indices group, the group |
| 714 | + // position within the reassociation map is the desired projected |
| 715 | + // dimension position. |
| 716 | + if (llvm::any_of(indices, |
| 717 | + [&](int64_t expandDim) { return expandDim == pos; })) { |
| 718 | + projectedPos.push_back(idx); |
| 719 | + break; |
| 720 | + } |
| 721 | + } |
| 722 | + } |
| 723 | + assert(projectedPos.size() == dimsPos.size() && "Invalid dim pos projection"); |
| 724 | + |
| 725 | + return projectedPos; |
| 726 | +} |
| 727 | + |
| 728 | +/// Bubble up pack op through expand shape op. |
| 729 | +/// |
| 730 | +/// For example: |
| 731 | +/// |
| 732 | +/// %expand = tensor.expand_shape %in [[0], [1, 2]] |
| 733 | +/// : tensor<?x64xf32> into tensor<?x4x16xf32> |
| 734 | +/// %pack = tensor.pack %expand outer_dims_perm = [0, 1] |
| 735 | +/// inner_dims_pos = [2] inner_tiles = [8] into %empty |
| 736 | +/// : tensor<?x4x16xf32> -> tensor<?x4x2x8xf32> |
| 737 | +/// |
| 738 | +/// can be transformed into: |
| 739 | +/// |
| 740 | +/// %pack = tensor.pack %in outer_dims_perm = [1, 2] |
| 741 | +/// inner_dims_pos = [1] inner_tiles = [8] into %empty |
| 742 | +/// : tensor<?x64xf32> -> tensor<?x8x8xf32> |
| 743 | +/// %expand = tensor.expand_shape %pack [[0], [1, 2], [3]] |
| 744 | +/// : tensor<?x8x8xf32> into tensor<?x4x2x8xf32> |
| 745 | +static LogicalResult |
| 746 | +bubbleUpPackOpThroughExpandShape(tensor::ExpandShapeOp expandOp, |
| 747 | + tensor::PackOp packOp, |
| 748 | + PatternRewriter &rewriter) { |
| 749 | + // Outer dimensions permutation is not supported currently. |
| 750 | + // TODO: Handle outer_dims_perm variants. |
| 751 | + ArrayRef<int64_t> outerDimsPerm = packOp.getOuterDimsPerm(); |
| 752 | + if (!outerDimsPerm.empty() && !isIdentityPermutation(outerDimsPerm)) { |
| 753 | + return rewriter.notifyMatchFailure(packOp, |
| 754 | + "non-identity outer dims perm NYI"); |
| 755 | + } |
| 756 | + |
| 757 | + // Validate dimensions' relations between shape expansion and packing. |
| 758 | + SmallVector<ReassociationIndices, 4> reassoc = |
| 759 | + expandOp.getReassociationIndices(); |
| 760 | + ArrayRef<int64_t> packInnerDims = packOp.getInnerDimsPos(); |
| 761 | + llvm::SetVector<int64_t> packDimsPos(packInnerDims.begin(), |
| 762 | + packInnerDims.end()); |
| 763 | + |
| 764 | + for (auto [idx, indices] : llvm::enumerate(reassoc)) { |
| 765 | + // For each expand_shape reassociation, figure out which dimensions get |
| 766 | + // packed if any. |
| 767 | + llvm::SetVector<int64_t> expandDimPos(indices.begin(), indices.end()); |
| 768 | + llvm::SetVector<int64_t> packedDims = |
| 769 | + llvm::set_intersection(packDimsPos, expandDimPos); |
| 770 | + |
| 771 | + // The expanded dimension is not packed so, it does not affect moving pack |
| 772 | + // before shape expansion - simply continue. |
| 773 | + if (packedDims.empty()) |
| 774 | + continue; |
| 775 | + // Shape expansion cannot be propagated when multiple expanded dimension are |
| 776 | + // packed - in this case operation reordering would affect final element |
| 777 | + // positions and/or shapes can no longer be projected. |
| 778 | + if (packedDims.size() != 1) |
| 779 | + return rewriter.notifyMatchFailure( |
| 780 | + packOp, "only one of the expanded dimensions can be packed"); |
| 781 | + // Only the inner-most expanded dimension should be packed. Otherwise, |
| 782 | + // elements order will be affected after operation reordering. |
| 783 | + if (packedDims.front() != indices.back()) |
| 784 | + return rewriter.notifyMatchFailure( |
| 785 | + packOp, "can only pack the inner-most expanded dimension"); |
| 786 | + } |
| 787 | + |
| 788 | + // Project pack.inner_dims_pos to positions before shape expansion. |
| 789 | + SmallVector<int64_t> projectedInnerDimsPos = |
| 790 | + projectDimsPosIntoReassocPos(packInnerDims, reassoc); |
| 791 | + |
| 792 | + // Project the shape expansion to new packed shape. |
| 793 | + // The pack.outer_dims_perm is restricted to identity so, the permutation can |
| 794 | + // be omitted for simplicity. |
| 795 | + // TODO: Account for outer dimensions permutation. |
| 796 | + // |
| 797 | + // If reassociation is not possible, then reordering cannot happen. |
| 798 | + // This can be caused by pack padding affecting previously expanded |
| 799 | + // dimensions or packing extending dimensions. |
| 800 | + RankedTensorType newPackType = tensor::PackOp::inferPackedType( |
| 801 | + expandOp.getSrcType(), packOp.getStaticInnerTiles(), |
| 802 | + projectedInnerDimsPos, /*outerDimsPerm=*/SmallVector<int64_t>{}); |
| 803 | + auto reassocExpand = |
| 804 | + getReassociationIndicesForReshape(newPackType, packOp.getDestType()); |
| 805 | + if (!reassocExpand) |
| 806 | + return rewriter.notifyMatchFailure( |
| 807 | + packOp, "could not reassociate dims after bubbling up"); |
| 808 | + |
| 809 | + Value destTensor = tensor::PackOp::createDestinationTensor( |
| 810 | + rewriter, packOp.getLoc(), expandOp.getSrc(), packOp.getMixedTiles(), |
| 811 | + projectedInnerDimsPos, /*outerDimsPerm=*/SmallVector<int64_t>{}); |
| 812 | + Value packedVal = rewriter.create<tensor::PackOp>( |
| 813 | + packOp.getLoc(), expandOp.getSrc(), destTensor, projectedInnerDimsPos, |
| 814 | + packOp.getMixedTiles(), packOp.getPaddingValue(), |
| 815 | + /*outerDimsPerm=*/SmallVector<int64_t>{}); |
| 816 | + |
| 817 | + Value newExpandOp = rewriter.create<tensor::ExpandShapeOp>( |
| 818 | + packOp.getLoc(), packOp.getDestType(), packedVal, *reassocExpand); |
| 819 | + rewriter.replaceOp(packOp, newExpandOp); |
| 820 | + |
| 821 | + return success(); |
| 822 | +} |
| 823 | + |
697 | 824 | class BubbleUpPackOpThroughReshapeOp final
|
698 | 825 | : public OpRewritePattern<tensor::PackOp> {
|
699 | 826 | public:
|
@@ -723,6 +850,9 @@ class BubbleUpPackOpThroughReshapeOp final
|
723 | 850 | .Case([&](tensor::CollapseShapeOp op) {
|
724 | 851 | return bubbleUpPackOpThroughCollapseShape(op, packOp, rewriter);
|
725 | 852 | })
|
| 853 | + .Case([&](tensor::ExpandShapeOp op) { |
| 854 | + return bubbleUpPackOpThroughExpandShape(op, packOp, rewriter); |
| 855 | + }) |
726 | 856 | .Default([](Operation *) { return failure(); });
|
727 | 857 | }
|
728 | 858 |
|
|
0 commit comments