@@ -1695,6 +1695,66 @@ struct DropUnitDimFromElementwiseOps final
1695
1695
}
1696
1696
};
1697
1697
1698
+ // / Drops unit non scalable dimensions inside a broadcastOp which are shared
1699
+ // / among source and result with shape_casts.
1700
+ // / The newly inserted shape_cast Ops fold (before Op) and then
1701
+ // / restore the unit dim after Op. Source type is required to be a vector.
1702
+ // /
1703
+ // / Ex:
1704
+ // / ```
1705
+ // / %bc = vector.broadcast %arg0 : vector<1x4xf32> to vector<1x3x1x4xf32>
1706
+ // / %cast = vector.shape_cast %bc : vector<1x3x1x4xf32> to vector<1x3x4xf32>
1707
+ // / ```
1708
+ // /
1709
+ // / Gets converted to:
1710
+ // /
1711
+ // / ```
1712
+ // / %sc_arg = vector.shape_cast %arg0 : vector<1x4xf32> to vector<4xf32>
1713
+ // / %bc = vector.broadcast %arg : vector<4xf32> to vector<1x3x4xf32>
1714
+ // / %cast_new = vector.shape_cast %bc : vector<1x3x4xf32> to
1715
+ // / vector<1x3x1x4xf32>
1716
+ // / %cast = vector.shape_cast %cast_new : vector<1x3x1x4xf32> to
1717
+ // / vector<1x3x4xf32>
1718
+ // / ```
1719
+ // / %cast_new and %cast can be folded away.
1720
+ struct DropUnitDimFromBroadcastOp final
1721
+ : public OpRewritePattern<vector::BroadcastOp> {
1722
+ using OpRewritePattern::OpRewritePattern;
1723
+
1724
+ LogicalResult matchAndRewrite (vector::BroadcastOp broadcastOp,
1725
+ PatternRewriter &rewriter) const override {
1726
+ auto srcVT = dyn_cast<VectorType>(broadcastOp.getSourceType ());
1727
+ if (!srcVT)
1728
+ return failure ();
1729
+ auto resVT = broadcastOp.getResultVectorType ();
1730
+ VectorType newSrcVT = srcVT;
1731
+ VectorType newResVT = resVT;
1732
+ auto broadcastedUnitDims = broadcastOp.computeBroadcastedUnitDims ();
1733
+ // Reversing allows us to remove dims from the back without keeping track of
1734
+ // removed dimensions.
1735
+ for (const auto &dim : llvm::enumerate (llvm::reverse (srcVT.getShape ()))) {
1736
+ if (dim.value () == 1 &&
1737
+ !srcVT.getScalableDims ()[srcVT.getRank () - dim.index () - 1 ] &&
1738
+ !broadcastedUnitDims.contains (srcVT.getRank () - dim.index () - 1 )) {
1739
+ newSrcVT = VectorType::Builder (newSrcVT).dropDim (srcVT.getRank () -
1740
+ dim.index () - 1 );
1741
+ newResVT = VectorType::Builder (newResVT).dropDim (resVT.getRank () -
1742
+ dim.index () - 1 );
1743
+ }
1744
+ }
1745
+
1746
+ if (newSrcVT == srcVT)
1747
+ return failure ();
1748
+ auto loc = broadcastOp->getLoc ();
1749
+ auto newSource = rewriter.create <vector::ShapeCastOp>(
1750
+ loc, newSrcVT, broadcastOp.getSource ());
1751
+ auto newOp = rewriter.create <vector::BroadcastOp>(loc, newResVT, newSource);
1752
+ rewriter.replaceOpWithNewOp <ShapeCastOp>(broadcastOp, resVT,
1753
+ newOp.getResult ());
1754
+ return success ();
1755
+ }
1756
+ };
1757
+
1698
1758
// / Pattern to eliminate redundant zero-constants added to reduction operands.
1699
1759
// / It's enough for there to be one initial zero value, so we can eliminate the
1700
1760
// / extra ones that feed into `vector.reduction <add>`. These get created by the
@@ -1819,8 +1879,8 @@ void mlir::vector::populateShapeCastFoldingPatterns(RewritePatternSet &patterns,
1819
1879
1820
1880
void mlir::vector::populateDropUnitDimWithShapeCastPatterns (
1821
1881
RewritePatternSet &patterns, PatternBenefit benefit) {
1822
- patterns.add <DropUnitDimFromElementwiseOps, ShapeCastOpFolder>(
1823
- patterns.getContext (), benefit);
1882
+ patterns.add <DropUnitDimFromElementwiseOps, DropUnitDimFromBroadcastOp,
1883
+ ShapeCastOpFolder>( patterns.getContext (), benefit);
1824
1884
}
1825
1885
1826
1886
void mlir::vector::populateBubbleVectorBitCastOpPatterns (
0 commit comments