@@ -104,8 +104,7 @@ static OpFoldResult foldReshapeOp(ReshapeOpTy reshapeOp,
104
104
if (srcType != resultType)
105
105
return nullptr ;
106
106
107
- if (srcType == resultType &&
108
- llvm::count_if (srcType.getShape (), ShapedType::isDynamic) < 2 ) {
107
+ if (llvm::count_if (srcType.getShape (), ShapedType::isDynamic) < 2 ) {
109
108
return reshapeSrcOp.getSrc ();
110
109
}
111
110
@@ -116,20 +115,18 @@ static OpFoldResult foldReshapeOp(ReshapeOpTy reshapeOp,
116
115
// 3) No reassociations have more than 1 dynamic dimension, and reassociated
117
116
// shapes are equal for each reassociation.
118
117
auto reassociations = reshapeOp.getReassociationIndices ();
119
- auto inverseReassociations = reshapeSrcOp.getReassociationIndices ();
120
- if (reassociations != inverseReassociations)
118
+ if (reassociations != reshapeSrcOp.getReassociationIndices ())
121
119
return nullptr ;
122
120
// If the reshapes are expanding and then collapsing, the ops can be folded
123
121
// despite multiple dynamic dimensions.
124
122
if (srcType.getRank () < reshapeSrcOp.getResultType ().getRank ())
125
123
return reshapeSrcOp.getSrc ();
126
124
ArrayRef<int64_t > expandedSrcShape = srcType.getShape ();
127
125
ArrayRef<int64_t > expandedResultShape = resultType.getShape ();
128
- if (llvm::none_of (reassociations, [&](auto reInd) {
129
- auto srcSlice = expandedSrcShape.slice (reInd.front (), reInd.size ());
130
- auto resSlice = expandedResultShape.slice (reInd.front (), reInd.size ());
131
- return srcSlice == resSlice &&
132
- llvm::count_if (srcSlice, ShapedType::isDynamic) > 1 ;
126
+ if (llvm::all_of (reassociations, [&](auto reInd) {
127
+ ArrayRef<int64_t > srcSlice =
128
+ expandedSrcShape.slice (reInd.front (), reInd.size ());
129
+ return llvm::count_if (srcSlice, ShapedType::isDynamic) < 2 ;
133
130
})) {
134
131
return reshapeSrcOp.getSrc ();
135
132
}
0 commit comments