Skip to content

Commit 8b5a6be

Browse files
committed
[mlir] Add bubbling patterns for non intersecting reshapes
1 parent 2dc8fea commit 8b5a6be

File tree

2 files changed

+105
-0
lines changed

2 files changed

+105
-0
lines changed

mlir/lib/Dialect/Linalg/Transforms/ElementwiseOpFusion.cpp

Lines changed: 71 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1023,6 +1023,76 @@ struct FoldReshapeWithGenericOpByExpansion
10231023
private:
10241024
ControlFusionFn controlFoldingReshapes;
10251025
};
1026+
1027+
/// Pattern to bubble up a tensor.expand_shape op through a producer
1028+
/// tensor.collapse_shape op that has non intersecting reassociations.
1029+
struct BubbleUpExpandThroughParallelCollapse
1030+
: public OpRewritePattern<tensor::ExpandShapeOp> {
1031+
using OpRewritePattern<tensor::ExpandShapeOp>::OpRewritePattern;
1032+
1033+
LogicalResult matchAndRewrite(tensor::ExpandShapeOp expandOp,
1034+
PatternRewriter &rewriter) const override {
1035+
auto collapseOp =
1036+
expandOp.getSrc().getDefiningOp<tensor::CollapseShapeOp>();
1037+
if (!collapseOp || !collapseOp->hasOneUse())
1038+
return failure();
1039+
auto expandReInds = expandOp.getReassociationIndices();
1040+
auto collapseReInds = collapseOp.getReassociationIndices();
1041+
1042+
// Reshapes are parallel to each other if none of the reassociation indices
1043+
// have greater than 1 index for both reshapes.
1044+
for (auto [expandReassociation, collapseReassociation] :
1045+
llvm::zip_equal(expandReInds, collapseReInds)) {
1046+
if (collapseReassociation.size() != 1 && expandReassociation.size() != 1)
1047+
return failure();
1048+
}
1049+
1050+
// Compute new reassociation indices and expanded/collaped shapes.
1051+
SmallVector<ReassociationIndices> newExpandReInds, newCollapseReInds;
1052+
Location loc = expandOp->getLoc();
1053+
SmallVector<OpFoldResult> collapseSizes =
1054+
tensor::getMixedSizes(rewriter, loc, collapseOp.getSrc());
1055+
SmallVector<OpFoldResult> expandSizes(getMixedValues(
1056+
expandOp.getStaticOutputShape(), expandOp.getOutputShape(), rewriter));
1057+
SmallVector<OpFoldResult> newExpandSizes;
1058+
int64_t index = 0, expandIndex = 0, collapseIndex = 0;
1059+
for (auto [idx, collapseReassociation] : llvm::enumerate(collapseReInds)) {
1060+
if (collapseReassociation.size() != 1) {
1061+
ReassociationIndices newCollapseReassociation;
1062+
for (size_t i = 0; i < collapseReassociation.size(); ++i) {
1063+
newCollapseReassociation.push_back(index);
1064+
newExpandReInds.push_back({index++});
1065+
newExpandSizes.push_back(collapseSizes[collapseIndex++]);
1066+
}
1067+
newCollapseReInds.push_back(newCollapseReassociation);
1068+
expandIndex++;
1069+
continue;
1070+
}
1071+
ReassociationIndices newExpandReassociation;
1072+
auto expandReassociation = expandReInds[idx];
1073+
for (size_t i = 0; i < expandReassociation.size(); ++i) {
1074+
newExpandReassociation.push_back(index);
1075+
newCollapseReInds.push_back({index++});
1076+
newExpandSizes.push_back(expandSizes[expandIndex++]);
1077+
}
1078+
newExpandReInds.push_back(newExpandReassociation);
1079+
collapseIndex++;
1080+
}
1081+
1082+
// Swap reshape order.
1083+
SmallVector<Value> dynamicSizes;
1084+
SmallVector<int64_t> staticSizes;
1085+
dispatchIndexOpFoldResults(newExpandSizes, dynamicSizes, staticSizes);
1086+
auto expandResultType = expandOp.getResultType().clone(staticSizes);
1087+
auto newExpand = rewriter.create<tensor::ExpandShapeOp>(
1088+
loc, expandResultType, collapseOp.getSrc(), newExpandReInds,
1089+
newExpandSizes);
1090+
rewriter.replaceOpWithNewOp<tensor::CollapseShapeOp>(
1091+
expandOp, newExpand.getResult(), newCollapseReInds);
1092+
return success();
1093+
}
1094+
};
1095+
10261096
} // namespace
10271097

