Skip to content

Commit 143243f

Browse files
author
Jerry Wu
committed
Fix control function
1 parent c60506c commit 143243f

File tree

1 file changed

+10
-10
lines changed

1 file changed

+10
-10
lines changed

mlir/lib/Dialect/Linalg/Transforms/DataLayoutPropagation.cpp

Lines changed: 10 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -703,10 +703,6 @@ class BubbleUpPackOpThroughReshapeOp final
703703

704704
LogicalResult matchAndRewrite(tensor::PackOp packOp,
705705
PatternRewriter &rewriter) const override {
706-
// User controlled propagation function.
707-
if (!controlFn(packOp))
708-
return failure();
709-
710706
Operation *srcOp = packOp.getSource().getDefiningOp();
711707
// Currently only support when the pack op is the only user.
712708
if (!srcOp || !(srcOp->getNumResults() == 1) ||
@@ -720,6 +716,10 @@ class BubbleUpPackOpThroughReshapeOp final
720716
return failure();
721717
}
722718

719+
// User controlled propagation function.
720+
if (!controlFn(srcOp))
721+
return failure();
722+
723723
return TypeSwitch<Operation *, LogicalResult>(srcOp)
724724
.Case([&](tensor::CollapseShapeOp op) {
725725
return bubbleUpPackOpThroughCollapseShape(op, packOp, rewriter);
@@ -825,10 +825,6 @@ class PushDownUnPackOpThroughReshapeOp final
825825

826826
LogicalResult matchAndRewrite(tensor::UnPackOp unPackOp,
827827
PatternRewriter &rewriter) const override {
828-
// User controlled propagation function.
829-
if (!controlFn(unPackOp))
830-
return failure();
831-
832828
Value result = unPackOp.getResult();
833829
// Currently only support unpack op with the single user.
834830
if (!result.hasOneUse()) {
@@ -841,8 +837,12 @@ class PushDownUnPackOpThroughReshapeOp final
841837
return failure();
842838
}
843839

844-
Operation *userOp = *result.user_begin();
845-
return TypeSwitch<Operation *, LogicalResult>(userOp)
840+
Operation *consumerOp = *result.user_begin();
841+
// User controlled propagation function.
842+
if (!controlFn(consumerOp))
843+
return failure();
844+
845+
return TypeSwitch<Operation *, LogicalResult>(consumerOp)
846846
.Case([&](tensor::ExpandShapeOp op) {
847847
return pushDownUnPackOpThroughExpandShape(unPackOp, op, rewriter);
848848
})

0 commit comments

Comments
 (0)