Skip to content

Commit 930a4d7

Browse files
committed
[mlir] Fix bugs in expand_shape patterns after semantics changes
1 parent 554a2fa commit 930a4d7

File tree

2 files changed

+101
-12
lines changed

2 files changed

+101
-12
lines changed

mlir/include/mlir/Dialect/Utils/ReshapeOpsUtils.h

Lines changed: 46 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -85,21 +85,55 @@ bool isReassociationValid(ArrayRef<AffineMap> reassociation,
8585
template <typename ReshapeOpTy, typename InverseReshapeOpTy>
8686
static OpFoldResult foldReshapeOp(ReshapeOpTy reshapeOp,
8787
ArrayRef<Attribute> operands) {
88-
88+
// Fold identity reshape.
8989
if (reshapeOp.getSrcType() == reshapeOp.getType())
9090
return reshapeOp.getSrc();
9191

92-
// Fold producer-consumer reshape ops where the operand type of the
93-
// producer is same as the return type of the consumer.
94-
auto reshapeSrcOp =
95-
reshapeOp.getSrc().template getDefiningOp<InverseReshapeOpTy>();
96-
if (reshapeSrcOp && reshapeSrcOp.getSrcType() == reshapeOp.getResultType())
97-
return reshapeSrcOp.getSrc();
98-
9992
// Reshape of a constant can be replaced with a new constant.
10093
if (auto elements = dyn_cast_or_null<DenseElementsAttr>(operands.front()))
10194
return elements.reshape(cast<ShapedType>(reshapeOp.getResult().getType()));
10295

96+
// Fold if the producer reshape source has the same shape with at most 1
97+
// dynamic dimension.
98+
auto reshapeSrcOp =
99+
reshapeOp.getSrc().template getDefiningOp<InverseReshapeOpTy>();
100+
if (!reshapeSrcOp)
101+
return nullptr;
102+
auto srcType = reshapeSrcOp.getSrcType();
103+
auto resultType = reshapeOp.getResultType();
104+
if (srcType != resultType)
105+
return nullptr;
106+
107+
// If the reshapes are expanding and then collapsing, the ops can be folded
108+
// despite multiple dynamic dimensions.
109+
if (srcType.getRank() < reshapeSrcOp.getResultType().getRank())
110+
return reshapeSrcOp.getSrc();
111+
// Otherwise, only 1 dynamic dimension is allowed.
112+
if (srcType == resultType &&
113+
llvm::count_if(srcType.getShape(), ShapedType::isDynamic) < 2) {
114+
return reshapeSrcOp.getSrc();
115+
}
116+
117+
// Fold producer-consumer reshape ops when they are perfect inverses of each
118+
// other:
119+
// 1) Reassociation indices are equivalent.
120+
// 2) Boundary types are equivalent.
121+
// 3) No reassociations have more than 1 dynamic dimension, and reassociated
122+
// shapes are equal for each reassociation.
123+
auto reassociations = reshapeOp.getReassociationIndices();
124+
auto inverseReassociations = reshapeSrcOp.getReassociationIndices();
125+
if (reassociations != inverseReassociations)
126+
return nullptr;
127+
ArrayRef<int64_t> expandedSrcShape = srcType.getShape();
128+
ArrayRef<int64_t> expandedResultShape = resultType.getShape();
129+
if (llvm::none_of(reassociations, [&](auto reInd) {
130+
auto srcSlice = expandedSrcShape.slice(reInd.front(), reInd.size());
131+
auto resSlice = expandedResultShape.slice(reInd.front(), reInd.size());
132+
return srcSlice == resSlice &&
133+
llvm::count_if(srcSlice, ShapedType::isDynamic) > 1;
134+
})) {
135+
return reshapeSrcOp.getSrc();
136+
}
103137
return nullptr;
104138
}
105139

@@ -360,10 +394,12 @@ struct ComposeExpandOfCollapseOp : public OpRewritePattern<ExpandOpTy> {
360394
resultShape.slice(resultIndices.front(), resultIndices.size());
361395

362396
if (srcSubShape.size() == resultSubShape.size()) {
363-
if (srcSubShape == resultSubShape)
397+
if (srcSubShape == resultSubShape &&
398+
llvm::count_if(srcSubShape, ShapedType::isDynamic) < 2) {
364399
composedReassociation.push_back(srcIndices);
365-
else
400+
} else {
366401
return std::nullopt;
402+
}
367403
}
368404

369405
// Find reassociation to collapse `srcSubShape` into `resultSubShape`.

mlir/test/Dialect/Tensor/canonicalize.mlir

Lines changed: 55 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1139,7 +1139,7 @@ func.func @fold_collapse_of_expand(%arg0 : tensor<12x4xf32>) -> tensor<12x4xf32>
11391139
return %1 : tensor<12x4xf32>
11401140
}
11411141
// CHECK-LABEL: @fold_collapse_of_expand
1142-
// CHECK-NOT: linalg.{{.*}}shape
1142+
// CHECK-NOT: tensor.{{.*}}_shape
11431143

11441144
// -----
11451145

@@ -1152,7 +1152,60 @@ func.func @fold_collapse_of_expand_dynamic(%arg0 : tensor<?x?xf32>, %arg1: index
11521152
return %1 : tensor<?x?xf32>
11531153
}
11541154
// CHECK-LABEL: @fold_collapse_of_expand_dynamic
1155-
// CHECK-NOT: linalg.{{.*}}_shape
1155+
// CHECK-NOT: tensor.{{.*}}_shape
1156+
1157+
// -----
1158+
1159+
func.func @fold_collapse_of_expand_fully_dynamic(%arg0 : tensor<?x?xf32>, %arg1: index, %arg2: index, %arg3: index)
1160+
-> tensor<?x?xf32> {
1161+
%0 = tensor.expand_shape %arg0 [[0, 1], [2]] output_shape [%arg1, %arg2, %arg3]
1162+
: tensor<?x?xf32> into tensor<?x?x?xf32>
1163+
%1 = tensor.collapse_shape %0 [[0, 1], [2]]
1164+
: tensor<?x?x?xf32> into tensor<?x?xf32>
1165+
return %1 : tensor<?x?xf32>
1166+
}
1167+
// CHECK-LABEL: @fold_collapse_of_expand_fully_dynamic
1168+
// CHECK-NOT: tensor.{{.*}}_shape
1169+
1170+
// -----
1171+
1172+
func.func @fold_expand_of_collapse(%arg0 : tensor<3x4x4xf32>) -> tensor<3x4x4xf32> {
1173+
%0 = tensor.collapse_shape %arg0 [[0, 1], [2]]
1174+
: tensor<3x4x4xf32> into tensor<12x4xf32>
1175+
%1 = tensor.expand_shape %0 [[0, 1], [2]] output_shape [3, 4, 4]
1176+
: tensor<12x4xf32> into tensor<3x4x4xf32>
1177+
return %1 : tensor<3x4x4xf32>
1178+
}
1179+
// CHECK-LABEL: @fold_expand_of_collapse
1180+
// CHECK-NOT: tensor.{{.*}}_shape
1181+
1182+
// -----
1183+
1184+
func.func @fold_expand_of_collapse_dynamic(%arg0 : tensor<?x4x?xf32>, %arg1: index, %arg2: index)
1185+
-> tensor<?x4x?xf32> {
1186+
%0 = tensor.collapse_shape %arg0 [[0, 1], [2]]
1187+
: tensor<?x4x?xf32> into tensor<?x?xf32>
1188+
%1 = tensor.expand_shape %0 [[0, 1], [2]] output_shape [%arg1, 4, %arg2]
1189+
: tensor<?x?xf32> into tensor<?x4x?xf32>
1190+
return %1 : tensor<?x4x?xf32>
1191+
}
1192+
// CHECK-LABEL: @fold_expand_of_collapse_dynamic
1193+
// CHECK-NOT: tensor.{{.*}}_shape
1194+
1195+
// -----
1196+
1197+
func.func @no_fold_expand_of_collapse_dynamic(%arg0 : tensor<?x?x?xf32>, %arg1: index, %arg2: index, %arg3: index)
1198+
-> tensor<?x?x?xf32> {
1199+
%0 = tensor.collapse_shape %arg0 [[0, 1], [2]]
1200+
: tensor<?x?x?xf32> into tensor<?x?xf32>
1201+
%1 = tensor.expand_shape %0 [[0, 1], [2]] output_shape [%arg1, %arg2, %arg3]
1202+
: tensor<?x?xf32> into tensor<?x?x?xf32>
1203+
return %1 : tensor<?x?x?xf32>
1204+
}
1205+
// CHECK-LABEL: @no_fold_expand_of_collapse_dynamic
1206+
// CHECK: tensor.collapse_shape
1207+
// CHECK: %[[EXPAND:.+]] = tensor.expand_shape
1208+
// CHECK: return %[[EXPAND]]
11561209

11571210
// -----
11581211

0 commit comments

Comments
 (0)