@@ -1613,11 +1613,12 @@ VectorType dropNonScalableUnitDimType(VectorType inVecTy) {
1613
1613
SmallVector<int64_t > newShape;
1614
1614
SmallVector<bool > newScalableDims;
1615
1615
for (auto [dim, isScalable] :
1616
- llvm::zip (inVecShape, inVecTy.getScalableDims ())) {
1617
- if (dim != 1 || isScalable) {
1618
- newShape.push_back (dim);
1619
- newScalableDims.push_back (isScalable);
1620
- }
1616
+ llvm::zip_equal (inVecShape, inVecTy.getScalableDims ())) {
1617
+ if (dim == 1 && !isScalable)
1618
+ continue ;
1619
+
1620
+ newShape.push_back (dim);
1621
+ newScalableDims.push_back (isScalable);
1621
1622
}
1622
1623
1623
1624
return VectorType::get (newShape, inVecTy.getElementType (), newScalableDims);
@@ -1676,9 +1677,9 @@ struct DropUnitDimFromElementwiseOps final
1676
1677
for (auto operand : op->getOperands ()) {
1677
1678
auto opVectorType = cast<VectorType>(operand.getType ());
1678
1679
auto newVType = dropNonScalableUnitDimType (opVectorType);
1679
- if (newVType == opVectorType) {
1680
+ if (newVType == opVectorType)
1680
1681
return rewriter.notifyMatchFailure (op, " No unit dimension to remove." );
1681
- }
1682
+
1682
1683
auto opSC = rewriter.create <vector::ShapeCastOp>(loc, newVType, operand);
1683
1684
newOperands.push_back (opSC);
1684
1685
}
0 commit comments