@@ -1023,6 +1023,76 @@ struct FoldReshapeWithGenericOpByExpansion
1023
1023
private:
1024
1024
ControlFusionFn controlFoldingReshapes;
1025
1025
};
1026
+
1027
+ // / Pattern to bubble up a tensor.expand_shape op through a producer
1028
+ // / tensor.collapse_shape op that has non intersecting reassociations.
1029
+ struct BubbleUpExpandThroughParallelCollapse
1030
+ : public OpRewritePattern<tensor::ExpandShapeOp> {
1031
+ using OpRewritePattern<tensor::ExpandShapeOp>::OpRewritePattern;
1032
+
1033
+ LogicalResult matchAndRewrite (tensor::ExpandShapeOp expandOp,
1034
+ PatternRewriter &rewriter) const override {
1035
+ auto collapseOp =
1036
+ expandOp.getSrc ().getDefiningOp <tensor::CollapseShapeOp>();
1037
+ if (!collapseOp || !collapseOp->hasOneUse ())
1038
+ return failure ();
1039
+ auto expandReInds = expandOp.getReassociationIndices ();
1040
+ auto collapseReInds = collapseOp.getReassociationIndices ();
1041
+
1042
+ // Reshapes are parallel to each other if none of the reassociation indices
1043
+ // have greater than 1 index for both reshapes.
1044
+ for (auto [expandReassociation, collapseReassociation] :
1045
+ llvm::zip_equal (expandReInds, collapseReInds)) {
1046
+ if (collapseReassociation.size () != 1 && expandReassociation.size () != 1 )
1047
+ return failure ();
1048
+ }
1049
+
1050
+ // Compute new reassociation indices and expanded/collaped shapes.
1051
+ SmallVector<ReassociationIndices> newExpandReInds, newCollapseReInds;
1052
+ Location loc = expandOp->getLoc ();
1053
+ SmallVector<OpFoldResult> collapseSizes =
1054
+ tensor::getMixedSizes (rewriter, loc, collapseOp.getSrc ());
1055
+ SmallVector<OpFoldResult> expandSizes (getMixedValues (
1056
+ expandOp.getStaticOutputShape (), expandOp.getOutputShape (), rewriter));
1057
+ SmallVector<OpFoldResult> newExpandSizes;
1058
+ int64_t index = 0 , expandIndex = 0 , collapseIndex = 0 ;
1059
+ for (auto [idx, collapseReassociation] : llvm::enumerate (collapseReInds)) {
1060
+ if (collapseReassociation.size () != 1 ) {
1061
+ ReassociationIndices newCollapseReassociation;
1062
+ for (size_t i = 0 ; i < collapseReassociation.size (); ++i) {
1063
+ newCollapseReassociation.push_back (index);
1064
+ newExpandReInds.push_back ({index++});
1065
+ newExpandSizes.push_back (collapseSizes[collapseIndex++]);
1066
+ }
1067
+ newCollapseReInds.push_back (newCollapseReassociation);
1068
+ expandIndex++;
1069
+ continue ;
1070
+ }
1071
+ ReassociationIndices newExpandReassociation;
1072
+ auto expandReassociation = expandReInds[idx];
1073
+ for (size_t i = 0 ; i < expandReassociation.size (); ++i) {
1074
+ newExpandReassociation.push_back (index);
1075
+ newCollapseReInds.push_back ({index++});
1076
+ newExpandSizes.push_back (expandSizes[expandIndex++]);
1077
+ }
1078
+ newExpandReInds.push_back (newExpandReassociation);
1079
+ collapseIndex++;
1080
+ }
1081
+
1082
+ // Swap reshape order.
1083
+ SmallVector<Value> dynamicSizes;
1084
+ SmallVector<int64_t > staticSizes;
1085
+ dispatchIndexOpFoldResults (newExpandSizes, dynamicSizes, staticSizes);
1086
+ auto expandResultType = expandOp.getResultType ().clone (staticSizes);
1087
+ auto newExpand = rewriter.create <tensor::ExpandShapeOp>(
1088
+ loc, expandResultType, collapseOp.getSrc (), newExpandReInds,
1089
+ newExpandSizes);
1090
+ rewriter.replaceOpWithNewOp <tensor::CollapseShapeOp>(
1091
+ expandOp, newExpand.getResult (), newCollapseReInds);
1092
+ return success ();
1093
+ }
1094
+ };
1095
+
1026
1096
} // namespace
1027
1097
1028
1098
// ===---------------------------------------------------------------------===//
@@ -1939,6 +2009,7 @@ void mlir::linalg::populateFoldReshapeOpsByExpansionPatterns(
1939
2009
controlFoldingReshapes);
1940
2010
patterns.add <FoldWithProducerReshapeOpByExpansion>(patterns.getContext (),
1941
2011
controlFoldingReshapes);
2012
+ patterns.add <BubbleUpExpandThroughParallelCollapse>(patterns.getContext ());
1942
2013
}
1943
2014
1944
2015
void mlir::linalg::populateFoldReshapeOpsByCollapsingPatterns (
0 commit comments