@@ -85,21 +85,55 @@ bool isReassociationValid(ArrayRef<AffineMap> reassociation,
85
85
template <typename ReshapeOpTy, typename InverseReshapeOpTy>
86
86
static OpFoldResult foldReshapeOp (ReshapeOpTy reshapeOp,
87
87
ArrayRef<Attribute> operands) {
88
-
88
+ // Fold identity reshape.
89
89
if (reshapeOp.getSrcType () == reshapeOp.getType ())
90
90
return reshapeOp.getSrc ();
91
91
92
- // Fold producer-consumer reshape ops where the operand type of the
93
- // producer is same as the return type of the consumer.
94
- auto reshapeSrcOp =
95
- reshapeOp.getSrc ().template getDefiningOp <InverseReshapeOpTy>();
96
- if (reshapeSrcOp && reshapeSrcOp.getSrcType () == reshapeOp.getResultType ())
97
- return reshapeSrcOp.getSrc ();
98
-
99
92
// Reshape of a constant can be replaced with a new constant.
100
93
if (auto elements = dyn_cast_or_null<DenseElementsAttr>(operands.front ()))
101
94
return elements.reshape (cast<ShapedType>(reshapeOp.getResult ().getType ()));
102
95
96
+ // Fold if the producer reshape source has the same shape with at most 1
97
+ // dynamic dimension.
98
+ auto reshapeSrcOp =
99
+ reshapeOp.getSrc ().template getDefiningOp <InverseReshapeOpTy>();
100
+ if (!reshapeSrcOp)
101
+ return nullptr ;
102
+ auto srcType = reshapeSrcOp.getSrcType ();
103
+ auto resultType = reshapeOp.getResultType ();
104
+ if (srcType != resultType)
105
+ return nullptr ;
106
+
107
+ // If the reshapes are expanding and then collapsing, the ops can be folded
108
+ // despite multiple dynamic dimensions.
109
+ if (srcType.getRank () < reshapeSrcOp.getResultType ().getRank ())
110
+ return reshapeSrcOp.getSrc ();
111
+ // Otherwise, only 1 dynamic dimension is allowed.
112
+ if (srcType == resultType &&
113
+ llvm::count_if (srcType.getShape (), ShapedType::isDynamic) < 2 ) {
114
+ return reshapeSrcOp.getSrc ();
115
+ }
116
+
117
+ // Fold producer-consumer reshape ops when they are perfect inverses of each
118
+ // other:
119
+ // 1) Reassociation indices are equivalent.
120
+ // 2) Boundary types are equivalent.
121
+ // 3) No reassociations have more than 1 dynamic dimension, and reassociated
122
+ // shapes are equal for each reassociation.
123
+ auto reassociations = reshapeOp.getReassociationIndices ();
124
+ auto inverseReassociations = reshapeSrcOp.getReassociationIndices ();
125
+ if (reassociations != inverseReassociations)
126
+ return nullptr ;
127
+ ArrayRef<int64_t > expandedSrcShape = srcType.getShape ();
128
+ ArrayRef<int64_t > expandedResultShape = resultType.getShape ();
129
+ if (llvm::none_of (reassociations, [&](auto reInd) {
130
+ auto srcSlice = expandedSrcShape.slice (reInd.front (), reInd.size ());
131
+ auto resSlice = expandedResultShape.slice (reInd.front (), reInd.size ());
132
+ return srcSlice == resSlice &&
133
+ llvm::count_if (srcSlice, ShapedType::isDynamic) > 1 ;
134
+ })) {
135
+ return reshapeSrcOp.getSrc ();
136
+ }
103
137
return nullptr ;
104
138
}
105
139
@@ -360,10 +394,12 @@ struct ComposeExpandOfCollapseOp : public OpRewritePattern<ExpandOpTy> {
360
394
resultShape.slice (resultIndices.front (), resultIndices.size ());
361
395
362
396
if (srcSubShape.size () == resultSubShape.size ()) {
363
- if (srcSubShape == resultSubShape)
397
+ if (srcSubShape == resultSubShape &&
398
+ llvm::count_if (srcSubShape, ShapedType::isDynamic) < 2 ) {
364
399
composedReassociation.push_back (srcIndices);
365
- else
400
+ } else {
366
401
return std::nullopt;
402
+ }
367
403
}
368
404
369
405
// Find reassociation to collapse `srcSubShape` into `resultSubShape`.
0 commit comments