Skip to content

Commit 7b09906

Browse files
committed
Fixup: simple exit and use zip_equal
1 parent 32e81c5 commit 7b09906

File tree

1 file changed

+8
-7
lines changed

1 file changed

+8
-7
lines changed

mlir/lib/Dialect/Vector/Transforms/VectorTransforms.cpp

Lines changed: 8 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -1613,11 +1613,12 @@ VectorType dropNonScalableUnitDimType(VectorType inVecTy) {
16131613
SmallVector<int64_t> newShape;
16141614
SmallVector<bool> newScalableDims;
16151615
for (auto [dim, isScalable] :
1616-
llvm::zip(inVecShape, inVecTy.getScalableDims())) {
1617-
if (dim != 1 || isScalable) {
1618-
newShape.push_back(dim);
1619-
newScalableDims.push_back(isScalable);
1620-
}
1616+
llvm::zip_equal(inVecShape, inVecTy.getScalableDims())) {
1617+
if (dim == 1 && !isScalable)
1618+
continue;
1619+
1620+
newShape.push_back(dim);
1621+
newScalableDims.push_back(isScalable);
16211622
}
16221623

16231624
return VectorType::get(newShape, inVecTy.getElementType(), newScalableDims);
@@ -1676,9 +1677,9 @@ struct DropUnitDimFromElementwiseOps final
16761677
for (auto operand : op->getOperands()) {
16771678
auto opVectorType = cast<VectorType>(operand.getType());
16781679
auto newVType = dropNonScalableUnitDimType(opVectorType);
1679-
if (newVType == opVectorType) {
1680+
if (newVType == opVectorType)
16801681
return rewriter.notifyMatchFailure(op, "No unit dimension to remove.");
1681-
}
1682+
16821683
auto opSC = rewriter.create<vector::ShapeCastOp>(loc, newVType, operand);
16831684
newOperands.push_back(opSC);
16841685
}

0 commit comments

Comments
 (0)