10281098
//===---------------------------------------------------------------------===//
@@ -1939,6 +2009,7 @@ void mlir::linalg::populateFoldReshapeOpsByExpansionPatterns(
19392009
controlFoldingReshapes);
19402010
patterns.add<FoldWithProducerReshapeOpByExpansion>(patterns.getContext(),
19412011
controlFoldingReshapes);
2012+
patterns.add<BubbleUpExpandThroughParallelCollapse>(patterns.getContext());
19422013
}
19432014

19442015
void mlir::linalg::populateFoldReshapeOpsByCollapsingPatterns(

mlir/test/Dialect/Linalg/reshape_fusion.mlir

Lines changed: 34 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -826,3 +826,37 @@ func.func @linalg_add_reshape_producer_fusion(%arg0 : tensor<?x7x?x8xf32>,
826826
// CHECK-SAME: [0, 1], [2, 3]
827827
// CHECK-SAME: tensor<?x7x?x8xf32> into tensor<?x?xf32>
828828
// CHECK: return %[[T4]]
829+
830+
// -----
831+
832+
func.func @bubble_parallel_reshapes(%arg0: tensor<?x?x?x?xf32>, %s0: index, %s1: index, %s2: index, %s3: index) -> tensor<?x?x?x?xf32> {
833+
%collapse = tensor.collapse_shape %arg0 [[0], [1, 2], [3]] : tensor<?x?x?x?xf32> into tensor<?x?x?xf32>
834+
%expand = tensor.expand_shape %collapse [[0], [1], [2, 3]]
835+
output_shape [%s0, %s1, %s2, %s3] : tensor<?x?x?xf32> into tensor<?x?x?x?xf32>
836+
return %expand : tensor<?x?x?x?xf32>
837+
}
838+
// CHECK: func @bubble_parallel_reshapes
839+
// CHECK-SAME: %[[ARG0:.+]]: tensor<?x?x?x?xf32>
840+
// CHECK-SAME: %[[S0:.+]]: index, %[[S1:.+]]: index, %[[S2:.+]]: index, %[[S3:.+]]: index
841+
// CHECK-DAG: %[[C1:.+]] = arith.constant 1 : index
842+
// CHECK-DAG: %[[C2:.+]] = arith.constant 2 : index
843+
// CHECK-DAG: %[[DIM1:.+]] = tensor.dim %[[ARG0]], %[[C1]] : tensor<?x?x?x?xf32>
844+
// CHECK-DAG: %[[DIM2:.+]] = tensor.dim %[[ARG0]], %[[C2]] : tensor<?x?x?x?xf32>
845+
// CHECK: %[[EXPAND:.+]] = tensor.expand_shape %[[ARG0]] {{\[}}[0], [1], [2], [3, 4]]
846+
// CHECK-SAME: output_shape [%[[S0]], %[[DIM1]], %[[DIM2]], %[[S2]], %[[S3]]] : tensor<?x?x?x?xf32> into tensor<?x?x?x?x?xf32>
847+
// CHECK: %[[COLLAPSE:.+]] = tensor.collapse_shape %[[EXPAND]] {{\[}}[0], [1, 2], [3], [4]] : tensor<?x?x?x?x?xf32> into tensor<?x?x?x?xf32>
848+
// CHECK: return %[[COLLAPSE]]
849+
850+
// -----
851+
852+
func.func @no_bubble_intersecting_reshapes(%arg0: tensor<?x?x?x?xf32>, %s0: index, %s1: index, %s2: index, %s3: index) -> tensor<?x?x?x?xf32> {
853+
%collapse = tensor.collapse_shape %arg0 [[0], [1, 2], [3]] : tensor<?x?x?x?xf32> into tensor<?x?x?xf32>
854+
%expand = tensor.expand_shape %collapse [[0], [1, 2], [3]]
855+
output_shape [%s0, %s1, %s2, %s3] : tensor<?x?x?xf32> into tensor<?x?x?x?xf32>
856+
return %expand : tensor<?x?x?x?xf32>
857+
}
858+
// CHECK: func @no_bubble_intersecting_reshapes
859+
// CHECK-SAME: %[[ARG0:.+]]: tensor<?x?x?x?xf32>
860+
// CHECK: %[[COLLAPSE:.+]] = tensor.collapse_shape %[[ARG0]] {{\[}}[0], [1, 2], [3]]
861+
// CHECK: %[[EXPAND:.+]] = tensor.expand_shape %[[COLLAPSE]] {{\[}}[0], [1, 2], [3]]
862+
// CHECK: return %[[EXPAND]]

0 commit comments

Comments
 (0)