Skip to content

Commit 2dc8fea

Browse files
committed
address comments
1 parent 0536c7e commit 2dc8fea

File tree

1 file changed

+6
-9
lines changed

1 file changed

+6
-9
lines changed

mlir/include/mlir/Dialect/Utils/ReshapeOpsUtils.h

Lines changed: 6 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -104,8 +104,7 @@ static OpFoldResult foldReshapeOp(ReshapeOpTy reshapeOp,
104104
if (srcType != resultType)
105105
return nullptr;
106106

107-
if (srcType == resultType &&
108-
llvm::count_if(srcType.getShape(), ShapedType::isDynamic) < 2) {
107+
if (llvm::count_if(srcType.getShape(), ShapedType::isDynamic) < 2) {
109108
return reshapeSrcOp.getSrc();
110109
}
111110

@@ -116,20 +115,18 @@ static OpFoldResult foldReshapeOp(ReshapeOpTy reshapeOp,
116115
// 3) No reassociations have more than 1 dynamic dimension, and reassociated
117116
// shapes are equal for each reassociation.
118117
auto reassociations = reshapeOp.getReassociationIndices();
119-
auto inverseReassociations = reshapeSrcOp.getReassociationIndices();
120-
if (reassociations != inverseReassociations)
118+
if (reassociations != reshapeSrcOp.getReassociationIndices())
121119
return nullptr;
122120
// If the reshapes are expanding and then collapsing, the ops can be folded
123121
// despite multiple dynamic dimensions.
124122
if (srcType.getRank() < reshapeSrcOp.getResultType().getRank())
125123
return reshapeSrcOp.getSrc();
126124
ArrayRef<int64_t> expandedSrcShape = srcType.getShape();
127125
ArrayRef<int64_t> expandedResultShape = resultType.getShape();
128-
if (llvm::none_of(reassociations, [&](auto reInd) {
129-
auto srcSlice = expandedSrcShape.slice(reInd.front(), reInd.size());
130-
auto resSlice = expandedResultShape.slice(reInd.front(), reInd.size());
131-
return srcSlice == resSlice &&
132-
llvm::count_if(srcSlice, ShapedType::isDynamic) > 1;
126+
if (llvm::all_of(reassociations, [&](auto reInd) {
127+
ArrayRef<int64_t> srcSlice =
128+
expandedSrcShape.slice(reInd.front(), reInd.size());
129+
return llvm::count_if(srcSlice, ShapedType::isDynamic) < 2;
133130
})) {
134131
return reshapeSrcOp.getSrc();
135132
}

0 commit comments

Comments
 (0)