Skip to content

Commit d94eb3d

Browse files
committed
[MLIR][Vector] Enable DropUnitDimFromBroadcastOp
1 parent 2c06fb8 commit d94eb3d

File tree

2 files changed

+116
-2
lines changed

2 files changed

+116
-2
lines changed

mlir/lib/Dialect/Vector/Transforms/VectorTransforms.cpp

Lines changed: 62 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1703,6 +1703,66 @@ struct DropUnitDimFromElementwiseOps final
17031703
}
17041704
};
17051705

1706+
/// Drops unit non scalable dimensions inside a broadcastOp which are shared
1707+
/// among source and result with shape_casts.
1708+
/// The newly inserted shape_cast Ops fold (before Op) and then
1709+
/// restore the unit dim after Op. Source type is required to be a vector.
1710+
///
1711+
/// Ex:
1712+
/// ```
1713+
/// %bc = vector.broadcast %arg0 : vector<1x4xf32> to vector<1x3x1x4xf32>
1714+
/// %cast = vector.shape_cast %bc : vector<1x3x1x4xf32> to vector<1x3x4xf32>
1715+
/// ```
1716+
///
1717+
/// Gets converted to:
1718+
///
1719+
/// ```
1720+
/// %sc_arg = vector.shape_cast %arg0 : vector<1x4xf32> to vector<4xf32>
1721+
/// %bc = vector.broadcast %arg : vector<4xf32> to vector<1x3x4xf32>
1722+
/// %cast_new = vector.shape_cast %bc : vector<1x3x4xf32> to
1723+
/// vector<1x3x1x4xf32>
1724+
/// %cast = vector.shape_cast %cast_new : vector<1x3x1x4xf32> to
1725+
/// vector<1x3x4xf32>
1726+
/// ```
1727+
/// %cast_new and %cast can be folded away.
1728+
struct DropUnitDimFromBroadcastOp final
1729+
: public OpRewritePattern<vector::BroadcastOp> {
1730+
using OpRewritePattern::OpRewritePattern;
1731+
1732+
LogicalResult matchAndRewrite(vector::BroadcastOp broadcastOp,
1733+
PatternRewriter &rewriter) const override {
1734+
auto srcVT = dyn_cast<VectorType>(broadcastOp.getSourceType());
1735+
if (!srcVT)
1736+
return failure();
1737+
auto resVT = broadcastOp.getResultVectorType();
1738+
VectorType newSrcVT = srcVT;
1739+
VectorType newResVT = resVT;
1740+
auto broadcastedUnitDims = broadcastOp.computeBroadcastedUnitDims();
1741+
// Reversing allows us to remove dims from the back without keeping track of
1742+
// removed dimensions.
1743+
for (const auto &dim : llvm::enumerate(llvm::reverse(srcVT.getShape()))) {
1744+
if (dim.value() == 1 &&
1745+
!srcVT.getScalableDims()[srcVT.getRank() - dim.index() - 1] &&
1746+
!broadcastedUnitDims.contains(srcVT.getRank() - dim.index() - 1)) {
1747+
newSrcVT = VectorType::Builder(newSrcVT).dropDim(srcVT.getRank() -
1748+
dim.index() - 1);
1749+
newResVT = VectorType::Builder(newResVT).dropDim(resVT.getRank() -
1750+
dim.index() - 1);
1751+
}
1752+
}
1753+
1754+
if (newSrcVT == srcVT)
1755+
return failure();
1756+
auto loc = broadcastOp->getLoc();
1757+
auto newSource = rewriter.create<vector::ShapeCastOp>(
1758+
loc, newSrcVT, broadcastOp.getSource());
1759+
auto newOp = rewriter.create<vector::BroadcastOp>(loc, newResVT, newSource);
1760+
rewriter.replaceOpWithNewOp<ShapeCastOp>(broadcastOp, resVT,
1761+
newOp.getResult());
1762+
return success();
1763+
}
1764+
};
1765+
17061766
/// Pattern to eliminate redundant zero-constants added to reduction operands.
17071767
/// It's enough for there to be one initial zero value, so we can eliminate the
17081768
/// extra ones that feed into `vector.reduction <add>`. These get created by the
@@ -1827,8 +1887,8 @@ void mlir::vector::populateShapeCastFoldingPatterns(RewritePatternSet &patterns,
18271887

18281888
void mlir::vector::populateDropUnitDimWithShapeCastPatterns(
18291889
RewritePatternSet &patterns, PatternBenefit benefit) {
1830-
patterns.add<DropUnitDimFromElementwiseOps, ShapeCastOpFolder>(
1831-
patterns.getContext(), benefit);
1890+
patterns.add<DropUnitDimFromElementwiseOps, DropUnitDimFromBroadcastOp,
1891+
ShapeCastOpFolder>(patterns.getContext(), benefit);
18321892
}
18331893

18341894
void mlir::vector::populateBubbleVectorBitCastOpPatterns(

mlir/test/Dialect/Vector/vector-transfer-flatten.mlir

Lines changed: 54 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -535,6 +535,60 @@ func.func @fold_inner_unit_dim_scalable(%arg0 : vector<8x1x[1]x3xf128>,
535535

536536
// -----
537537

538+
func.func @drop_broadcast_unit_dim(%arg0 : vector<1x[1]x3x1xf128>) -> vector<4x1x[1]x3x1xf128> {
539+
%bc = vector.broadcast %arg0 : vector<1x[1]x3x1xf128> to vector<4x1x[1]x3x1xf128>
540+
return %bc : vector<4x1x[1]x3x1xf128>
541+
}
542+
543+
// CHECK-LABEL: func.func @drop_broadcast_unit_dim(
544+
// CHECK-SAME: %[[VAL_0:.*]]: vector<1x[1]x3x1xf128>{{.*}}-> vector<4x1x[1]x3x1xf128> {
545+
// CHECK: %[[VAL_1:.*]] = vector.shape_cast %[[VAL_0]] : vector<1x[1]x3x1xf128> to vector<[1]x3xf128>
546+
// CHECK: %[[VAL_2:.*]] = vector.broadcast %[[VAL_1]] : vector<[1]x3xf128> to vector<4x[1]x3xf128>
547+
// CHECK: %[[VAL_3:.*]] = vector.shape_cast %[[VAL_2]] : vector<4x[1]x3xf128> to vector<4x1x[1]x3x1xf128>
548+
// CHECK: return %[[VAL_3]] : vector<4x1x[1]x3x1xf128>
549+
550+
// -----
551+
552+
func.func @drop_broadcasted_only_unit_dim(%arg0 : vector<1xf32>) -> vector<1x1xf32> {
553+
%bc = vector.broadcast %arg0 : vector<1xf32> to vector<1x1xf32>
554+
return %bc : vector<1x1xf32>
555+
}
556+
557+
// CHECK-LABEL: func.func @drop_broadcasted_only_unit_dim(
558+
// CHECK-SAME: %[[VAL_0:.*]]: vector<1xf32>) -> vector<1x1xf32> {
559+
// CHECK: %[[VAL_1:.*]] = vector.shape_cast %[[VAL_0]] : vector<1xf32> to vector<f32>
560+
// CHECK: %[[VAL_2:.*]] = vector.broadcast %[[VAL_1]] : vector<f32> to vector<1xf32>
561+
// CHECK: %[[VAL_3:.*]] = vector.shape_cast %[[VAL_2]] : vector<1xf32> to vector<1x1xf32>
562+
// CHECK: return %[[VAL_3]] : vector<1x1xf32>
563+
564+
// -----
565+
566+
// Generated unit dimensions through broadcasts are not dropped as we prefer to have a
567+
// single broadcast rather than a broadcast and a shape_cast.
568+
func.func @drop_broadcast_generated_unit_dim(%arg0 : vector<4xf32>) -> vector<3x1x4xf32> {
569+
%bc = vector.broadcast %arg0 : vector<4xf32> to vector<3x1x4xf32>
570+
return %bc : vector<3x1x4xf32>
571+
}
572+
573+
// CHECK-LABEL: func.func @drop_broadcast_generated_unit_dim(
574+
// CHECK-SAME: %[[VAL_0:.*]]: vector<4xf32>{{.*}}-> vector<3x1x4xf32> {
575+
// CHECK: %[[VAL_1:.*]] = vector.broadcast %[[VAL_0]] : vector<4xf32> to vector<3x1x4xf32>
576+
// CHECK: return %[[VAL_1]] : vector<3x1x4xf32>
577+
578+
// -----
579+
580+
// A broadcasted unit dimension cannot be dropped to prevent type mismatch.
581+
func.func @drop_broadcasted_unit_dim(%arg0 : vector<2x1x4xf32>) -> vector<2x3x4xf32> {
582+
%bc = vector.broadcast %arg0 : vector<2x1x4xf32> to vector<2x3x4xf32>
583+
return %bc : vector<2x3x4xf32>
584+
}
585+
// CHECK-LABEL: func.func @drop_broadcasted_unit_dim(
586+
// CHECK-SAME: %[[VAL_0:.*]]: vector<2x1x4xf32>{{.*}}-> vector<2x3x4xf32> {
587+
// CHECK: %[[VAL_1:.*]] = vector.broadcast %[[VAL_0]] : vector<2x1x4xf32> to vector<2x3x4xf32>
588+
// CHECK: return %[[VAL_1]] : vector<2x3x4xf32>
589+
590+
// -----
591+
538592
func.func @negative_out_of_bound_transfer_read(
539593
%arg : memref<?x4x3x2xi8, strided<[24, 6, 2, 1], offset: ?>>) -> vector<5x4x3x2xi8> {
540594
%c0 = arith.constant 0 : index

0 commit comments

Comments
 (0)