@@ -1612,7 +1612,27 @@ struct ChainedReduction final : OpRewritePattern<vector::ReductionOp> {
1612
1612
}
1613
1613
};
1614
1614
1615
- // / For vectors with either leading or trailing unit dim, replaces:
1615
+ // Scalable unit dimensions are not supported. Folding such dimensions would
1616
+ // require "shifting" the scalable flag onto some other fixed-width dim (e.g.
1617
+ // vector<[1]x4xf32> -> vector<[4]xf32>). This could be implemented in the
1618
+ // future.
1619
+ static VectorType dropNonScalableUnitDimFromType (VectorType inVecTy) {
1620
+ auto inVecShape = inVecTy.getShape ();
1621
+ SmallVector<int64_t > newShape;
1622
+ SmallVector<bool > newScalableDims;
1623
+ for (auto [dim, isScalable] :
1624
+ llvm::zip_equal (inVecShape, inVecTy.getScalableDims ())) {
1625
+ if (dim == 1 && !isScalable)
1626
+ continue ;
1627
+
1628
+ newShape.push_back (dim);
1629
+ newScalableDims.push_back (isScalable);
1630
+ }
1631
+
1632
+ return VectorType::get (newShape, inVecTy.getElementType (), newScalableDims);
1633
+ }
1634
+
1635
+ // / For vectors with at least an unit dim, replaces:
1616
1636
// / elementwise(a, b)
1617
1637
// / with:
1618
1638
// / sc_a = shape_cast(a)
@@ -1624,20 +1644,16 @@ struct ChainedReduction final : OpRewritePattern<vector::ReductionOp> {
1624
1644
// / required to be rank > 1.
1625
1645
// /
1626
1646
// / Ex:
1627
- // / ```
1628
1647
// / %mul = arith.mulf %B_row, %A_row : vector<1x[4]xf32>
1629
1648
// / %cast = vector.shape_cast %mul : vector<1x[4]xf32> to vector<[4]xf32>
1630
- // / ```
1631
1649
// /
1632
1650
// / gets converted to:
1633
1651
// /
1634
- // / ```
1635
1652
// / %B_row_sc = vector.shape_cast %B_row : vector<1x[4]xf32> to vector<[4]xf32>
1636
1653
// / %A_row_sc = vector.shape_cast %A_row : vector<1x[4]xf32> to vector<[4]xf32>
1637
1654
// / %mul = arith.mulf %B_row_sc, %A_row_sc : vector<[4]xf32>
1638
1655
// / %cast_new = vector.shape_cast %mul : vector<[4]xf32> to vector<1x[4]xf32>
1639
1656
// / %cast = vector.shape_cast %cast_new : vector<1x[4]xf32> to vector<[4]xf32>
1640
- // / ```
1641
1657
// /
1642
1658
// / Patterns for folding shape_casts should instantly eliminate `%cast_new` and
1643
1659
// / `%cast`.
@@ -1657,42 +1673,29 @@ struct DropUnitDimFromElementwiseOps final
1657
1673
// guaranteed to have identical shapes (with some exceptions such as
1658
1674
// `arith.select`) and it suffices to only check one of them.
1659
1675
auto sourceVectorType = dyn_cast<VectorType>(op->getOperand (0 ).getType ());
1660
- if (!sourceVectorType)
1661
- return failure ();
1662
- if (sourceVectorType.getRank () < 2 )
1663
- return failure ();
1664
-
1665
- bool hasTrailingDimUnitFixed =
1666
- ((sourceVectorType.getShape ().back () == 1 ) &&
1667
- (!sourceVectorType.getScalableDims ().back ()));
1668
- bool hasLeadingDimUnitFixed =
1669
- ((sourceVectorType.getShape ().front () == 1 ) &&
1670
- (!sourceVectorType.getScalableDims ().front ()));
1671
- if (!hasLeadingDimUnitFixed && !hasTrailingDimUnitFixed)
1676
+ if (!sourceVectorType || sourceVectorType.getRank () < 2 )
1672
1677
return failure ();
1673
1678
1674
- // Drop leading/trailing unit dim by applying vector.shape_cast to all
1675
- // operands
1676
- int64_t dim = hasLeadingDimUnitFixed ? 0 : sourceVectorType.getRank () - 1 ;
1677
-
1678
1679
SmallVector<Value> newOperands;
1679
1680
auto loc = op->getLoc ();
1680
1681
for (auto operand : op->getOperands ()) {
1681
1682
auto opVectorType = cast<VectorType>(operand.getType ());
1682
- VectorType newVType = VectorType::Builder (opVectorType).dropDim (dim);
1683
+ auto newVType = dropNonScalableUnitDimFromType (opVectorType);
1684
+ if (newVType == opVectorType)
1685
+ return rewriter.notifyMatchFailure (op, " No unit dimension to remove." );
1686
+
1683
1687
auto opSC = rewriter.create <vector::ShapeCastOp>(loc, newVType, operand);
1684
1688
newOperands.push_back (opSC);
1685
1689
}
1686
1690
1687
1691
VectorType newResultVectorType =
1688
- VectorType::Builder (resultVectorType). dropDim (dim );
1689
- // Create an updated elementwise Op without leading/trailing unit dim
1692
+ dropNonScalableUnitDimFromType (resultVectorType);
1693
+ // Create an updated elementwise Op without unit dim.
1690
1694
Operation *elementwiseOp =
1691
1695
rewriter.create (loc, op->getName ().getIdentifier (), newOperands,
1692
1696
newResultVectorType, op->getAttrs ());
1693
1697
1694
- // Restore the leading/trailing unit dim by applying vector.shape_cast
1695
- // to the result
1698
+ // Restore the unit dim by applying vector.shape_cast to the result.
1696
1699
rewriter.replaceOpWithNewOp <ShapeCastOp>(op, resultVectorType,
1697
1700
elementwiseOp->getResult (0 ));
1698
1701
0 commit comments