Skip to content

[MLIR][Vector] Add DropUnitDimFromBroadcastOp pattern #92938

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 3 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
65 changes: 63 additions & 2 deletions mlir/lib/Dialect/Vector/Transforms/VectorTransforms.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1703,6 +1703,67 @@ struct DropUnitDimFromElementwiseOps final
}
};

/// Drops unit non scalable dimensions inside a broadcastOp which are shared
/// among source and result with shape_casts.
/// The newly inserted shape_cast Ops fold (before Op) and then
/// restore the unit dim after Op. Source type is required to be a vector.
///
/// Ex:
/// ```
/// %bc = vector.broadcast %arg0 : vector<1x4xf32> to vector<1x3x1x4xf32>
/// %cast = vector.shape_cast %bc : vector<1x3x1x4xf32> to vector<1x3x4xf32>
/// ```
///
/// Gets converted to:
///
/// ```
/// %sc_arg = vector.shape_cast %arg0 : vector<1x4xf32> to vector<4xf32>
/// %bc = vector.broadcast %arg : vector<4xf32> to vector<1x3x4xf32>
/// %cast_new = vector.shape_cast %bc : vector<1x3x4xf32> to
/// vector<1x3x1x4xf32>
/// %cast = vector.shape_cast %cast_new : vector<1x3x1x4xf32> to
/// vector<1x3x4xf32>
/// ```
/// %cast_new and %cast can be folded away.
struct DropUnitDimFromBroadcastOp final
: public OpRewritePattern<vector::BroadcastOp> {
using OpRewritePattern::OpRewritePattern;

LogicalResult matchAndRewrite(vector::BroadcastOp broadcastOp,
PatternRewriter &rewriter) const override {
auto srcVecTy = dyn_cast<VectorType>(broadcastOp.getSourceType());
if (!srcVecTy)
return failure();
auto resVecTy = broadcastOp.getResultVectorType();
auto srcVecTyBuilder = VectorType::Builder(srcVecTy);
auto resVecTyBuilder = VectorType::Builder(resVecTy);
Comment on lines +1738 to +1739
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Builders may be less efficient than just appending the dims not dropped to a new vector (but this is probably not much of a concern given the number of dims is normally < 5-ish).

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

(no need to change this -- just a note)

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I agree, I tried implementing it with shapes to be appended. But, considering one has to rebuild also scalableDims for both the new source and result type, I found it was generating lots code which is hidden thanks to dropDim.
Also, for some reason, I was able to generate the base of the new resultShape creating a subvector of it with :

SmallVector<int64_t> newResShape =
        llvm::to_vector(resVecTy.getShape().drop_back(srcVecTy.getRank()));

but for Scalable Dims I get some errors like this and I dont think I should be changing the behaviour of SmallVector. I suspect it comes from the way ScalableDims are defined.

llvm/include/llvm/ADT/SmallVector.h:1317:11: error: type 'decltype(__cont.begin())' (aka 'const bool *') cannot be narrowed to 'bool' in initializer list [-Wc++11-narrowing]

In order to fix it, I needed to create an ugly vector inserting explicit casts like

    SmallVector<bool> newResScalableDims = {
        static_cast<bool>(resVecTy.getScalableDims().begin()),
        static_cast<bool>(resVecTy.getScalableDims().drop_back(srcVecTy.getRank()).end())};

If you want, I can push my solution on top and we revert it if we prefer it as it currently is.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Those static_casts look very suspect 😅. It looks like that's just going to make a SmallVector of two (likely true) bools. I think keeping it as-is is fine as it's simpler, and likely not really a performance concern (vector types are normally small).

Copy link
Member

@MacDue MacDue Jun 20, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Btw, I've been trying to make writing code like this easier for scalable dims (for a while now 😅). With my current attempt #96236, I think you'd be able to rewrite this as (untested!!!):

auto srcDims = VectorDimList::from(srcVecTy);
auto resDims = VectorDimList::from(resVecTy);
auto rankDiff = resDims.size() - srcDims.size();

SmallVector<VectorDim> newSrcDims;
SmallVector<VectorDim> newResDims(resDims.takeFront(rankDiff));

auto broadcastedUnitDims = broadcastOp.computeBroadcastedUnitDims();
for (auto [idx, dim] : llvm::enumerate(srcDims)) {
  if (dim != VectorDim::getFixed(1) || broadcastedUnitDims.contains(idx)) {
    newSrcDims.push_back(dim);
    newResDims.push_back(resDims[idx + rankDiff]);
  }
}

auto newSourceType = ScalableVectorType::get(newSrcDims, srcVecTy.getElementType());
auto newResultType = ScalableVectorType::get(newResDims, srcVecTy.getElementType());

Please take a look at the PR if you think it'd be useful :)

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Looks indeed way simpler, very fortunate PR you submitted. I will probably not be able to review it at the moment as I need to sort some things out today. But I hope to give it a look. (I'll be off for a while, do not wait for my review).

auto broadcastedUnitDims = broadcastOp.computeBroadcastedUnitDims();
// Reversing allows us to remove dims from the back without keeping track of
// removed dimensions.
for (const auto [reversedIndex, dim] :
llvm::enumerate(llvm::reverse(srcVecTy.getShape()))) {
unsigned srcDimIndex = srcVecTy.getRank() - reversedIndex - 1;
unsigned resDimIndex = resVecTy.getRank() - reversedIndex - 1;
if (dim == 1 && !srcVecTy.getScalableDims()[srcDimIndex] &&
!broadcastedUnitDims.contains(srcDimIndex)) {
srcVecTyBuilder.dropDim(srcDimIndex);
resVecTyBuilder.dropDim(resDimIndex);
}
}

if (VectorType(srcVecTyBuilder) == srcVecTy)
return failure();
auto loc = broadcastOp->getLoc();
auto newSource = rewriter.create<vector::ShapeCastOp>(
loc, VectorType(srcVecTyBuilder), broadcastOp.getSource());
Comment on lines +1754 to +1758
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: avoid constructing the new vector type twice:

Suggested change
if (VectorType(srcVecTyBuilder) == srcVecTy)
return failure();
auto loc = broadcastOp->getLoc();
auto newSource = rewriter.create<vector::ShapeCastOp>(
loc, VectorType(srcVecTyBuilder), broadcastOp.getSource());
auto newSrcVecTy = VectorType(srcVecTyBuilder);
if (newSrcVecTy == srcVecTy)
return failure();
auto loc = broadcastOp->getLoc();
auto newSource = rewriter.create<vector::ShapeCastOp>(
loc, newSrcVecTy, broadcastOp.getSource());

auto newOp = rewriter.create<vector::BroadcastOp>(
loc, VectorType(resVecTyBuilder), newSource);
rewriter.replaceOpWithNewOp<ShapeCastOp>(broadcastOp, resVecTy,
newOp.getResult());
return success();
}
};

/// Pattern to eliminate redundant zero-constants added to reduction operands.
/// It's enough for there to be one initial zero value, so we can eliminate the
/// extra ones that feed into `vector.reduction <add>`. These get created by the
Expand Down Expand Up @@ -1827,8 +1888,8 @@ void mlir::vector::populateShapeCastFoldingPatterns(RewritePatternSet &patterns,

void mlir::vector::populateDropUnitDimWithShapeCastPatterns(
RewritePatternSet &patterns, PatternBenefit benefit) {
patterns.add<DropUnitDimFromElementwiseOps, ShapeCastOpFolder>(
patterns.getContext(), benefit);
patterns.add<DropUnitDimFromElementwiseOps, DropUnitDimFromBroadcastOp,
ShapeCastOpFolder>(patterns.getContext(), benefit);
}

void mlir::vector::populateBubbleVectorBitCastOpPatterns(
Expand Down
54 changes: 54 additions & 0 deletions mlir/test/Dialect/Vector/vector-transfer-flatten.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -535,6 +535,60 @@ func.func @fold_inner_unit_dim_scalable(%arg0 : vector<8x1x[1]x3xf128>,

// -----

func.func @drop_broadcast_unit_dim(%arg0 : vector<1x[1]x3x1xf128>) -> vector<4x1x[1]x3x1xf128> {
%bc = vector.broadcast %arg0 : vector<1x[1]x3x1xf128> to vector<4x1x[1]x3x1xf128>
return %bc : vector<4x1x[1]x3x1xf128>
}

// CHECK-LABEL: func.func @drop_broadcast_unit_dim(
// CHECK-SAME: %[[VAL_0:.*]]: vector<1x[1]x3x1xf128>{{.*}}-> vector<4x1x[1]x3x1xf128> {
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: Use better names than the generated VAL_* :)

// CHECK: %[[VAL_1:.*]] = vector.shape_cast %[[VAL_0]] : vector<1x[1]x3x1xf128> to vector<[1]x3xf128>
// CHECK: %[[VAL_2:.*]] = vector.broadcast %[[VAL_1]] : vector<[1]x3xf128> to vector<4x[1]x3xf128>
// CHECK: %[[VAL_3:.*]] = vector.shape_cast %[[VAL_2]] : vector<4x[1]x3xf128> to vector<4x1x[1]x3x1xf128>
// CHECK: return %[[VAL_3]] : vector<4x1x[1]x3x1xf128>

// -----

func.func @drop_broadcasted_only_unit_dim(%arg0 : vector<1xf32>) -> vector<1x1xf32> {
%bc = vector.broadcast %arg0 : vector<1xf32> to vector<1x1xf32>
return %bc : vector<1x1xf32>
}

// CHECK-LABEL: func.func @drop_broadcasted_only_unit_dim(
// CHECK-SAME: %[[VAL_0:.*]]: vector<1xf32>) -> vector<1x1xf32> {
// CHECK: %[[VAL_1:.*]] = vector.shape_cast %[[VAL_0]] : vector<1xf32> to vector<f32>
// CHECK: %[[VAL_2:.*]] = vector.broadcast %[[VAL_1]] : vector<f32> to vector<1xf32>
// CHECK: %[[VAL_3:.*]] = vector.shape_cast %[[VAL_2]] : vector<1xf32> to vector<1x1xf32>
// CHECK: return %[[VAL_3]] : vector<1x1xf32>

// -----

// Generated unit dimensions through broadcasts are not dropped as we prefer to have a
// single broadcast rather than a broadcast and a shape_cast.
func.func @drop_broadcast_generated_unit_dim(%arg0 : vector<4xf32>) -> vector<3x1x4xf32> {
%bc = vector.broadcast %arg0 : vector<4xf32> to vector<3x1x4xf32>
return %bc : vector<3x1x4xf32>
}

// CHECK-LABEL: func.func @drop_broadcast_generated_unit_dim(
// CHECK-SAME: %[[VAL_0:.*]]: vector<4xf32>{{.*}}-> vector<3x1x4xf32> {
// CHECK: %[[VAL_1:.*]] = vector.broadcast %[[VAL_0]] : vector<4xf32> to vector<3x1x4xf32>
// CHECK: return %[[VAL_1]] : vector<3x1x4xf32>

// -----

// A broadcasted unit dimension cannot be dropped to prevent type mismatch.
func.func @drop_broadcasted_unit_dim(%arg0 : vector<2x1x4xf32>) -> vector<2x3x4xf32> {
%bc = vector.broadcast %arg0 : vector<2x1x4xf32> to vector<2x3x4xf32>
return %bc : vector<2x3x4xf32>
}
// CHECK-LABEL: func.func @drop_broadcasted_unit_dim(
// CHECK-SAME: %[[VAL_0:.*]]: vector<2x1x4xf32>{{.*}}-> vector<2x3x4xf32> {
// CHECK: %[[VAL_1:.*]] = vector.broadcast %[[VAL_0]] : vector<2x1x4xf32> to vector<2x3x4xf32>
// CHECK: return %[[VAL_1]] : vector<2x3x4xf32>

// -----

func.func @negative_out_of_bound_transfer_read(
%arg : memref<?x4x3x2xi8, strided<[24, 6, 2, 1], offset: ?>>) -> vector<5x4x3x2xi8> {
%c0 = arith.constant 0 : index
Expand Down
Loading