Skip to content

Commit a945f55

Browse files
adam-smnkpashu123
andauthored
[mlir][linalg] Add pattern to bubble-up pack through expand shape op (#93529)
Extends bubble-up pack through reshape pattern to handle pack propagation through expand shape ops. --------- Co-authored-by: Prashant Kumar <[email protected]>
1 parent 0e21f12 commit a945f55

File tree

2 files changed

+390
-0
lines changed

2 files changed

+390
-0
lines changed

mlir/lib/Dialect/Linalg/Transforms/DataLayoutPropagation.cpp

Lines changed: 130 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,8 @@
1717
#include "mlir/Dialect/Utils/IndexingUtils.h"
1818
#include "mlir/IR/Dominance.h"
1919
#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
20+
#include "llvm/ADT/SetOperations.h"
21+
#include "llvm/ADT/SetVector.h"
2022
#include "llvm/ADT/TypeSwitch.h"
2123
#include "llvm/Support/Debug.h"
2224
#include <optional>
@@ -694,6 +696,131 @@ bubbleUpPackOpThroughCollapseShape(tensor::CollapseShapeOp collapseOp,
694696
return success();
695697
}
696698

699+
/// Project dimsPos to their collapsed positions in the reassocIndices.
700+
///
701+
/// For example, given dimsPos [0, 1, 2, 4], and matching reassocIndices
702+
/// [[0], [1, 2], [3], [4]], it returns [0, 1, 1, 3]. Because for pos 0,
703+
/// the reassoc dim [0] is 0. For pos 1 and 2, the reassoc dim in pos
704+
/// [1, 2] is 1. And for pos 4, the reassoc dim [4] is 3.
705+
static SmallVector<int64_t>
706+
projectDimsPosIntoReassocPos(ArrayRef<int64_t> dimsPos,
707+
ArrayRef<ReassociationIndices> reassocIndices) {
708+
SmallVector<int64_t> projectedPos;
709+
710+
// Map each dimension to the position of corresponding reassociation index.
711+
for (auto pos : dimsPos) {
712+
for (auto [idx, indices] : llvm::enumerate(reassocIndices)) {
713+
// If the dimension is present in the current indices group, the group
714+
// position within the reassociation map is the desired projected
715+
// dimension position.
716+
if (llvm::any_of(indices,
717+
[&](int64_t expandDim) { return expandDim == pos; })) {
718+
projectedPos.push_back(idx);
719+
break;
720+
}
721+
}
722+
}
723+
assert(projectedPos.size() == dimsPos.size() && "Invalid dim pos projection");
724+
725+
return projectedPos;
726+
}
727+
728+
/// Bubble up pack op through expand shape op.
729+
///
730+
/// For example:
731+
///
732+
/// %expand = tensor.expand_shape %in [[0], [1, 2]]
733+
/// : tensor<?x64xf32> into tensor<?x4x16xf32>
734+
/// %pack = tensor.pack %expand outer_dims_perm = [0, 1]
735+
/// inner_dims_pos = [2] inner_tiles = [8] into %empty
736+
/// : tensor<?x4x16xf32> -> tensor<?x4x2x8xf32>
737+
///
738+
/// can be transformed into:
739+
///
740+
/// %pack = tensor.pack %in outer_dims_perm = [1, 2]
741+
/// inner_dims_pos = [1] inner_tiles = [8] into %empty
742+
/// : tensor<?x64xf32> -> tensor<?x8x8xf32>
743+
/// %expand = tensor.expand_shape %pack [[0], [1, 2], [3]]
744+
/// : tensor<?x8x8xf32> into tensor<?x4x2x8xf32>
745+
static LogicalResult
746+
bubbleUpPackOpThroughExpandShape(tensor::ExpandShapeOp expandOp,
747+
tensor::PackOp packOp,
748+
PatternRewriter &rewriter) {
749+
// Outer dimensions permutation is not supported currently.
750+
// TODO: Handle outer_dims_perm variants.
751+
ArrayRef<int64_t> outerDimsPerm = packOp.getOuterDimsPerm();
752+
if (!outerDimsPerm.empty() && !isIdentityPermutation(outerDimsPerm)) {
753+
return rewriter.notifyMatchFailure(packOp,
754+
"non-identity outer dims perm NYI");
755+
}
756+
757+
// Validate dimensions' relations between shape expansion and packing.
758+
SmallVector<ReassociationIndices, 4> reassoc =
759+
expandOp.getReassociationIndices();
760+
ArrayRef<int64_t> packInnerDims = packOp.getInnerDimsPos();
761+
llvm::SetVector<int64_t> packDimsPos(packInnerDims.begin(),
762+
packInnerDims.end());
763+
764+
for (auto [idx, indices] : llvm::enumerate(reassoc)) {
765+
// For each expand_shape reassociation, figure out which dimensions get
766+
// packed if any.
767+
llvm::SetVector<int64_t> expandDimPos(indices.begin(), indices.end());
768+
llvm::SetVector<int64_t> packedDims =
769+
llvm::set_intersection(packDimsPos, expandDimPos);
770+
771+
// The expanded dimension is not packed so, it does not affect moving pack
772+
// before shape expansion - simply continue.
773+
if (packedDims.empty())
774+
continue;
775+
// Shape expansion cannot be propagated when multiple expanded dimension are
776+
// packed - in this case operation reordering would affect final element
777+
// positions and/or shapes can no longer be projected.
778+
if (packedDims.size() != 1)
779+
return rewriter.notifyMatchFailure(
780+
packOp, "only one of the expanded dimensions can be packed");
781+
// Only the inner-most expanded dimension should be packed. Otherwise,
782+
// elements order will be affected after operation reordering.
783+
if (packedDims.front() != indices.back())
784+
return rewriter.notifyMatchFailure(
785+
packOp, "can only pack the inner-most expanded dimension");
786+
}
787+
788+
// Project pack.inner_dims_pos to positions before shape expansion.
789+
SmallVector<int64_t> projectedInnerDimsPos =
790+
projectDimsPosIntoReassocPos(packInnerDims, reassoc);
791+
792+
// Project the shape expansion to new packed shape.
793+
// The pack.outer_dims_perm is restricted to identity so, the permutation can
794+
// be omitted for simplicity.
795+
// TODO: Account for outer dimensions permutation.
796+
//
797+
// If reassociation is not possible, then reordering cannot happen.
798+
// This can be caused by pack padding affecting previously expanded
799+
// dimensions or packing extending dimensions.
800+
RankedTensorType newPackType = tensor::PackOp::inferPackedType(
801+
expandOp.getSrcType(), packOp.getStaticInnerTiles(),
802+
projectedInnerDimsPos, /*outerDimsPerm=*/SmallVector<int64_t>{});
803+
auto reassocExpand =
804+
getReassociationIndicesForReshape(newPackType, packOp.getDestType());
805+
if (!reassocExpand)
806+
return rewriter.notifyMatchFailure(
807+
packOp, "could not reassociate dims after bubbling up");
808+
809+
Value destTensor = tensor::PackOp::createDestinationTensor(
810+
rewriter, packOp.getLoc(), expandOp.getSrc(), packOp.getMixedTiles(),
811+
projectedInnerDimsPos, /*outerDimsPerm=*/SmallVector<int64_t>{});
812+
Value packedVal = rewriter.create<tensor::PackOp>(
813+
packOp.getLoc(), expandOp.getSrc(), destTensor, projectedInnerDimsPos,
814+
packOp.getMixedTiles(), packOp.getPaddingValue(),
815+
/*outerDimsPerm=*/SmallVector<int64_t>{});
816+
817+
Value newExpandOp = rewriter.create<tensor::ExpandShapeOp>(
818+
packOp.getLoc(), packOp.getDestType(), packedVal, *reassocExpand);
819+
rewriter.replaceOp(packOp, newExpandOp);
820+
821+
return success();
822+
}
823+
697824
class BubbleUpPackOpThroughReshapeOp final
698825
: public OpRewritePattern<tensor::PackOp> {
699826
public:
@@ -723,6 +850,9 @@ class BubbleUpPackOpThroughReshapeOp final
723850
.Case([&](tensor::CollapseShapeOp op) {
724851
return bubbleUpPackOpThroughCollapseShape(op, packOp, rewriter);
725852
})
853+
.Case([&](tensor::ExpandShapeOp op) {
854+
return bubbleUpPackOpThroughExpandShape(op, packOp, rewriter);
855+
})
726856
.Default([](Operation *) { return failure(); });
727857
}
728858

0 commit comments

Comments
 (0)