@@ -1607,20 +1607,20 @@ struct ChainedReduction final : OpRewritePattern<vector::ReductionOp> {
1607
1607
}
1608
1608
};
1609
1609
1610
- FailureOr<VectorType> dropNonScalableUnitDimType (VectorType inVecTy) {
1611
- int numUnitDimsDropped = 0 ;
1612
- auto inVecShape = inVecTy.getShape ();
1610
+ VectorType dropNonScalableUnitDimType (VectorType inVecTy) {
1613
1611
auto newVecBuilder = VectorType::Builder (inVecTy);
1614
- for (unsigned i = 0 ; i < inVecShape.size (); i++) {
1615
- if (inVecShape[i] == 1 && !inVecTy.getScalableDims ()[i]) {
1616
- newVecBuilder.dropDim (i - numUnitDimsDropped);
1617
- numUnitDimsDropped++;
1612
+ auto inVecShape = inVecTy.getShape ();
1613
+ SmallVector<int64_t > newShape;
1614
+ SmallVector<bool > newScalableDims;
1615
+ for (auto [dim, isScalable] :
1616
+ llvm::zip (inVecShape, inVecTy.getScalableDims ())) {
1617
+ if (dim != 1 || isScalable) {
1618
+ newShape.push_back (dim);
1619
+ newScalableDims.push_back (isScalable);
1618
1620
}
1619
1621
}
1620
1622
1621
- if (numUnitDimsDropped == 0 )
1622
- return failure ();
1623
- return VectorType (newVecBuilder);
1623
+ return VectorType::get (newShape, inVecTy.getElementType (), newScalableDims);
1624
1624
}
1625
1625
1626
1626
// / For vectors with at least an unit dim, replaces:
@@ -1676,16 +1676,15 @@ struct DropUnitDimFromElementwiseOps final
1676
1676
for (auto operand : op->getOperands ()) {
1677
1677
auto opVectorType = cast<VectorType>(operand.getType ());
1678
1678
auto newVType = dropNonScalableUnitDimType (opVectorType);
1679
- if (failed ( newVType) ) {
1680
- return failure ( );
1679
+ if (newVType == opVectorType ) {
1680
+ return rewriter. notifyMatchFailure (op, " No unit dimension to remove. " );
1681
1681
}
1682
- auto opSC =
1683
- rewriter.create <vector::ShapeCastOp>(loc, newVType.value (), operand);
1682
+ auto opSC = rewriter.create <vector::ShapeCastOp>(loc, newVType, operand);
1684
1683
newOperands.push_back (opSC);
1685
1684
}
1686
1685
1687
1686
VectorType newResultVectorType =
1688
- dropNonScalableUnitDimType (resultVectorType). value () ;
1687
+ dropNonScalableUnitDimType (resultVectorType);
1689
1688
// Create an updated elementwise Op without unit dim
1690
1689
Operation *elementwiseOp =
1691
1690
rewriter.create (loc, op->getName ().getIdentifier (), newOperands,
0 commit comments