Skip to content

Commit 3f79dc6

Browse files
author
Jerry Wu
committed
Refactor
1 parent ff6aa3a commit 3f79dc6

File tree

1 file changed

+110
-68
lines changed

1 file changed

+110
-68
lines changed

mlir/lib/Dialect/Linalg/Transforms/DataLayoutPropagation.cpp

Lines changed: 110 additions & 68 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,7 @@
1717
#include "mlir/Dialect/Utils/IndexingUtils.h"
1818
#include "mlir/IR/Dominance.h"
1919
#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
20+
#include "llvm/ADT/TypeSwitch.h"
2021
#include "llvm/Support/Debug.h"
2122
#include <optional>
2223

@@ -572,6 +573,39 @@ projectToInnerMostNonUnitDimsPos(ArrayRef<int64_t> dimsPos,
572573
return projectedDimsPos;
573574
}
574575

576+
static int64_t applyPermutationAndReindexReassoc(
577+
SmallVector<ReassociationIndices> &reassociationIndices,
578+
ArrayRef<int64_t> dimsPerm) {
579+
applyPermutationToVector<ReassociationIndices>(reassociationIndices,
580+
dimsPerm);
581+
int64_t lastPos = 0;
582+
for (ReassociationIndices &indices : reassociationIndices) {
583+
for (auto &index : indices) {
584+
index = lastPos;
585+
lastPos += 1;
586+
}
587+
}
588+
return lastPos;
589+
}
590+
591+
/// Bubble up pack op through collapse shape op when the packed dims can be
592+
/// mapped to the source dims before collapsing. This is possible when the inner
593+
/// tile sizes can divide the mapped source dims.
594+
///
595+
/// For example:
596+
///
597+
/// %collapsed = tensor.collapse_shape %in [[0, 1], 2] : tensor<?x16x4xf32> into
598+
/// tensor<?x4xf32> %out = tensor.empty() : tensor<?x4x8x1xf32> %pack =
599+
/// tensor.pack %collapsed outer_dims_perm = [0, 1] inner_dims_pos = [0, 1]
600+
/// inner_tiles = [8, 1] into %out : tensor<?x4xf32> -> tensor<?x4x8x1xf32>
601+
///
602+
/// Can be transformed into:
603+
///
604+
/// %out = tensor.empty() : tensor<?x2x4x8x1xf32>
605+
/// %pack = tensor.pack %in outer_dims_perm = [1, 2] inner_dims_pos = [1, 2]
606+
/// inner_tiles = [8, 1] into %out : tensor<?x16x4xf32> -> tensor<?x2x4x8x1xf32>
607+
/// %collapsed = tensor.collapse_shape %1 [[0, 1], 2, 3, 4] :
608+
/// tensor<?x2x4x8x1xf32> into tensor<?x4x8x1>
575609
static LogicalResult
576610
bubbleUpPackOpThroughCollapseShape(tensor::CollapseShapeOp collapseOp,
577611
tensor::PackOp packOp,
@@ -580,27 +614,23 @@ bubbleUpPackOpThroughCollapseShape(tensor::CollapseShapeOp collapseOp,
580614
ArrayRef<int64_t> innerDimsPos = packOp.getInnerDimsPos();
581615
ArrayRef<int64_t> outerDimsPerm = packOp.getOuterDimsPerm();
582616

583-
if (llvm::any_of(innerTileSizes,
584-
[](int64_t size) { return ShapedType::isDynamic(size); })) {
585-
return failure();
586-
}
587-
588617
ArrayRef<int64_t> srcShape = collapseOp.getSrcType().getShape();
589618
SmallVector<ReassociationIndices> reassocIndices =
590619
collapseOp.getReassociationIndices();
591-
SmallVector<int64_t> baseDimsPos =
620+
SmallVector<int64_t> projectedInnerDimsPos =
592621
projectToInnerMostNonUnitDimsPos(innerDimsPos, reassocIndices, srcShape);
593622

594-
// Check if the base dims before reassociation are divisible by the inner tile
623+
// Check if the projected dims on the source are divisible by the inner tile
595624
// sizes.
596-
for (auto [basePos, tileSize] :
597-
llvm::zip_equal(baseDimsPos, innerTileSizes)) {
598-
int64_t dim = srcShape[basePos];
599-
if (ShapedType::isDynamic(dim) || (dim % tileSize) != 0) {
625+
for (auto [projectedPos, tileSize] :
626+
llvm::zip_equal(projectedInnerDimsPos, innerTileSizes)) {
627+
int64_t dim = srcShape[projectedPos];
628+
if (ShapedType::isDynamic(dim) || (dim % tileSize) != 0)
600629
return failure();
601-
}
602630
}
603-
// Expand the outer dims perm with associated src dims.
631+
// Expand the outer dims permutation with the associated source dims for the
632+
// new permutation after bubbling. This is because moving a collapsed dim is
633+
// equivalent to moving the associated source dims together.
604634
SmallVector<int64_t> newOuterDimsPerm;
605635
for (auto outerPos : outerDimsPerm) {
606636
newOuterDimsPerm.insert(newOuterDimsPerm.end(),
@@ -610,23 +640,19 @@ bubbleUpPackOpThroughCollapseShape(tensor::CollapseShapeOp collapseOp,
610640

611641
auto emptyOp = tensor::PackOp::createDestinationTensor(
612642
rewriter, packOp.getLoc(), collapseOp.getSrc(), packOp.getMixedTiles(),
613-
baseDimsPos, newOuterDimsPerm);
643+
projectedInnerDimsPos, newOuterDimsPerm);
614644
auto newPackOp = rewriter.create<tensor::PackOp>(
615-
packOp.getLoc(), collapseOp.getSrc(), emptyOp, baseDimsPos,
645+
packOp.getLoc(), collapseOp.getSrc(), emptyOp, projectedInnerDimsPos,
616646
packOp.getMixedTiles(), packOp.getPaddingValue(), newOuterDimsPerm);
617647

618-
SmallVector<ReassociationIndices> newReassocIndices;
619-
int64_t currPos = 0;
620-
for (auto outerPos : outerDimsPerm) {
621-
int64_t start = currPos;
622-
int64_t end = start + reassocIndices[outerPos].size();
623-
newReassocIndices.push_back(llvm::to_vector(llvm::seq(start, end)));
624-
currPos = end;
625-
}
626-
for (auto unused : innerTileSizes) {
627-
(void)unused;
628-
newReassocIndices.push_back({currPos});
629-
currPos += 1;
648+
SmallVector<ReassociationIndices> newReassocIndices = reassocIndices;
649+
// First build reassociations on the outer dims after the permutation.
650+
int64_t lastPos =
651+
applyPermutationAndReindexReassoc(newReassocIndices, outerDimsPerm);
652+
// Then add direct mapping for the inner tile dims.
653+
for (size_t i = 0; i < innerDimsPos.size(); ++i) {
654+
newReassocIndices.push_back({lastPos});
655+
lastPos += 1;
630656
}
631657

632658
auto newCollapseOp = rewriter.create<tensor::CollapseShapeOp>(
@@ -644,18 +670,28 @@ class BubbleUpPackOpThroughReshapeOp final
644670

645671
LogicalResult matchAndRewrite(tensor::PackOp packOp,
646672
PatternRewriter &rewriter) const override {
647-
if (packOp.getPaddingValue())
673+
// User controlled propagation function.
674+
if (!controlFn(packOp))
648675
return failure();
649676

650677
Operation *srcOp = packOp.getSource().getDefiningOp();
678+
// Currently only support when the pack op is the only user.
651679
if (!srcOp || !(srcOp->getNumResults() == 1) ||
652-
!srcOp->getResult(0).hasOneUse())
680+
!srcOp->getResult(0).hasOneUse()) {
653681
return failure();
654-
655-
if (auto collapseOp = dyn_cast<tensor::CollapseShapeOp>(srcOp)) {
656-
return bubbleUpPackOpThroughCollapseShape(collapseOp, packOp, rewriter);
657682
}
658-
return failure();
683+
// Currently only support static inner tile sizes.
684+
if (llvm::any_of(packOp.getStaticTiles(), [](int64_t size) {
685+
return ShapedType::isDynamic(size);
686+
})) {
687+
return failure();
688+
}
689+
690+
return TypeSwitch<Operation *, LogicalResult>(srcOp)
691+
.Case([&](tensor::CollapseShapeOp op) {
692+
return bubbleUpPackOpThroughCollapseShape(op, packOp, rewriter);
693+
})
694+
.Default([](Operation *) { return failure(); });
659695
}
660696

661697
private:
@@ -666,65 +702,59 @@ static LogicalResult
666702
pushDownUnPackOpThroughExpandShape(tensor::UnPackOp unPackOp,
667703
tensor::ExpandShapeOp expandOp,
668704
PatternRewriter &rewriter) {
669-
670705
SmallVector<int64_t> innerTileSizes = unPackOp.getStaticTiles();
671706
ArrayRef<int64_t> innerDimsPos = unPackOp.getInnerDimsPos();
672707
ArrayRef<int64_t> outerDimsPerm = unPackOp.getOuterDimsPerm();
673708

674-
if (llvm::any_of(innerTileSizes,
675-
[](int64_t size) { return ShapedType::isDynamic(size); })) {
676-
return failure();
677-
}
678-
679709
ArrayRef<int64_t> dstShape = expandOp.getType().getShape();
680710
SmallVector<ReassociationIndices> reassocIndices =
681711
expandOp.getReassociationIndices();
682-
SmallVector<int64_t> baseDimsPos =
712+
SmallVector<int64_t> projectedInnerDimsPos =
683713
projectToInnerMostNonUnitDimsPos(innerDimsPos, reassocIndices, dstShape);
684714

685-
// Check if the base dims after reassociation are divisible by the inner tile
715+
// Check if the projected dims on the dest are divisible by the inner tile
686716
// sizes.
687-
for (auto [basePos, tileSize] :
688-
llvm::zip_equal(baseDimsPos, innerTileSizes)) {
689-
int64_t dim = dstShape[basePos];
690-
if (ShapedType::isDynamic(dim) || dstShape[basePos] % tileSize != 0) {
717+
for (auto [projectedPos, tileSize] :
718+
llvm::zip_equal(projectedInnerDimsPos, innerTileSizes)) {
719+
int64_t dim = dstShape[projectedPos];
720+
if (ShapedType::isDynamic(dim) ||
721+
(dstShape[projectedPos] % tileSize) != 0) {
691722
return failure();
692723
}
693724
}
694-
// Expand the outer dims perm with associated src dims.
725+
// Expand the outer dims permutation with the associated expanded dims for the
726+
// new permutation after pushing. This is because moving a source dim is
727+
// equivalent to moving the associated expanded dims together.
695728
SmallVector<int64_t> newOuterDimsPerm;
696729
for (auto outerPos : outerDimsPerm) {
697730
newOuterDimsPerm.insert(newOuterDimsPerm.end(),
698731
reassocIndices[outerPos].begin(),
699732
reassocIndices[outerPos].end());
700733
}
701734

702-
SmallVector<ReassociationIndices> newReassocIndices;
703-
int64_t currPos = 0;
704-
for (auto outerPos : outerDimsPerm) {
705-
int64_t start = currPos;
706-
int64_t end = start + reassocIndices[outerPos].size();
707-
newReassocIndices.push_back(llvm::to_vector(llvm::seq(start, end)));
708-
currPos = end;
709-
}
710-
for (auto unused : innerTileSizes) {
711-
(void)unused;
712-
newReassocIndices.push_back({currPos});
713-
currPos += 1;
735+
SmallVector<ReassociationIndices> newReassocIndices = reassocIndices;
736+
// First build reassociations on the outer dims after the permutation.
737+
int64_t lastPos =
738+
applyPermutationAndReindexReassoc(newReassocIndices, outerDimsPerm);
739+
// Then add direct mapping for the inner tile dims.
740+
for (size_t i = 0; i < innerDimsPos.size(); ++i) {
741+
newReassocIndices.push_back({lastPos});
742+
lastPos += 1;
714743
}
715744

716-
RankedTensorType newExpandType = tensor::PackOp::inferPackedType(
717-
expandOp.getType(), innerTileSizes, baseDimsPos, newOuterDimsPerm);
745+
RankedTensorType newExpandType =
746+
tensor::PackOp::inferPackedType(expandOp.getType(), innerTileSizes,
747+
projectedInnerDimsPos, newOuterDimsPerm);
718748
auto newExpandOp = rewriter.create<tensor::ExpandShapeOp>(
719749
expandOp.getLoc(), newExpandType, unPackOp.getSource(),
720750
newReassocIndices);
721751

722752
auto emptyOp = tensor::UnPackOp::createDestinationTensor(
723753
rewriter, unPackOp.getLoc(), newExpandOp, unPackOp.getMixedTiles(),
724-
baseDimsPos, newOuterDimsPerm);
754+
projectedInnerDimsPos, newOuterDimsPerm);
725755
auto newUnPackOp = rewriter.create<tensor::UnPackOp>(
726-
unPackOp.getLoc(), newExpandOp.getResult(), emptyOp, baseDimsPos,
727-
unPackOp.getMixedTiles(), newOuterDimsPerm);
756+
unPackOp.getLoc(), newExpandOp.getResult(), emptyOp,
757+
projectedInnerDimsPos, unPackOp.getMixedTiles(), newOuterDimsPerm);
728758
rewriter.replaceOp(expandOp, newUnPackOp);
729759

730760
return success();
@@ -740,16 +770,28 @@ class PushDownUnPackOpThroughReshapeOp final
740770

741771
LogicalResult matchAndRewrite(tensor::UnPackOp unPackOp,
742772
PatternRewriter &rewriter) const override {
773+
// User controlled propagation function.
774+
if (!controlFn(unPackOp))
775+
return failure();
776+
743777
Value result = unPackOp.getResult();
778+
// Currently only support unpack op with the single user.
744779
if (!result.hasOneUse()) {
745780
return failure();
746781
}
747-
Operation *userOp = *result.user_begin();
748-
749-
if (auto expandOp = dyn_cast<tensor::ExpandShapeOp>(userOp)) {
750-
return pushDownUnPackOpThroughExpandShape(unPackOp, expandOp, rewriter);
782+
// Currently only support static inner tile sizes.
783+
if (llvm::any_of(unPackOp.getStaticTiles(), [](int64_t size) {
784+
return ShapedType::isDynamic(size);
785+
})) {
786+
return failure();
751787
}
752-
return failure();
788+
789+
Operation *userOp = *result.user_begin();
790+
return TypeSwitch<Operation *, LogicalResult>(userOp)
791+
.Case([&](tensor::ExpandShapeOp op) {
792+
return pushDownUnPackOpThroughExpandShape(unPackOp, op, rewriter);
793+
})
794+
.Default([](Operation *) { return failure(); });
753795
}
754796

755797
private:

0 commit comments

Comments
 (0)