@@ -1703,6 +1703,66 @@ struct DropUnitDimFromElementwiseOps final
1703
1703
}
1704
1704
};
1705
1705
1706
+ // / Drops unit non scalable dimensions inside a broadcastOp which are shared
1707
+ // / among source and result with shape_casts.
1708
+ // / The newly inserted shape_cast Ops fold (before Op) and then
1709
+ // / restore the unit dim after Op. Source type is required to be a vector.
1710
+ // /
1711
+ // / Ex:
1712
+ // / ```
1713
+ // / %bc = vector.broadcast %arg0 : vector<1x4xf32> to vector<1x3x1x4xf32>
1714
+ // / %cast = vector.shape_cast %bc : vector<1x3x1x4xf32> to vector<1x3x4xf32>
1715
+ // / ```
1716
+ // /
1717
+ // / Gets converted to:
1718
+ // /
1719
+ // / ```
1720
+ // / %sc_arg = vector.shape_cast %arg0 : vector<1x4xf32> to vector<4xf32>
1721
+ // / %bc = vector.broadcast %arg : vector<4xf32> to vector<1x3x4xf32>
1722
+ // / %cast_new = vector.shape_cast %bc : vector<1x3x4xf32> to
1723
+ // / vector<1x3x1x4xf32>
1724
+ // / %cast = vector.shape_cast %cast_new : vector<1x3x1x4xf32> to
1725
+ // / vector<1x3x4xf32>
1726
+ // / ```
1727
+ // / %cast_new and %cast can be folded away.
1728
+ struct DropUnitDimFromBroadcastOp final
1729
+ : public OpRewritePattern<vector::BroadcastOp> {
1730
+ using OpRewritePattern::OpRewritePattern;
1731
+
1732
+ LogicalResult matchAndRewrite (vector::BroadcastOp broadcastOp,
1733
+ PatternRewriter &rewriter) const override {
1734
+ auto srcVT = dyn_cast<VectorType>(broadcastOp.getSourceType ());
1735
+ if (!srcVT)
1736
+ return failure ();
1737
+ auto resVT = broadcastOp.getResultVectorType ();
1738
+ VectorType newSrcVT = srcVT;
1739
+ VectorType newResVT = resVT;
1740
+ auto broadcastedUnitDims = broadcastOp.computeBroadcastedUnitDims ();
1741
+ // Reversing allows us to remove dims from the back without keeping track of
1742
+ // removed dimensions.
1743
+ for (const auto &dim : llvm::enumerate (llvm::reverse (srcVT.getShape ()))) {
1744
+ if (dim.value () == 1 &&
1745
+ !srcVT.getScalableDims ()[srcVT.getRank () - dim.index () - 1 ] &&
1746
+ !broadcastedUnitDims.contains (srcVT.getRank () - dim.index () - 1 )) {
1747
+ newSrcVT = VectorType::Builder (newSrcVT).dropDim (srcVT.getRank () -
1748
+ dim.index () - 1 );
1749
+ newResVT = VectorType::Builder (newResVT).dropDim (resVT.getRank () -
1750
+ dim.index () - 1 );
1751
+ }
1752
+ }
1753
+
1754
+ if (newSrcVT == srcVT)
1755
+ return failure ();
1756
+ auto loc = broadcastOp->getLoc ();
1757
+ auto newSource = rewriter.create <vector::ShapeCastOp>(
1758
+ loc, newSrcVT, broadcastOp.getSource ());
1759
+ auto newOp = rewriter.create <vector::BroadcastOp>(loc, newResVT, newSource);
1760
+ rewriter.replaceOpWithNewOp <ShapeCastOp>(broadcastOp, resVT,
1761
+ newOp.getResult ());
1762
+ return success ();
1763
+ }
1764
+ };
1765
+
1706
1766
// / Pattern to eliminate redundant zero-constants added to reduction operands.
1707
1767
// / It's enough for there to be one initial zero value, so we can eliminate the
1708
1768
// / extra ones that feed into `vector.reduction <add>`. These get created by the
@@ -1827,8 +1887,8 @@ void mlir::vector::populateShapeCastFoldingPatterns(RewritePatternSet &patterns,
1827
1887
1828
1888
void mlir::vector::populateDropUnitDimWithShapeCastPatterns (
1829
1889
RewritePatternSet &patterns, PatternBenefit benefit) {
1830
- patterns.add <DropUnitDimFromElementwiseOps, ShapeCastOpFolder>(
1831
- patterns.getContext (), benefit);
1890
+ patterns.add <DropUnitDimFromElementwiseOps, DropUnitDimFromBroadcastOp,
1891
+ ShapeCastOpFolder>( patterns.getContext (), benefit);
1832
1892
}
1833
1893
1834
1894
void mlir::vector::populateBubbleVectorBitCastOpPatterns (
0 commit comments