@@ -703,10 +703,6 @@ class BubbleUpPackOpThroughReshapeOp final
703
703
704
704
LogicalResult matchAndRewrite (tensor::PackOp packOp,
705
705
PatternRewriter &rewriter) const override {
706
- // User controlled propagation function.
707
- if (!controlFn (packOp))
708
- return failure ();
709
-
710
706
Operation *srcOp = packOp.getSource ().getDefiningOp ();
711
707
// Currently only support when the pack op is the only user.
712
708
if (!srcOp || !(srcOp->getNumResults () == 1 ) ||
@@ -720,6 +716,10 @@ class BubbleUpPackOpThroughReshapeOp final
720
716
return failure ();
721
717
}
722
718
719
+ // User controlled propagation function.
720
+ if (!controlFn (srcOp))
721
+ return failure ();
722
+
723
723
return TypeSwitch<Operation *, LogicalResult>(srcOp)
724
724
.Case ([&](tensor::CollapseShapeOp op) {
725
725
return bubbleUpPackOpThroughCollapseShape (op, packOp, rewriter);
@@ -825,10 +825,6 @@ class PushDownUnPackOpThroughReshapeOp final
825
825
826
826
LogicalResult matchAndRewrite (tensor::UnPackOp unPackOp,
827
827
PatternRewriter &rewriter) const override {
828
- // User controlled propagation function.
829
- if (!controlFn (unPackOp))
830
- return failure ();
831
-
832
828
Value result = unPackOp.getResult ();
833
829
// Currently only support unpack op with the single user.
834
830
if (!result.hasOneUse ()) {
@@ -841,8 +837,12 @@ class PushDownUnPackOpThroughReshapeOp final
841
837
return failure ();
842
838
}
843
839
844
- Operation *userOp = *result.user_begin ();
845
- return TypeSwitch<Operation *, LogicalResult>(userOp)
840
+ Operation *consumerOp = *result.user_begin ();
841
+ // User controlled propagation function.
842
+ if (!controlFn (consumerOp))
843
+ return failure ();
844
+
845
+ return TypeSwitch<Operation *, LogicalResult>(consumerOp)
846
846
.Case ([&](tensor::ExpandShapeOp op) {
847
847
return pushDownUnPackOpThroughExpandShape (unPackOp, op, rewriter);
848
848
})
0 commit comments