Skip to content

Commit 61f1d84

Browse files
author
Jerry Wu
committed
Test collapse pack and unpack expand
1 parent 7ef1a59 commit 61f1d84

File tree

1 file changed

+188
-1
lines changed

1 file changed

+188
-1
lines changed

mlir/lib/Dialect/Linalg/Transforms/DataLayoutPropagation.cpp

Lines changed: 188 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -552,6 +552,192 @@ class BubbleUpPackThroughPadOp final : public OpRewritePattern<tensor::PackOp> {
552552
ControlPropagationFn controlFn;
553553
};
554554

555+
static LogicalResult
556+
bubbleUpPackOpThroughCollapseShape(tensor::CollapseShapeOp collapseOp,
557+
tensor::PackOp packOp,
558+
PatternRewriter &rewriter) {
559+
SmallVector<int64_t> innerTileSizes = packOp.getStaticTiles();
560+
ArrayRef<int64_t> innerDimsPos = packOp.getInnerDimsPos();
561+
ArrayRef<int64_t> outerDimsPerm = packOp.getOuterDimsPerm();
562+
563+
if (llvm::any_of(innerTileSizes,
564+
[](int64_t size) { return ShapedType::isDynamic(size); })) {
565+
return failure();
566+
}
567+
568+
ArrayRef<int64_t> srcShape = collapseOp.getSrcType().getShape();
569+
SmallVector<ReassociationIndices> reassocIndices =
570+
collapseOp.getReassociationIndices();
571+
SmallVector<int64_t> baseDimsPos;
572+
for (auto pos : innerDimsPos) {
573+
baseDimsPos.push_back(reassocIndices[pos].back());
574+
}
575+
// Check if the base dims before reassociation are divisible by the inner tile
576+
// sizes.
577+
for (auto [basePos, tileSize] :
578+
llvm::zip_equal(baseDimsPos, innerTileSizes)) {
579+
int64_t dim = srcShape[basePos];
580+
if (ShapedType::isDynamic(dim) || (dim % tileSize) != 0) {
581+
return failure();
582+
}
583+
}
584+
// Expand the outer dims perm with associated src dims.
585+
SmallVector<int64_t> newOuterDimsPerm;
586+
for (auto outerPos : outerDimsPerm) {
587+
newOuterDimsPerm.insert(newOuterDimsPerm.end(),
588+
reassocIndices[outerPos].begin(),
589+
reassocIndices[outerPos].end());
590+
}
591+
592+
auto emptyOp = tensor::PackOp::createDestinationTensor(
593+
rewriter, packOp.getLoc(), collapseOp.getSrc(), packOp.getMixedTiles(), baseDimsPos,
594+
newOuterDimsPerm);
595+
auto newPackOp = rewriter.create<tensor::PackOp>(
596+
packOp.getLoc(), collapseOp.getSrc(), emptyOp, baseDimsPos, packOp.getMixedTiles(),
597+
packOp.getPaddingValue(), newOuterDimsPerm);
598+
599+
SmallVector<ReassociationIndices> newReassocIndices;
600+
int64_t currPos = 0;
601+
for (auto outerPos : outerDimsPerm) {
602+
int64_t start = currPos;
603+
int64_t end = start + reassocIndices[outerPos].size();
604+
newReassocIndices.push_back(llvm::to_vector(llvm::seq(start, end)));
605+
currPos = end;
606+
}
607+
for (auto unused : innerTileSizes) {
608+
(void)unused;
609+
newReassocIndices.push_back({currPos});
610+
currPos += 1;
611+
}
612+
613+
auto newCollapseOp = rewriter.create<tensor::CollapseShapeOp>(
614+
collapseOp.getLoc(), packOp.getType(), newPackOp, newReassocIndices);
615+
rewriter.replaceOp(packOp, newCollapseOp);
616+
617+
return success();
618+
}
619+
620+
class BubbleUpPackOpThroughReshapeOp final
621+
: public OpRewritePattern<tensor::PackOp> {
622+
public:
623+
BubbleUpPackOpThroughReshapeOp(MLIRContext *context, ControlPropagationFn fun)
624+
: OpRewritePattern<tensor::PackOp>(context), controlFn(std::move(fun)) {}
625+
626+
LogicalResult matchAndRewrite(tensor::PackOp packOp,
627+
PatternRewriter &rewriter) const override {
628+
if (packOp.getPaddingValue())
629+
return failure();
630+
631+
Operation *srcOp = packOp.getSource().getDefiningOp();
632+
if (!srcOp || !(srcOp->getNumResults() == 1) ||
633+
!srcOp->getResult(0).hasOneUse())
634+
return failure();
635+
636+
if (auto collapseOp = dyn_cast<tensor::CollapseShapeOp>(srcOp)) {
637+
return bubbleUpPackOpThroughCollapseShape(collapseOp, packOp, rewriter);
638+
}
639+
return failure();
640+
}
641+
642+
private:
643+
ControlPropagationFn controlFn;
644+
};
645+
646+
static LogicalResult
647+
pushDownUnPackOpThroughExpandShape(tensor::UnPackOp unPackOp,
648+
tensor::ExpandShapeOp expandOp,
649+
PatternRewriter &rewriter) {
650+
651+
SmallVector<int64_t> innerTileSizes = unPackOp.getStaticTiles();
652+
ArrayRef<int64_t> innerDimsPos = unPackOp.getInnerDimsPos();
653+
ArrayRef<int64_t> outerDimsPerm = unPackOp.getOuterDimsPerm();
654+
655+
if (llvm::any_of(innerTileSizes,
656+
[](int64_t size) { return ShapedType::isDynamic(size); })) {
657+
return failure();
658+
}
659+
660+
ArrayRef<int64_t> dstShape = expandOp.getType().getShape();
661+
SmallVector<ReassociationIndices> reassocIndices =
662+
expandOp.getReassociationIndices();
663+
SmallVector<int64_t> baseDimsPos;
664+
for (auto pos : innerDimsPos) {
665+
baseDimsPos.push_back(reassocIndices[pos].back());
666+
}
667+
// Check if the base dims after reassociation are divisible by the inner tile
668+
// sizes.
669+
for (auto [basePos, tileSize] :
670+
llvm::zip_equal(baseDimsPos, innerTileSizes)) {
671+
int64_t dim = dstShape[basePos];
672+
if (ShapedType::isDynamic(dim) || dstShape[basePos] % tileSize != 0) {
673+
return failure();
674+
}
675+
}
676+
// Expand the outer dims perm with associated src dims.
677+
SmallVector<int64_t> newOuterDimsPerm;
678+
for (auto outerPos : outerDimsPerm) {
679+
newOuterDimsPerm.insert(newOuterDimsPerm.end(),
680+
reassocIndices[outerPos].begin(),
681+
reassocIndices[outerPos].end());
682+
}
683+
684+
SmallVector<ReassociationIndices> newReassocIndices;
685+
int64_t currPos = 0;
686+
for (auto outerPos : outerDimsPerm) {
687+
int64_t start = currPos;
688+
int64_t end = start + reassocIndices[outerPos].size();
689+
newReassocIndices.push_back(llvm::to_vector(llvm::seq(start, end)));
690+
currPos = end;
691+
}
692+
for (auto unused : innerTileSizes) {
693+
(void)unused;
694+
newReassocIndices.push_back({currPos});
695+
currPos += 1;
696+
}
697+
698+
RankedTensorType newExpandType = tensor::PackOp::inferPackedType(
699+
expandOp.getType(), innerTileSizes, baseDimsPos, newOuterDimsPerm);
700+
auto newExpandOp = rewriter.create<tensor::ExpandShapeOp>(
701+
expandOp.getLoc(), newExpandType, unPackOp.getSource(),
702+
newReassocIndices);
703+
704+
auto emptyOp = tensor::UnPackOp::createDestinationTensor(
705+
rewriter, unPackOp.getLoc(), newExpandOp, unPackOp.getMixedTiles(), baseDimsPos,
706+
newOuterDimsPerm);
707+
auto newUnPackOp = rewriter.create<tensor::UnPackOp>(
708+
unPackOp.getLoc(), newExpandOp.getResult(), emptyOp, baseDimsPos,
709+
unPackOp.getMixedTiles(), newOuterDimsPerm);
710+
rewriter.replaceOp(expandOp, newUnPackOp);
711+
712+
return success();
713+
}
714+
715+
class PushDownUnPackOpThroughReshapeOp final
716+
: public OpRewritePattern<tensor::UnPackOp> {
717+
public:
718+
PushDownUnPackOpThroughReshapeOp(MLIRContext *context,
719+
ControlPropagationFn fun)
720+
: OpRewritePattern<tensor::UnPackOp>(context), controlFn(std::move(fun)) {
721+
}
722+
723+
LogicalResult matchAndRewrite(tensor::UnPackOp unPackOp,
724+
PatternRewriter &rewriter) const override {
725+
Value result = unPackOp.getResult();
726+
if (!result.hasOneUse()) {
727+
return failure();
728+
}
729+
Operation *userOp = *result.user_begin();
730+
731+
if (auto expandOp = dyn_cast<tensor::ExpandShapeOp>(userOp)) {
732+
return pushDownUnPackOpThroughExpandShape(unPackOp, expandOp, rewriter);
733+
}
734+
return failure();
735+
}
736+
737+
private:
738+
ControlPropagationFn controlFn;
739+
};
740+
555741
// TODO: Relax this restriction. We should unpack a generic op also
556742
// in the presence of multiple unpack ops as producers.
557743
/// Return the unpacked operand, if present, for the current generic op.
@@ -774,6 +960,7 @@ void mlir::linalg::populateDataLayoutPropagationPatterns(
774960
const ControlPropagationFn &controlPackUnPackPropagation) {
775961
patterns
776962
.insert<BubbleUpPackOpThroughGenericOpPattern, BubbleUpPackThroughPadOp,
777-
PushDownUnPackOpThroughGenericOp, PushDownUnPackThroughPadOp>(
963+
BubbleUpPackOpThroughReshapeOp, PushDownUnPackOpThroughGenericOp,
964+
PushDownUnPackThroughPadOp, PushDownUnPackOpThroughReshapeOp>(
778965
patterns.getContext(), controlPackUnPackPropagation);
779966
}

0 commit comments

Comments
 (0)