Skip to content

Commit 6df6377

Browse files
committed
[MLIR][Vector] Enable DropUnitDimFromBroadcastOp
1 parent 8b22bb8 commit 6df6377

File tree

2 files changed

+117
-2
lines changed

2 files changed

+117
-2
lines changed

mlir/lib/Dialect/Vector/Transforms/VectorTransforms.cpp

Lines changed: 62 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1695,6 +1695,66 @@ struct DropUnitDimFromElementwiseOps final
16951695
}
16961696
};
16971697

1698+
/// Drops unit non scalable dimensions inside a broadcastOp which are shared
1699+
/// among source and result with shape_casts.
1700+
/// The newly inserted shape_cast Ops fold (before Op) and then
1701+
/// restore the unit dim after Op. Source type is required to be a vector.
1702+
///
1703+
/// Ex:
1704+
/// ```
1705+
/// %bc = vector.broadcast %arg0 : vector<1x4xf32> to vector<1x3x1x4xf32>
1706+
/// %cast = vector.shape_cast %bc : vector<1x3x1x4xf32> to vector<1x3x4xf32>
1707+
/// ```
1708+
///
1709+
/// Gets converted to:
1710+
///
1711+
/// ```
1712+
/// %sc_arg = vector.shape_cast %arg0 : vector<1x4xf32> to vector<4xf32>
1713+
/// %bc = vector.broadcast %arg : vector<4xf32> to vector<1x3x4xf32>
1714+
/// %cast_new = vector.shape_cast %bc : vector<1x3x4xf32> to
1715+
/// vector<1x3x1x4xf32>
1716+
/// %cast = vector.shape_cast %cast_new : vector<1x3x1x4xf32> to
1717+
/// vector<1x3x4xf32>
1718+
/// ```
1719+
/// %cast_new and %cast can be folded away.
1720+
struct DropUnitDimFromBroadcastOp final
1721+
: public OpRewritePattern<vector::BroadcastOp> {
1722+
using OpRewritePattern::OpRewritePattern;
1723+
1724+
LogicalResult matchAndRewrite(vector::BroadcastOp broadcastOp,
1725+
PatternRewriter &rewriter) const override {
1726+
auto srcVT = dyn_cast<VectorType>(broadcastOp.getSourceType());
1727+
if (!srcVT)
1728+
return failure();
1729+
auto resVT = broadcastOp.getResultVectorType();
1730+
VectorType newSrcVT = srcVT;
1731+
VectorType newResVT = resVT;
1732+
auto broadcastedUnitDims = broadcastOp.computeBroadcastedUnitDims();
1733+
// Reversing allows us to remove dims from the back without keeping track of
1734+
// removed dimensions.
1735+
for (const auto &dim : llvm::enumerate(llvm::reverse(srcVT.getShape()))) {
1736+
if (dim.value() == 1 &&
1737+
!srcVT.getScalableDims()[srcVT.getRank() - dim.index() - 1] &&
1738+
!broadcastedUnitDims.contains(srcVT.getRank() - dim.index() - 1)) {
1739+
newSrcVT = VectorType::Builder(newSrcVT).dropDim(srcVT.getRank() -
1740+
dim.index() - 1);
1741+
newResVT = VectorType::Builder(newResVT).dropDim(resVT.getRank() -
1742+
dim.index() - 1);
1743+
}
1744+
}
1745+
1746+
if (newSrcVT == srcVT)
1747+
return failure();
1748+
auto loc = broadcastOp->getLoc();
1749+
auto newSource = rewriter.create<vector::ShapeCastOp>(
1750+
loc, newSrcVT, broadcastOp.getSource());
1751+
auto newOp = rewriter.create<vector::BroadcastOp>(loc, newResVT, newSource);
1752+
rewriter.replaceOpWithNewOp<ShapeCastOp>(broadcastOp, resVT,
1753+
newOp.getResult());
1754+
return success();
1755+
}
1756+
};
1757+
16981758
/// Pattern to eliminate redundant zero-constants added to reduction operands.
16991759
/// It's enough for there to be one initial zero value, so we can eliminate the
17001760
/// extra ones that feed into `vector.reduction <add>`. These get created by the
@@ -1819,8 +1879,8 @@ void mlir::vector::populateShapeCastFoldingPatterns(RewritePatternSet &patterns,
18191879

18201880
void mlir::vector::populateDropUnitDimWithShapeCastPatterns(
18211881
RewritePatternSet &patterns, PatternBenefit benefit) {
1822-
patterns.add<DropUnitDimFromElementwiseOps, ShapeCastOpFolder>(
1823-
patterns.getContext(), benefit);
1882+
patterns.add<DropUnitDimFromElementwiseOps, DropUnitDimFromBroadcastOp,
1883+
ShapeCastOpFolder>(patterns.getContext(), benefit);
18241884
}
18251885

18261886
void mlir::vector::populateBubbleVectorBitCastOpPatterns(

mlir/test/Dialect/Vector/vector-transfer-flatten.mlir

Lines changed: 55 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -460,6 +460,61 @@ func.func @fold_unit_dims_entirely(%arg0 : vector<8xi32>,
460460
// CHECK-128B-NOT: memref.collapse_shape
461461

462462

463+
// -----
464+
465+
func.func @drop_broadcast_unit_dim(%arg0 : vector<1x[1]x3x1xf128>) -> vector<4x1x[1]x3x1xf128> {
466+
%bc = vector.broadcast %arg0 : vector<1x[1]x3x1xf128> to vector<4x1x[1]x3x1xf128>
467+
return %bc : vector<4x1x[1]x3x1xf128>
468+
}
469+
470+
// CHECK-LABEL: func.func @drop_broadcast_unit_dim(
471+
// CHECK-SAME: %[[VAL_0:.*]]: vector<1x[1]x3x1xf128>{{.*}}-> vector<4x1x[1]x3x1xf128> {
472+
// CHECK: %[[VAL_1:.*]] = vector.shape_cast %[[VAL_0]] : vector<1x[1]x3x1xf128> to vector<[1]x3xf128>
473+
// CHECK: %[[VAL_2:.*]] = vector.broadcast %[[VAL_1]] : vector<[1]x3xf128> to vector<4x[1]x3xf128>
474+
// CHECK: %[[VAL_3:.*]] = vector.shape_cast %[[VAL_2]] : vector<4x[1]x3xf128> to vector<4x1x[1]x3x1xf128>
475+
// CHECK: return %[[VAL_3]] : vector<4x1x[1]x3x1xf128>
476+
477+
// -----
478+
479+
func.func @drop_broadcasted_only_unit_dim(%arg0 : vector<1xf32>) -> vector<1x1xf32> {
480+
%bc = vector.broadcast %arg0 : vector<1xf32> to vector<1x1xf32>
481+
return %bc : vector<1x1xf32>
482+
}
483+
484+
// CHECK-LABEL: func.func @drop_broadcasted_only_unit_dim(
485+
// CHECK-SAME: %[[VAL_0:.*]]: vector<1xf32>) -> vector<1x1xf32> {
486+
// CHECK: %[[VAL_1:.*]] = vector.shape_cast %[[VAL_0]] : vector<1xf32> to vector<f32>
487+
// CHECK: %[[VAL_2:.*]] = vector.broadcast %[[VAL_1]] : vector<f32> to vector<1xf32>
488+
// CHECK: %[[VAL_3:.*]] = vector.shape_cast %[[VAL_2]] : vector<1xf32> to vector<1x1xf32>
489+
// CHECK: return %[[VAL_3]] : vector<1x1xf32>
490+
491+
// -----
492+
493+
// Generated unit dimensions through broadcasts are not dropped as we prefer to have a
494+
// single broadcast rather than a broadcast and a shape_cast.
495+
func.func @drop_broadcast_generated_unit_dim(%arg0 : vector<4xf32>) -> vector<3x1x4xf32> {
496+
%bc = vector.broadcast %arg0 : vector<4xf32> to vector<3x1x4xf32>
497+
return %bc : vector<3x1x4xf32>
498+
}
499+
500+
// CHECK-LABEL: func.func @drop_broadcast_generated_unit_dim(
501+
// CHECK-SAME: %[[VAL_0:.*]]: vector<4xf32>{{.*}}-> vector<3x1x4xf32> {
502+
// CHECK: %[[VAL_1:.*]] = vector.broadcast %[[VAL_0]] : vector<4xf32> to vector<3x1x4xf32>
503+
// CHECK: return %[[VAL_1]] : vector<3x1x4xf32>
504+
505+
// -----
506+
507+
// A broadcasted unit dimension cannot be dropped to prevent type mismatch.
508+
func.func @drop_broadcasted_unit_dim(%arg0 : vector<2x1x4xf32>) -> vector<2x3x4xf32> {
509+
%bc = vector.broadcast %arg0 : vector<2x1x4xf32> to vector<2x3x4xf32>
510+
return %bc : vector<2x3x4xf32>
511+
}
512+
// CHECK-LABEL: func.func @drop_broadcasted_unit_dim(
513+
// CHECK-SAME: %[[VAL_0:.*]]: vector<2x1x4xf32>{{.*}}-> vector<2x3x4xf32> {
514+
// CHECK: %[[VAL_1:.*]] = vector.broadcast %[[VAL_0]] : vector<2x1x4xf32> to vector<2x3x4xf32>
515+
// CHECK: return %[[VAL_1]] : vector<2x3x4xf32>
516+
517+
463518
// -----
464519

465520
func.func @regression_non_contiguous_dim_read(%subview : memref<1x3x3x2xf32, strided<[40, 10, 2, 1], offset: ?>>,

0 commit comments

Comments
 (0)