Skip to content

Commit a0b00a0

Browse files
author
Jerry Wu
committed
Handle unit dim
1 parent 61f1d84 commit a0b00a0

File tree

1 file changed

+32
-14
lines changed

1 file changed

+32
-14
lines changed

mlir/lib/Dialect/Linalg/Transforms/DataLayoutPropagation.cpp

Lines changed: 32 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -552,6 +552,26 @@ class BubbleUpPackThroughPadOp final : public OpRewritePattern<tensor::PackOp> {
552552
ControlPropagationFn controlFn;
553553
};
554554

555+
static SmallVector<int64_t>
556+
projectToInnerMostNonUnitDimsPos(ArrayRef<int64_t> dimsPos,
557+
ArrayRef<ReassociationIndices> reassocIndices,
558+
ArrayRef<int64_t> baseShape) {
559+
SmallVector<int64_t> projectedDimsPos;
560+
for (auto pos : dimsPos) {
561+
int64_t projectedPos = -1;
562+
for (auto it = reassocIndices[pos].rbegin();
563+
it != reassocIndices[pos].rend(); ++it) {
564+
projectedPos = *it;
565+
if (baseShape[projectedPos] > 1) {
566+
break;
567+
}
568+
}
569+
assert(projectedPos != -1 && "projected dim not found");
570+
projectedDimsPos.push_back(projectedPos);
571+
}
572+
return projectedDimsPos;
573+
}
574+
555575
static LogicalResult
556576
bubbleUpPackOpThroughCollapseShape(tensor::CollapseShapeOp collapseOp,
557577
tensor::PackOp packOp,
@@ -568,10 +588,9 @@ bubbleUpPackOpThroughCollapseShape(tensor::CollapseShapeOp collapseOp,
568588
ArrayRef<int64_t> srcShape = collapseOp.getSrcType().getShape();
569589
SmallVector<ReassociationIndices> reassocIndices =
570590
collapseOp.getReassociationIndices();
571-
SmallVector<int64_t> baseDimsPos;
572-
for (auto pos : innerDimsPos) {
573-
baseDimsPos.push_back(reassocIndices[pos].back());
574-
}
591+
SmallVector<int64_t> baseDimsPos =
592+
projectToInnerMostNonUnitDimsPos(innerDimsPos, reassocIndices, srcShape);
593+
575594
// Check if the base dims before reassociation are divisible by the inner tile
576595
// sizes.
577596
for (auto [basePos, tileSize] :
@@ -590,11 +609,11 @@ bubbleUpPackOpThroughCollapseShape(tensor::CollapseShapeOp collapseOp,
590609
}
591610

592611
auto emptyOp = tensor::PackOp::createDestinationTensor(
593-
rewriter, packOp.getLoc(), collapseOp.getSrc(), packOp.getMixedTiles(), baseDimsPos,
594-
newOuterDimsPerm);
612+
rewriter, packOp.getLoc(), collapseOp.getSrc(), packOp.getMixedTiles(),
613+
baseDimsPos, newOuterDimsPerm);
595614
auto newPackOp = rewriter.create<tensor::PackOp>(
596-
packOp.getLoc(), collapseOp.getSrc(), emptyOp, baseDimsPos, packOp.getMixedTiles(),
597-
packOp.getPaddingValue(), newOuterDimsPerm);
615+
packOp.getLoc(), collapseOp.getSrc(), emptyOp, baseDimsPos,
616+
packOp.getMixedTiles(), packOp.getPaddingValue(), newOuterDimsPerm);
598617

599618
SmallVector<ReassociationIndices> newReassocIndices;
600619
int64_t currPos = 0;
@@ -660,10 +679,9 @@ pushDownUnPackOpThroughExpandShape(tensor::UnPackOp unPackOp,
660679
ArrayRef<int64_t> dstShape = expandOp.getType().getShape();
661680
SmallVector<ReassociationIndices> reassocIndices =
662681
expandOp.getReassociationIndices();
663-
SmallVector<int64_t> baseDimsPos;
664-
for (auto pos : innerDimsPos) {
665-
baseDimsPos.push_back(reassocIndices[pos].back());
666-
}
682+
SmallVector<int64_t> baseDimsPos =
683+
projectToInnerMostNonUnitDimsPos(innerDimsPos, reassocIndices, dstShape);
684+
667685
// Check if the base dims after reassociation are divisible by the inner tile
668686
// sizes.
669687
for (auto [basePos, tileSize] :
@@ -702,8 +720,8 @@ pushDownUnPackOpThroughExpandShape(tensor::UnPackOp unPackOp,
702720
newReassocIndices);
703721

704722
auto emptyOp = tensor::UnPackOp::createDestinationTensor(
705-
rewriter, unPackOp.getLoc(), newExpandOp, unPackOp.getMixedTiles(), baseDimsPos,
706-
newOuterDimsPerm);
723+
rewriter, unPackOp.getLoc(), newExpandOp, unPackOp.getMixedTiles(),
724+
baseDimsPos, newOuterDimsPerm);
707725
auto newUnPackOp = rewriter.create<tensor::UnPackOp>(
708726
unPackOp.getLoc(), newExpandOp.getResult(), emptyOp, baseDimsPos,
709727
unPackOp.getMixedTiles(), newOuterDimsPerm);

0 commit comments

Comments
 (0)